Merge pull request #1 from josh1593/package-exo-app

Package exo app
This commit is contained in:
josh
2024-11-19 05:51:45 -08:00
committed by GitHub
11 changed files with 144 additions and 45 deletions

2
.gitignore vendored
View File

@@ -4,6 +4,7 @@ test_weights.npz
.exo_used_ports
.exo_node_id
.idea
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
@@ -15,7 +16,6 @@ __pycache__/
# Distribution / packaging
/.Python
/build/
/develop-eggs/
/dist/
/downloads/

BIN
docs/exo-rounded.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

View File

@@ -1 +1 @@
from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

View File

@@ -8,15 +8,16 @@ from typing import List, Literal, Union, Dict
from aiohttp import web
import aiohttp_cors
import traceback
import os
import sys
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict
from exo.helpers import PrefixDict, shutdown
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
from typing import Callable, Optional
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
self.role = role
@@ -26,6 +27,7 @@ class Message:
return {"role": self.role, "content": self.content}
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
self.model = model
@@ -143,7 +145,6 @@ class PromptSession:
self.timestamp = timestamp
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
self.node = node
@@ -172,13 +173,22 @@ class ChatGPTAPI:
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
async def handle_quit(self, request):
if DEBUG>=1: print("Received quit signal")
response = web.json_response({"detail": "Quit signal received"}, status=200)
await response.prepare(request)
await response.write_eof()
await shutdown(signal.SIGINT, asyncio.get_event_loop())
async def timeout_middleware(self, app, handler):
async def middleware(request):

View File

@@ -1,7 +1,11 @@
import aiofiles.os as aios
from typing import Union
import asyncio
import aiohttp
import json
import os
import sys
import shutil
from urllib.parse import urljoin
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
from datetime import datetime, timedelta
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
from pathlib import Path
from typing import Generator, Iterable, TypeVar, TypedDict
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from exo.helpers import DEBUG
from exo.helpers import DEBUG, is_frozen
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
from exo.inference.shard import Shard
import aiofiles
@@ -17,7 +21,6 @@ from aiofiles import os as aios
T = TypeVar("T")
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
refs_dir = get_repo_root(repo_id)/"refs"
refs_file = refs_dir/revision
@@ -99,9 +102,22 @@ async def get_auth_headers():
def get_repo_root(repo_id: str) -> Path:
"""Get the root directory for a given repo ID in the Hugging Face cache."""
sanitized_repo_id = repo_id.replace("/", "--")
sanitized_repo_id = str(repo_id).replace("/", "--")
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
async def move_models_to_hf(seed_dir: Union[str, Path]):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = get_hf_home()/"hub"
await aios.makedirs(dest_dir, exist_ok=True)
async for path in source_dir.iterdir():
if path.is_dir() and path.startswith("models--"):
dest_path = dest_dir / path.name
if dest_path.exists():
if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
else:
await aios.rename(str(path), str(dest_path))
async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
@@ -409,7 +425,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
elif shard.is_last_layer():
shard_specific_patterns.add(sorted_file_names[-1])
else:
shard_specific_patterns = set("*.safetensors")
shard_specific_patterns = set(["*.safetensors"])
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -1,4 +1,5 @@
import os
import sys
import asyncio
from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return ["localhost"]
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
print("Thank you for using exo.")
print_yellow_exo()
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.stop()
loop.stop()
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

View File

@@ -21,6 +21,7 @@ from transformers import AutoProcessor
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
from exo import DEBUG
from exo.inference.tokenizers import resolve_tokenizer
from ..shard import Shard
@@ -183,7 +184,7 @@ async def load_shard(
processor.encode = processor.tokenizer.encode
return model, processor
else:
tokenizer = load_tokenizer(model_path, tokenizer_config)
tokenizer = await resolve_tokenizer(model_path)
return model, tokenizer

View File

@@ -7,7 +7,6 @@ from exo.inference.tokenizers import resolve_tokenizer
from tinygrad.nn.state import load_state_dict
from tinygrad import Tensor, nn, Context
from exo.inference.inference_engine import InferenceEngine
from typing import Optional, Tuple
import numpy as np
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
from exo.download.shard_download import ShardDownloader
@@ -68,24 +67,21 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
logits = x[:, -1, :]
def sample_wrapper():
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
return out.numpy().astype(int)
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
return np.array(tokens)
return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
async def decode(self, shard: Shard, tokens) -> str:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
return tokens
return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
await self.ensure_shard(shard)
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
return output_data.numpy()
return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
async def ensure_shard(self, shard: Shard):
if self.shard == shard:

View File

@@ -3,6 +3,9 @@ import asyncio
import signal
import json
import logging
import platform
import os
import sys
import time
import traceback
import uuid
@@ -17,14 +20,14 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
from exo.api import ChatGPTAPI
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration.node import Node
from exo.models import build_base_shard, get_repo
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -34,6 +37,7 @@ parser.add_argument("--default-model", type=str, default=None, help="Default mod
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
@@ -129,7 +133,6 @@ node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)
def preemptively_start_download(request_id: str, opaque_status: str):
try:
status = json.loads(opaque_status)
@@ -162,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
print("Thank you for using exo.")
print_yellow_exo()
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.stop()
loop.stop()
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
inference_class = inference_engine.__class__.__name__
shard = build_base_shard(model_name, inference_class)
@@ -219,13 +208,20 @@ async def main():
{"❌ No read access" if not has_read else ""}
{"❌ No write access" if not has_write else ""}
""")
if not args.models_seed_dir is None:
try:
await move_models_to_hf(args.models_seed_dir)
except Exception as e:
print(f"Error moving models to .cache/huggingface: {e}")
# Use a more direct approach to handle signals
def handle_exit():
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, handle_exit)
if platform.system() != "Windows":
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, handle_exit)
await node.start(wait_for_peers=args.wait_for_peers)

60
scripts/build_exo.py Normal file
View File

@@ -0,0 +1,60 @@
import site
import subprocess
import sys
import os
import pkgutil
def run():
site_packages = site.getsitepackages()[0]
command = [
f"{sys.executable}", "-m", "nuitka", "exo/main.py",
"--company-name=exolabs",
"--product-name=exo",
"--output-dir=dist",
"--follow-imports",
"--standalone",
"--output-filename=exo",
"--onefile",
"--python-flag=no_site"
]
if sys.platform == "darwin":
command.extend([
"--macos-app-name=exo",
"--macos-app-mode=gui",
"--macos-app-version=0.0.1",
"--macos-signed-app-name=com.exolabs.exo",
"--macos-sign-identity=auto",
"--macos-sign-notarization",
"--include-distribution-meta=mlx",
"--include-module=mlx._reprlib_fix",
"--include-module=mlx._os_warning",
f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
"--include-distribution-meta=pygments",
"--nofollow-import-to=tinygrad"
])
inference_modules = [
name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
]
for module in inference_modules:
command.append(f"--include-module=exo.inference.mlx.models.{module}")
elif sys.platform == "win32":
command.extend([
"--windows-icon-from-ico=docs/exo-logo-win.ico",
"--file-version=0.0.1",
"--product-version=0.0.1"
])
elif sys.platform.startswith("linux"):
command.extend([
"--include-distribution-metadata=pygments",
"--linux-icon=docs/exo-rounded.png"
])
try:
subprocess.run(command, check=True)
print("Build completed!")
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
run()

View File

@@ -8,11 +8,12 @@ install_requires = [
"aiohttp==3.10.11",
"aiohttp_cors==0.7.0",
"aiofiles==24.1.0",
"grpcio==1.64.1",
"grpcio-tools==1.64.1",
"grpcio==1.68.0",
"grpcio-tools==1.68.0",
"Jinja2==3.1.4",
"netifaces==0.11.0",
"numpy==2.0.0",
"nuitka==2.4.10",
"nvidia-ml-py==12.560.30",
"pillow==10.4.0",
"prometheus-client==0.20.0",
@@ -21,10 +22,9 @@ install_requires = [
"pydantic==2.9.2",
"requests==2.32.3",
"rich==13.7.1",
"safetensors==0.4.3",
"tenacity==9.0.0",
"tqdm==4.66.4",
"transformers==4.43.3",
"transformers==4.46.3",
"uuid==1.30",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
]