From e991438e728a4cfa7ada7731bdb9de636544501d Mon Sep 17 00:00:00 2001 From: josh Date: Mon, 18 Nov 2024 23:02:03 -0800 Subject: [PATCH] pr suggestions fix --- exo/api/chatgpt_api.py | 8 +------- exo/download/hf/hf_helpers.py | 9 ++------- exo/helpers.py | 20 ++++++++++++++++++++ exo/main.py | 16 +--------------- scripts/build_exo.py | 6 +----- 5 files changed, 25 insertions(+), 34 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index fa6e6cee..5484645c 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -13,7 +13,7 @@ import signal import sys from exo import DEBUG, VERSION from exo.download.download_progress import RepoProgressEvent -from exo.helpers import PrefixDict +from exo.helpers import PrefixDict, shutdown from exo.inference.inference_engine import inference_engine_classes from exo.inference.shard import Shard from exo.inference.tokenizers import resolve_tokenizer @@ -148,11 +148,6 @@ class PromptSession: self.timestamp = timestamp self.prompt = prompt -def is_frozen(): - return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \ - or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \ - or '__nuitka__' in globals() or getattr(sys, '__compiled__', False) - class ChatGPTAPI: def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None): self.node = node @@ -193,7 +188,6 @@ class ChatGPTAPI: async def handle_quit(self, request): print("Received quit signal") - from exo.main import shutdown response = web.json_response({"detail": "Quit signal received"}, status=200) await response.prepare(request) await response.write_eof() diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index a1af9789..3081ecf1 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -10,7 +10,7 @@ from fnmatch import fnmatch from pathlib import Path from typing import Generator, Iterable, TypeVar, TypedDict from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type -from exo.helpers import DEBUG +from exo.helpers import DEBUG, is_frozen from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback from exo.inference.shard import Shard import aiofiles @@ -18,11 +18,6 @@ from aiofiles import os as aios T = TypeVar("T") -def is_frozen(): - return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \ - or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \ - or ('__compiled__' in globals()) - async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]: refs_dir = get_repo_root(repo_id)/"refs" refs_file = refs_dir/revision @@ -105,7 +100,7 @@ async def get_auth_headers(): def get_repo_root(repo_id: str) -> Path: """Get the root directory for a given repo ID in the Hugging Face cache.""" sanitized_repo_id = str(repo_id).replace("/", "--") - if "Qwen2.5-0.5B-Instruct-4bit" in str(repo_id) and is_frozen(): + if is_frozen(): repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}" return repo_root return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" diff --git a/exo/helpers.py b/exo/helpers.py index b657b277..4defb98b 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -1,4 +1,5 @@ import os +import sys import asyncio from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List import socket @@ -234,3 +235,22 @@ def get_all_ip_addresses(): except: if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.") return ["localhost"] + + +async def shutdown(signal, loop): + """Gracefully shutdown the server and close the asyncio loop.""" + print(f"Received exit signal {signal.name}...") + print("Thank you for using exo.") + print_yellow_exo() + server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + [task.cancel() for task in server_tasks] + print(f"Cancelling {len(server_tasks)} outstanding tasks") + await asyncio.gather(*server_tasks, return_exceptions=True) + await server.stop() + loop.stop() + + +def is_frozen(): + return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \ + or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \ + or '__nuitka__' in globals() or getattr(sys, '__compiled__', False) \ No newline at end of file diff --git a/exo/main.py b/exo/main.py index 3d970aa1..35cc4f9c 100644 --- a/exo/main.py +++ b/exo/main.py @@ -20,7 +20,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe from exo.api import ChatGPTAPI from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader from exo.download.hf.hf_shard_download import HFShardDownloader -from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link +from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown from exo.inference.shard import Shard from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.inference.dummy_inference_engine import DummyInferenceEngine @@ -163,20 +163,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent): shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast) - -async def shutdown(signal, loop): - """Gracefully shutdown the server and close the asyncio loop.""" - print(f"Received exit signal {signal.name}...") - print("Thank you for using exo.") - print_yellow_exo() - server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] - [task.cancel() for task in server_tasks] - print(f"Cancelling {len(server_tasks)} outstanding tasks") - await asyncio.gather(*server_tasks, return_exceptions=True) - await server.stop() - loop.stop() - - async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str): inference_class = inference_engine.__class__.__name__ shard = build_base_shard(model_name, inference_class) diff --git a/scripts/build_exo.py b/scripts/build_exo.py index a2513b03..236f1ca8 100644 --- a/scripts/build_exo.py +++ b/scripts/build_exo.py @@ -21,11 +21,7 @@ def run(): "--macos-app-name=exo", "--macos-app-mode=gui", "--macos-app-version=0.0.1", - "--include-module=exo.inference.mlx.models.llama", - "--include-module=exo.inference.mlx.models.deepseek_v2", - "--include-module=exo.inference.mlx.models.base", - "--include-module=exo.inference.mlx.models.llava", - "--include-module=exo.inference.mlx.models.qwen2", + "--include-module=exo.inference.mlx.models.*", "--include-distribution-meta=mlx", "--include-module=mlx._reprlib_fix", "--include-module=mlx._os_warning",