mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
pr suggestion
This commit is contained in:
@@ -101,9 +101,6 @@ async def get_auth_headers():
|
|||||||
def get_repo_root(repo_id: str) -> Path:
|
def get_repo_root(repo_id: str) -> Path:
|
||||||
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
||||||
sanitized_repo_id = str(repo_id).replace("/", "--")
|
sanitized_repo_id = str(repo_id).replace("/", "--")
|
||||||
if is_frozen():
|
|
||||||
exec_root = Path(sys.argv[0]).parent
|
|
||||||
asyncio.run(move_models_to_hf)
|
|
||||||
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
|
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
|
||||||
|
|
||||||
async def move_models_to_hf():
|
async def move_models_to_hf():
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
|
|||||||
from exo.api import ChatGPTAPI
|
from exo.api import ChatGPTAPI
|
||||||
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
||||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
|
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown, move_models_to_hf
|
||||||
from exo.inference.shard import Shard
|
from exo.inference.shard import Shard
|
||||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||||
@@ -36,6 +36,7 @@ parser.add_argument("model_name", nargs="?", help="Model name to run")
|
|||||||
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
||||||
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
||||||
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
||||||
|
parser.add_argument("--model-seed-dir", type=str, default=None, help="Model seed directory")
|
||||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||||
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
||||||
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
||||||
@@ -130,6 +131,11 @@ node.on_token.register("update_topology_viz").on_next(
|
|||||||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.model_seed_dir is not None:
|
||||||
|
try:
|
||||||
|
await move_models_to_hf()
|
||||||
|
except:
|
||||||
|
print(f"Error moving models to .cache/huggingface: {e}")
|
||||||
|
|
||||||
def preemptively_start_download(request_id: str, opaque_status: str):
|
def preemptively_start_download(request_id: str, opaque_status: str):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def run():
|
|||||||
"--linux-icon=docs/exo-rounded.png"
|
"--linux-icon=docs/exo-rounded.png"
|
||||||
])
|
])
|
||||||
try:
|
try:
|
||||||
# subprocess.run(command, check=True)
|
subprocess.run(command, check=True)
|
||||||
print("Build completed!")
|
print("Build completed!")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"An error occurred: {e}")
|
print(f"An error occurred: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user