From 8ce0fe2bb391031b0085b46a5e9efb49b9bfe6f8 Mon Sep 17 00:00:00 2001 From: josh Date: Tue, 19 Nov 2024 00:59:33 -0800 Subject: [PATCH] pr suggestion --- exo/download/hf/hf_helpers.py | 3 --- exo/main.py | 8 +++++++- scripts/build_exo.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index cacd5b1b..a07a060e 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -101,9 +101,6 @@ async def get_auth_headers(): def get_repo_root(repo_id: str) -> Path: """Get the root directory for a given repo ID in the Hugging Face cache.""" sanitized_repo_id = str(repo_id).replace("/", "--") - if is_frozen(): - exec_root = Path(sys.argv[0]).parent - asyncio.run(move_models_to_hf) return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" async def move_models_to_hf(): diff --git a/exo/main.py b/exo/main.py index 35cc4f9c..d63643eb 100644 --- a/exo/main.py +++ b/exo/main.py @@ -20,7 +20,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe from exo.api import ChatGPTAPI from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader from exo.download.hf.hf_shard_download import HFShardDownloader -from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown +from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown, move_models_to_hf from exo.inference.shard import Shard from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.inference.dummy_inference_engine import DummyInferenceEngine @@ -36,6 +36,7 @@ parser.add_argument("model_name", nargs="?", help="Model name to run") parser.add_argument("--node-id", type=str, default=None, help="Node ID") parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host") parser.add_argument("--node-port", type=int, default=None, help="Node port") +parser.add_argument("--model-seed-dir", type=str, default=None, help="Model seed directory") parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery") parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download") parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download") @@ -130,6 +131,11 @@ node.on_token.register("update_topology_viz").on_next( lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None ) +if args.model_seed_dir is not None: + try: + await move_models_to_hf() + except: + print(f"Error moving models to .cache/huggingface: {e}") def preemptively_start_download(request_id: str, opaque_status: str): try: diff --git a/scripts/build_exo.py b/scripts/build_exo.py index 87402521..d857d76a 100644 --- a/scripts/build_exo.py +++ b/scripts/build_exo.py @@ -53,7 +53,7 @@ def run(): "--linux-icon=docs/exo-rounded.png" ]) try: - # subprocess.run(command, check=True) + subprocess.run(command, check=True) print("Build completed!") except subprocess.CalledProcessError as e: print(f"An error occurred: {e}")