mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,6 +4,7 @@ test_weights.npz
|
||||
.exo_used_ports
|
||||
.exo_node_id
|
||||
.idea
|
||||
.DS_Store
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
@@ -15,7 +16,6 @@ __pycache__/
|
||||
|
||||
# Distribution / packaging
|
||||
/.Python
|
||||
/build/
|
||||
/develop-eggs/
|
||||
/dist/
|
||||
/downloads/
|
||||
|
||||
BIN
docs/exo-rounded.png
Normal file
BIN
docs/exo-rounded.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
@@ -1 +1 @@
|
||||
from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
|
||||
from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
|
||||
@@ -8,15 +8,16 @@ from typing import List, Literal, Union, Dict
|
||||
from aiohttp import web
|
||||
import aiohttp_cors
|
||||
import traceback
|
||||
import os
|
||||
import sys
|
||||
from exo import DEBUG, VERSION
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.helpers import PrefixDict
|
||||
from exo.helpers import PrefixDict, shutdown
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration import Node
|
||||
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
self.role = role
|
||||
@@ -26,6 +27,7 @@ class Message:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
self.model = model
|
||||
@@ -143,7 +145,6 @@ class PromptSession:
|
||||
self.timestamp = timestamp
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
||||
self.node = node
|
||||
@@ -172,13 +173,22 @@ class ChatGPTAPI:
|
||||
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
|
||||
|
||||
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
||||
self.app.router.add_get("/", self.handle_root)
|
||||
self.app.router.add_static("/", self.static_dir, name="static")
|
||||
if "__compiled__" not in globals():
|
||||
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
||||
self.app.router.add_get("/", self.handle_root)
|
||||
self.app.router.add_static("/", self.static_dir, name="static")
|
||||
|
||||
self.app.middlewares.append(self.timeout_middleware)
|
||||
self.app.middlewares.append(self.log_request)
|
||||
|
||||
async def handle_quit(self, request):
|
||||
if DEBUG>=1: print("Received quit signal")
|
||||
response = web.json_response({"detail": "Quit signal received"}, status=200)
|
||||
await response.prepare(request)
|
||||
await response.write_eof()
|
||||
await shutdown(signal.SIGINT, asyncio.get_event_loop())
|
||||
|
||||
async def timeout_middleware(self, app, handler):
|
||||
async def middleware(request):
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import aiofiles.os as aios
|
||||
from typing import Union
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
|
||||
from datetime import datetime, timedelta
|
||||
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable, TypeVar, TypedDict
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from exo.helpers import DEBUG
|
||||
from exo.helpers import DEBUG, is_frozen
|
||||
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
|
||||
from exo.inference.shard import Shard
|
||||
import aiofiles
|
||||
@@ -17,7 +21,6 @@ from aiofiles import os as aios
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
|
||||
refs_dir = get_repo_root(repo_id)/"refs"
|
||||
refs_file = refs_dir/revision
|
||||
@@ -99,9 +102,22 @@ async def get_auth_headers():
|
||||
|
||||
def get_repo_root(repo_id: str) -> Path:
|
||||
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
||||
sanitized_repo_id = repo_id.replace("/", "--")
|
||||
sanitized_repo_id = str(repo_id).replace("/", "--")
|
||||
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
|
||||
|
||||
async def move_models_to_hf(seed_dir: Union[str, Path]):
|
||||
"""Move model in resources folder of app to .cache/huggingface/hub"""
|
||||
source_dir = Path(seed_dir)
|
||||
dest_dir = get_hf_home()/"hub"
|
||||
await aios.makedirs(dest_dir, exist_ok=True)
|
||||
async for path in source_dir.iterdir():
|
||||
if path.is_dir() and path.startswith("models--"):
|
||||
dest_path = dest_dir / path.name
|
||||
if dest_path.exists():
|
||||
if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
|
||||
else:
|
||||
await aios.rename(str(path), str(dest_path))
|
||||
|
||||
|
||||
async def fetch_file_list(session, repo_id, revision, path=""):
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
@@ -409,7 +425,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
elif shard.is_last_layer():
|
||||
shard_specific_patterns.add(sorted_file_names[-1])
|
||||
else:
|
||||
shard_specific_patterns = set("*.safetensors")
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
|
||||
import socket
|
||||
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
|
||||
except:
|
||||
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
|
||||
return ["localhost"]
|
||||
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
print("Thank you for using exo.")
|
||||
print_yellow_exo()
|
||||
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in server_tasks]
|
||||
print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
||||
await asyncio.gather(*server_tasks, return_exceptions=True)
|
||||
await server.stop()
|
||||
loop.stop()
|
||||
|
||||
|
||||
def is_frozen():
|
||||
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
|
||||
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
|
||||
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
|
||||
@@ -21,6 +21,7 @@ from transformers import AutoProcessor
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
|
||||
|
||||
from exo import DEBUG
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from ..shard import Shard
|
||||
|
||||
|
||||
@@ -183,7 +184,7 @@ async def load_shard(
|
||||
processor.encode = processor.tokenizer.encode
|
||||
return model, processor
|
||||
else:
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
tokenizer = await resolve_tokenizer(model_path)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from exo.inference.tokenizers import resolve_tokenizer
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad import Tensor, nn, Context
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
@@ -68,24 +67,21 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
||||
logits = x[:, -1, :]
|
||||
def sample_wrapper():
|
||||
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
|
||||
out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
||||
return out.numpy().astype(int)
|
||||
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
||||
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
||||
return np.array(tokens)
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
|
||||
|
||||
async def decode(self, shard: Shard, tokens) -> str:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
||||
return tokens
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
|
||||
return output_data.numpy()
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
|
||||
34
exo/main.py
34
exo/main.py
@@ -3,6 +3,9 @@ import asyncio
|
||||
import signal
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
@@ -17,14 +20,14 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
|
||||
from exo.api import ChatGPTAPI
|
||||
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration.node import Node
|
||||
from exo.models import build_base_shard, get_repo
|
||||
from exo.viz.topology_viz import TopologyViz
|
||||
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home
|
||||
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
@@ -34,6 +37,7 @@ parser.add_argument("--default-model", type=str, default=None, help="Default mod
|
||||
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
||||
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
||||
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
||||
parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
||||
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
||||
@@ -129,7 +133,6 @@ node.on_token.register("update_topology_viz").on_next(
|
||||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
||||
)
|
||||
|
||||
|
||||
def preemptively_start_download(request_id: str, opaque_status: str):
|
||||
try:
|
||||
status = json.loads(opaque_status)
|
||||
@@ -162,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
||||
|
||||
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
||||
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
print("Thank you for using exo.")
|
||||
print_yellow_exo()
|
||||
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in server_tasks]
|
||||
print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
||||
await asyncio.gather(*server_tasks, return_exceptions=True)
|
||||
await server.stop()
|
||||
loop.stop()
|
||||
|
||||
|
||||
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
||||
inference_class = inference_engine.__class__.__name__
|
||||
shard = build_base_shard(model_name, inference_class)
|
||||
@@ -219,13 +208,20 @@ async def main():
|
||||
{"❌ No read access" if not has_read else ""}
|
||||
{"❌ No write access" if not has_write else ""}
|
||||
""")
|
||||
|
||||
if not args.models_seed_dir is None:
|
||||
try:
|
||||
await move_models_to_hf(args.models_seed_dir)
|
||||
except Exception as e:
|
||||
print(f"Error moving models to .cache/huggingface: {e}")
|
||||
|
||||
# Use a more direct approach to handle signals
|
||||
def handle_exit():
|
||||
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
|
||||
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
if platform.system() != "Windows":
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
|
||||
await node.start(wait_for_peers=args.wait_for_peers)
|
||||
|
||||
|
||||
60
scripts/build_exo.py
Normal file
60
scripts/build_exo.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import site
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import pkgutil
|
||||
|
||||
def run():
|
||||
site_packages = site.getsitepackages()[0]
|
||||
command = [
|
||||
f"{sys.executable}", "-m", "nuitka", "exo/main.py",
|
||||
"--company-name=exolabs",
|
||||
"--product-name=exo",
|
||||
"--output-dir=dist",
|
||||
"--follow-imports",
|
||||
"--standalone",
|
||||
"--output-filename=exo",
|
||||
"--onefile",
|
||||
"--python-flag=no_site"
|
||||
]
|
||||
|
||||
if sys.platform == "darwin":
|
||||
command.extend([
|
||||
"--macos-app-name=exo",
|
||||
"--macos-app-mode=gui",
|
||||
"--macos-app-version=0.0.1",
|
||||
"--macos-signed-app-name=com.exolabs.exo",
|
||||
"--macos-sign-identity=auto",
|
||||
"--macos-sign-notarization",
|
||||
"--include-distribution-meta=mlx",
|
||||
"--include-module=mlx._reprlib_fix",
|
||||
"--include-module=mlx._os_warning",
|
||||
f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
|
||||
f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
|
||||
"--include-distribution-meta=pygments",
|
||||
"--nofollow-import-to=tinygrad"
|
||||
])
|
||||
inference_modules = [
|
||||
name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
|
||||
]
|
||||
for module in inference_modules:
|
||||
command.append(f"--include-module=exo.inference.mlx.models.{module}")
|
||||
elif sys.platform == "win32":
|
||||
command.extend([
|
||||
"--windows-icon-from-ico=docs/exo-logo-win.ico",
|
||||
"--file-version=0.0.1",
|
||||
"--product-version=0.0.1"
|
||||
])
|
||||
elif sys.platform.startswith("linux"):
|
||||
command.extend([
|
||||
"--include-distribution-metadata=pygments",
|
||||
"--linux-icon=docs/exo-rounded.png"
|
||||
])
|
||||
try:
|
||||
subprocess.run(command, check=True)
|
||||
print("Build completed!")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
8
setup.py
8
setup.py
@@ -8,11 +8,12 @@ install_requires = [
|
||||
"aiohttp==3.10.11",
|
||||
"aiohttp_cors==0.7.0",
|
||||
"aiofiles==24.1.0",
|
||||
"grpcio==1.64.1",
|
||||
"grpcio-tools==1.64.1",
|
||||
"grpcio==1.68.0",
|
||||
"grpcio-tools==1.68.0",
|
||||
"Jinja2==3.1.4",
|
||||
"netifaces==0.11.0",
|
||||
"numpy==2.0.0",
|
||||
"nuitka==2.4.10",
|
||||
"nvidia-ml-py==12.560.30",
|
||||
"pillow==10.4.0",
|
||||
"prometheus-client==0.20.0",
|
||||
@@ -21,10 +22,9 @@ install_requires = [
|
||||
"pydantic==2.9.2",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"safetensors==0.4.3",
|
||||
"tenacity==9.0.0",
|
||||
"tqdm==4.66.4",
|
||||
"transformers==4.43.3",
|
||||
"transformers==4.46.3",
|
||||
"uuid==1.30",
|
||||
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user