pr suggestions fix

This commit is contained in:
josh
2024-11-18 23:02:03 -08:00
parent 0ac1b87f35
commit e991438e72
5 changed files with 25 additions and 34 deletions

View File

@@ -13,7 +13,7 @@ import signal
import sys import sys
from exo import DEBUG, VERSION from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict from exo.helpers import PrefixDict, shutdown
from exo.inference.inference_engine import inference_engine_classes from exo.inference.inference_engine import inference_engine_classes
from exo.inference.shard import Shard from exo.inference.shard import Shard
from exo.inference.tokenizers import resolve_tokenizer from exo.inference.tokenizers import resolve_tokenizer
@@ -148,11 +148,6 @@ class PromptSession:
self.timestamp = timestamp self.timestamp = timestamp
self.prompt = prompt self.prompt = prompt
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
class ChatGPTAPI: class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None): def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
self.node = node self.node = node
@@ -193,7 +188,6 @@ class ChatGPTAPI:
async def handle_quit(self, request): async def handle_quit(self, request):
print("Received quit signal") print("Received quit signal")
from exo.main import shutdown
response = web.json_response({"detail": "Quit signal received"}, status=200) response = web.json_response({"detail": "Quit signal received"}, status=200)
await response.prepare(request) await response.prepare(request)
await response.write_eof() await response.write_eof()

View File

@@ -10,7 +10,7 @@ from fnmatch import fnmatch
from pathlib import Path from pathlib import Path
from typing import Generator, Iterable, TypeVar, TypedDict from typing import Generator, Iterable, TypeVar, TypedDict
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from exo.helpers import DEBUG from exo.helpers import DEBUG, is_frozen
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
from exo.inference.shard import Shard from exo.inference.shard import Shard
import aiofiles import aiofiles
@@ -18,11 +18,6 @@ from aiofiles import os as aios
T = TypeVar("T") T = TypeVar("T")
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or ('__compiled__' in globals())
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]: async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
refs_dir = get_repo_root(repo_id)/"refs" refs_dir = get_repo_root(repo_id)/"refs"
refs_file = refs_dir/revision refs_file = refs_dir/revision
@@ -105,7 +100,7 @@ async def get_auth_headers():
def get_repo_root(repo_id: str) -> Path: def get_repo_root(repo_id: str) -> Path:
"""Get the root directory for a given repo ID in the Hugging Face cache.""" """Get the root directory for a given repo ID in the Hugging Face cache."""
sanitized_repo_id = str(repo_id).replace("/", "--") sanitized_repo_id = str(repo_id).replace("/", "--")
if "Qwen2.5-0.5B-Instruct-4bit" in str(repo_id) and is_frozen(): if is_frozen():
repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}" repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
return repo_root return repo_root
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"

View File

@@ -1,4 +1,5 @@
import os import os
import sys
import asyncio import asyncio
from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
import socket import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
except: except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.") if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return ["localhost"] return ["localhost"]
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
print("Thank you for using exo.")
print_yellow_exo()
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.stop()
loop.stop()
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

View File

@@ -20,7 +20,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
from exo.api import ChatGPTAPI from exo.api import ChatGPTAPI
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
from exo.download.hf.hf_shard_download import HFShardDownloader from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
from exo.inference.shard import Shard from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.inference.dummy_inference_engine import DummyInferenceEngine from exo.inference.dummy_inference_engine import DummyInferenceEngine
@@ -163,20 +163,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast) shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
print("Thank you for using exo.")
print_yellow_exo()
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.stop()
loop.stop()
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str): async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
inference_class = inference_engine.__class__.__name__ inference_class = inference_engine.__class__.__name__
shard = build_base_shard(model_name, inference_class) shard = build_base_shard(model_name, inference_class)

View File

@@ -21,11 +21,7 @@ def run():
"--macos-app-name=exo", "--macos-app-name=exo",
"--macos-app-mode=gui", "--macos-app-mode=gui",
"--macos-app-version=0.0.1", "--macos-app-version=0.0.1",
"--include-module=exo.inference.mlx.models.llama", "--include-module=exo.inference.mlx.models.*",
"--include-module=exo.inference.mlx.models.deepseek_v2",
"--include-module=exo.inference.mlx.models.base",
"--include-module=exo.inference.mlx.models.llava",
"--include-module=exo.inference.mlx.models.qwen2",
"--include-distribution-meta=mlx", "--include-distribution-meta=mlx",
"--include-module=mlx._reprlib_fix", "--include-module=mlx._reprlib_fix",
"--include-module=mlx._os_warning", "--include-module=mlx._os_warning",