mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #640 from exo-explore/simpledownload
Simple download
This commit is contained in:
@@ -27,7 +27,7 @@ commands:
|
||||
fi
|
||||
|
||||
# Start first instance
|
||||
HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
EXO_HOME="$(pwd)/.exo_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
--node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \
|
||||
--chatgpt-api-response-timeout 900 --disable-tui > output1.log &
|
||||
PID1=$!
|
||||
@@ -35,7 +35,7 @@ commands:
|
||||
TAIL1=$!
|
||||
|
||||
# Start second instance
|
||||
HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
EXO_HOME="$(pwd)/.exo_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
--node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \
|
||||
--chatgpt-api-response-timeout 900 --disable-tui > output2.log &
|
||||
PID2=$!
|
||||
|
||||
1
.github/workflows/bench_job.yml
vendored
1
.github/workflows/bench_job.yml
vendored
@@ -128,6 +128,7 @@ jobs:
|
||||
--interface-type-filter="${{ inputs.network_interface }}" \
|
||||
--disable-tui \
|
||||
--max-generate-tokens 250 \
|
||||
--chatgpt-api-response-timeout 900 \
|
||||
--chatgpt-api-port 52415 > output1.log 2>&1 &
|
||||
PID1=$!
|
||||
|
||||
|
||||
@@ -212,9 +212,9 @@ exo run llama-3.2-3b --prompt "What is the meaning of exo?"
|
||||
|
||||
### Model Storage
|
||||
|
||||
Models by default are stored in `~/.cache/huggingface/hub`.
|
||||
Models by default are stored in `~/.cache/exo/downloads`.
|
||||
|
||||
You can set a different model storage location by setting the `HF_HOME` env var.
|
||||
You can set a different model storage location by setting the `EXO_HOME` env var.
|
||||
|
||||
## Debugging
|
||||
|
||||
|
||||
@@ -11,30 +11,27 @@ import aiohttp_cors
|
||||
import traceback
|
||||
import signal
|
||||
from exo import DEBUG, VERSION
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration import Node
|
||||
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
||||
from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
|
||||
from typing import Callable, Optional
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import platform
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.new_shard_download import delete_model
|
||||
import tempfile
|
||||
from exo.apputil import create_animation_mp4
|
||||
from collections import defaultdict
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
import tempfile
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
import shutil
|
||||
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
||||
from exo.apputil import create_animation_mp4
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||
@@ -277,41 +274,12 @@ class ChatGPTAPI:
|
||||
|
||||
async def handle_model_support(self, request):
|
||||
try:
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
})
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
|
||||
await response.prepare(request)
|
||||
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
|
||||
if self.inference_engine_classname in model_info.get("repo", {}):
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if shard:
|
||||
downloader = HFShardDownloader(quick_check=True)
|
||||
downloader.current_shard = shard
|
||||
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
status = await downloader.get_shard_download_status()
|
||||
|
||||
download_percentage = status.get("overall") if status else None
|
||||
total_size = status.get("total_size") if status else None
|
||||
total_downloaded = status.get("total_downloaded") if status else False
|
||||
|
||||
model_data = {
|
||||
model_name: {
|
||||
"name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
|
||||
"total_downloaded": total_downloaded
|
||||
}
|
||||
}
|
||||
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
|
||||
|
||||
downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
|
||||
for (path, d) in downloads:
|
||||
model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
@@ -348,6 +316,7 @@ class ChatGPTAPI:
|
||||
progress_data = {}
|
||||
for node_id, progress_event in self.node.node_download_progress.items():
|
||||
if isinstance(progress_event, RepoProgressEvent):
|
||||
if progress_event.status != "in_progress": continue
|
||||
progress_data[node_id] = progress_event.to_dict()
|
||||
else:
|
||||
print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
|
||||
@@ -574,46 +543,19 @@ class ChatGPTAPI:
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
async def handle_delete_model(self, request):
|
||||
model_id = request.match_info.get('model_name')
|
||||
try:
|
||||
model_name = request.match_info.get('model_name')
|
||||
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
|
||||
|
||||
if not model_name or model_name not in model_cards:
|
||||
return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
|
||||
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if not shard:
|
||||
return web.json_response({"detail": "Could not build shard for model"}, status=400)
|
||||
|
||||
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
|
||||
|
||||
# Get the HF cache directory using the helper function
|
||||
hf_home = get_hf_home()
|
||||
cache_dir = get_repo_root(repo_id)
|
||||
|
||||
if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
|
||||
|
||||
if os.path.exists(cache_dir):
|
||||
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
|
||||
try:
|
||||
shutil.rmtree(cache_dir)
|
||||
return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
|
||||
except Exception as e:
|
||||
return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
|
||||
else:
|
||||
return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
|
||||
|
||||
if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
|
||||
else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
|
||||
except Exception as e:
|
||||
print(f"Error in handle_delete_model: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_initial_models(self, request):
|
||||
model_data = {}
|
||||
for model_name, pretty in pretty_name.items():
|
||||
model_data[model_name] = {
|
||||
"name": pretty,
|
||||
for model_id in get_supported_models([[self.inference_engine_classname]]):
|
||||
model_data[model_id] = {
|
||||
"name": get_pretty_name(model_id),
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
@@ -659,7 +601,7 @@ class ChatGPTAPI:
|
||||
model_name = data.get("model")
|
||||
if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
|
||||
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
shard = build_full_shard(model_name, self.inference_engine_classname)
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Dict, Callable, Coroutine, Any, Literal
|
||||
from exo.inference.shard import Shard
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
@@ -14,11 +15,12 @@ class RepoFileProgressEvent:
|
||||
speed: int
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
|
||||
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
|
||||
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -29,6 +31,7 @@ class RepoFileProgressEvent:
|
||||
|
||||
@dataclass
|
||||
class RepoProgressEvent:
|
||||
shard: Shard
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
completed_files: int
|
||||
@@ -43,7 +46,7 @@ class RepoProgressEvent:
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
|
||||
"shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
|
||||
"downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
|
||||
"file_progress": {k: v.to_dict()
|
||||
for k, v in self.file_progress.items()}, "status": self.status
|
||||
@@ -53,6 +56,7 @@ class RepoProgressEvent:
|
||||
def from_dict(cls, data):
|
||||
if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
|
||||
if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
|
||||
if 'shard' in data: data['shard'] = Shard.from_dict(data['shard'])
|
||||
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@@ -1,36 +1,16 @@
|
||||
import aiofiles.os as aios
|
||||
from typing import Union
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional, Dict, List, Union
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable, TypeVar, TypedDict
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from exo.helpers import DEBUG, is_frozen
|
||||
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
|
||||
from typing import Generator, Iterable, TypeVar
|
||||
from exo.helpers import DEBUG
|
||||
from exo.inference.shard import Shard
|
||||
import aiofiles
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
|
||||
refs_dir = get_repo_root(repo_id)/"refs"
|
||||
refs_file = refs_dir/revision
|
||||
if await aios.path.exists(refs_file):
|
||||
async with aiofiles.open(refs_file, 'r') as f:
|
||||
commit_hash = (await f.read()).strip()
|
||||
snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
|
||||
return snapshot_dir
|
||||
return None
|
||||
|
||||
|
||||
def filter_repo_objects(
|
||||
items: Iterable[T],
|
||||
*,
|
||||
@@ -48,14 +28,12 @@ def filter_repo_objects(
|
||||
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
|
||||
|
||||
if key is None:
|
||||
|
||||
def _identity(item: T) -> str:
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, Path):
|
||||
return str(item)
|
||||
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
|
||||
|
||||
key = _identity
|
||||
|
||||
for item in items:
|
||||
@@ -66,22 +44,18 @@ def filter_repo_objects(
|
||||
continue
|
||||
yield item
|
||||
|
||||
|
||||
def _add_wildcard_to_directories(pattern: str) -> str:
|
||||
if pattern[-1] == "/":
|
||||
return pattern + "*"
|
||||
return pattern
|
||||
|
||||
|
||||
def get_hf_endpoint() -> str:
|
||||
return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
|
||||
|
||||
|
||||
def get_hf_home() -> Path:
|
||||
"""Get the Hugging Face home directory."""
|
||||
return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
|
||||
|
||||
|
||||
async def get_hf_token():
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
token_path = get_hf_home()/"token"
|
||||
@@ -90,7 +64,6 @@ async def get_hf_token():
|
||||
return (await f.read()).strip()
|
||||
return None
|
||||
|
||||
|
||||
async def get_auth_headers():
|
||||
"""Get authentication headers if a token is available."""
|
||||
token = await get_hf_token()
|
||||
@@ -98,325 +71,6 @@ async def get_auth_headers():
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
return {}
|
||||
|
||||
|
||||
def get_repo_root(repo_id: str) -> Path:
|
||||
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
||||
sanitized_repo_id = str(repo_id).replace("/", "--")
|
||||
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
|
||||
|
||||
async def move_models_to_hf(seed_dir: Union[str, Path]):
|
||||
"""Move model in resources folder of app to .cache/huggingface/hub"""
|
||||
source_dir = Path(seed_dir)
|
||||
dest_dir = get_hf_home()/"hub"
|
||||
await aios.makedirs(dest_dir, exist_ok=True)
|
||||
for path in source_dir.iterdir():
|
||||
if path.is_dir() and path.name.startswith("models--"):
|
||||
dest_path = dest_dir / path.name
|
||||
if await aios.path.exists(dest_path):
|
||||
print('Skipping moving model to .cache directory')
|
||||
else:
|
||||
try:
|
||||
await aios.rename(str(path), str(dest_path))
|
||||
except Exception as e:
|
||||
print(f'Error moving model to .cache: {e}')
|
||||
|
||||
|
||||
|
||||
async def fetch_file_list(session, repo_id, revision, path=""):
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
files = []
|
||||
for item in data:
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
|
||||
)
|
||||
async def download_file(
|
||||
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
|
||||
):
|
||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, file_path)
|
||||
local_path = os.path.join(save_directory, file_path)
|
||||
|
||||
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
# Check if file already exists and get its size
|
||||
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
|
||||
|
||||
headers = await get_auth_headers()
|
||||
if use_range_request:
|
||||
headers["Range"] = f"bytes={local_file_size}-"
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded_size = local_file_size
|
||||
downloaded_this_session = 0
|
||||
mode = 'ab' if use_range_request else 'wb'
|
||||
percentage = await get_file_download_percentage(
|
||||
session,
|
||||
repo_id,
|
||||
revision,
|
||||
file_path,
|
||||
Path(save_directory)
|
||||
)
|
||||
|
||||
if percentage == 100:
|
||||
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
if response.status == 200:
|
||||
# File doesn't support range requests or we're not using them, start from beginning
|
||||
mode = 'wb'
|
||||
downloaded_size = 0
|
||||
elif response.status == 206:
|
||||
# Partial content, resume download
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable, get the actual file size
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
if downloaded_size == total_size:
|
||||
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
else:
|
||||
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
|
||||
|
||||
if downloaded_size == total_size:
|
||||
print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
DOWNLOAD_CHUNK_SIZE = 32768
|
||||
start_time = datetime.now()
|
||||
async with aiofiles.open(local_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
||||
await f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
downloaded_this_session += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_size = total_size - downloaded_size
|
||||
eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
|
||||
status = "in_progress" if downloaded_size < total_size else "complete"
|
||||
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
||||
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
||||
|
||||
|
||||
async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
|
||||
repo_root = get_repo_root(repo_id)
|
||||
refs_dir = repo_root/"refs"
|
||||
refs_file = refs_dir/revision
|
||||
|
||||
# Check if we have a cached commit hash
|
||||
if await aios.path.exists(refs_file):
|
||||
async with aiofiles.open(refs_file, 'r') as f:
|
||||
commit_hash = (await f.read()).strip()
|
||||
if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
|
||||
return commit_hash
|
||||
|
||||
# Fetch the commit hash for the given revision
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(api_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
||||
revision_info = await response.json()
|
||||
commit_hash = revision_info['sha']
|
||||
|
||||
# Cache the commit hash
|
||||
await aios.makedirs(refs_dir, exist_ok=True)
|
||||
async with aiofiles.open(refs_file, 'w') as f:
|
||||
await f.write(commit_hash)
|
||||
|
||||
return commit_hash
|
||||
|
||||
|
||||
async def download_repo_files(
|
||||
repo_id: str,
|
||||
revision: str = "main",
|
||||
progress_callback: Optional[RepoProgressCallback] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_parallel_downloads: int = 4
|
||||
) -> Path:
|
||||
repo_root = get_repo_root(repo_id)
|
||||
snapshots_dir = repo_root/"snapshots"
|
||||
cachedreqs_dir = repo_root/"cachedreqs"
|
||||
|
||||
# Ensure directories exist
|
||||
await aios.makedirs(snapshots_dir, exist_ok=True)
|
||||
await aios.makedirs(cachedreqs_dir, exist_ok=True)
|
||||
|
||||
# Resolve revision to commit hash
|
||||
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
|
||||
|
||||
# Set up the snapshot directory
|
||||
snapshot_dir = snapshots_dir/commit_hash
|
||||
await aios.makedirs(snapshot_dir, exist_ok=True)
|
||||
|
||||
# Set up the cached file list directory
|
||||
cached_file_list_dir = cachedreqs_dir/commit_hash
|
||||
await aios.makedirs(cached_file_list_dir, exist_ok=True)
|
||||
cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Check if we have a cached file list
|
||||
if await aios.path.exists(cached_file_list_path):
|
||||
async with aiofiles.open(cached_file_list_path, 'r') as f:
|
||||
file_list = json.loads(await f.read())
|
||||
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
|
||||
else:
|
||||
file_list = await fetch_file_list(session, repo_id, revision)
|
||||
# Cache the file list
|
||||
async with aiofiles.open(cached_file_list_path, 'w') as f:
|
||||
await f.write(json.dumps(file_list))
|
||||
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
|
||||
|
||||
model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
|
||||
if model_index_exists:
|
||||
allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
|
||||
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
||||
total_files = len(filtered_file_list)
|
||||
total_bytes = sum(file["size"] for file in filtered_file_list)
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {
|
||||
file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
|
||||
for file in filtered_file_list
|
||||
}
|
||||
start_time = datetime.now()
|
||||
|
||||
async def download_with_progress(file_info, progress_state):
|
||||
local_path = snapshot_dir/file_info["path"]
|
||||
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
|
||||
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
|
||||
progress_state['completed_files'] += 1
|
||||
progress_state['downloaded_bytes'] += file_info["size"]
|
||||
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
async def file_progress_callback(event: RepoFileProgressEvent):
|
||||
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
|
||||
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
|
||||
file_progress[event.file_path] = event
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
|
||||
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
||||
progress_state['completed_files'] += 1
|
||||
file_progress[
|
||||
file_info["path"]
|
||||
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
|
||||
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
|
||||
async def download_with_semaphore(file_info):
|
||||
async with semaphore:
|
||||
await download_with_progress(file_info, progress_state)
|
||||
|
||||
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return snapshot_dir
|
||||
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Retrieve the weight map from the model.safetensors.index.json file.
|
||||
|
||||
Args:
|
||||
repo_id (str): The Hugging Face repository ID.
|
||||
revision (str): The revision of the repository to use.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
|
||||
"""
|
||||
|
||||
# Download the index file
|
||||
await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
|
||||
|
||||
# Check if the file exists
|
||||
repo_root = get_repo_root(repo_id)
|
||||
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
|
||||
snapshot_dir = repo_root/"snapshots"/commit_hash
|
||||
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
|
||||
|
||||
if index_file:
|
||||
index_file_path = snapshot_dir/index_file
|
||||
if await aios.path.exists(index_file_path):
|
||||
async with aiofiles.open(index_file_path, 'r') as f:
|
||||
index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_layer_num(tensor_name: str) -> Optional[int]:
|
||||
# This is a simple example and might need to be adjusted based on the actual naming convention
|
||||
parts = tensor_name.split('.')
|
||||
@@ -425,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
|
||||
return int(part)
|
||||
return None
|
||||
|
||||
|
||||
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
|
||||
shard_specific_patterns = set()
|
||||
@@ -443,65 +96,3 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
async def get_file_download_percentage(
|
||||
session: aiohttp.ClientSession,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
file_path: str,
|
||||
snapshot_dir: Path,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the download percentage for a file by comparing local and remote sizes.
|
||||
"""
|
||||
try:
|
||||
local_path = snapshot_dir / file_path
|
||||
if not await aios.path.exists(local_path):
|
||||
return 0
|
||||
|
||||
# Get local file size first
|
||||
local_size = await aios.path.getsize(local_path)
|
||||
if local_size == 0:
|
||||
return 0
|
||||
|
||||
# Check remote size
|
||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, file_path)
|
||||
headers = await get_auth_headers()
|
||||
|
||||
# Use HEAD request with redirect following for all files
|
||||
async with session.head(url, headers=headers, allow_redirects=True) as response:
|
||||
if response.status != 200:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to get remote file info for {file_path}: {response.status}")
|
||||
return 0
|
||||
|
||||
remote_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
if remote_size == 0:
|
||||
if DEBUG >= 2:
|
||||
print(f"Remote size is 0 for {file_path}")
|
||||
return 0
|
||||
|
||||
# Only return 100% if sizes match exactly
|
||||
if local_size == remote_size:
|
||||
return 100.0
|
||||
|
||||
# Calculate percentage based on sizes
|
||||
return (local_size / remote_size) * 100 if remote_size > 0 else 0
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Error checking file download status for {file_path}: {e}")
|
||||
return 0
|
||||
|
||||
async def has_hf_home_read_access() -> bool:
|
||||
hf_home = get_hf_home()
|
||||
try: return await aios.access(hf_home, os.R_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def has_hf_home_write_access() -> bool:
|
||||
hf_home = get_hf_home()
|
||||
try: return await aios.access(hf_home, os.W_OK)
|
||||
except OSError: return False
|
||||
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional, Union
|
||||
from exo.inference.shard import Shard
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.hf.hf_helpers import (
|
||||
download_repo_files, RepoProgressEvent, get_weight_map,
|
||||
get_allow_patterns, get_repo_root, fetch_file_list,
|
||||
get_local_snapshot_dir, get_file_download_percentage,
|
||||
filter_repo_objects
|
||||
)
|
||||
from exo.helpers import AsyncCallbackSystem, DEBUG
|
||||
from exo.models import model_cards, get_repo
|
||||
import aiohttp
|
||||
from aiofiles import os as aios
|
||||
|
||||
|
||||
class HFShardDownloader(ShardDownloader):
|
||||
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
|
||||
self.quick_check = quick_check
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
||||
self.completed_downloads: Dict[Shard, Path] = {}
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
self.current_shard: Optional[Shard] = None
|
||||
self.current_repo_id: Optional[str] = None
|
||||
self.revision: str = "main"
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
self.current_shard = shard
|
||||
self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
|
||||
repo_name = get_repo(shard.model_id, inference_engine_name)
|
||||
if shard in self.completed_downloads:
|
||||
return self.completed_downloads[shard]
|
||||
if self.quick_check:
|
||||
repo_root = get_repo_root(repo_name)
|
||||
snapshots_dir = repo_root/"snapshots"
|
||||
if snapshots_dir.exists():
|
||||
visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
|
||||
if visible_dirs:
|
||||
most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
|
||||
return most_recent_dir
|
||||
|
||||
# If a download on this shard is already in progress, keep that one
|
||||
for active_shard in self.active_downloads:
|
||||
if active_shard == shard:
|
||||
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
|
||||
return await self.active_downloads[shard]
|
||||
|
||||
# Cancel any downloads for this model_id on a different shard
|
||||
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
|
||||
for active_shard in existing_active_shards:
|
||||
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
|
||||
task = self.active_downloads[active_shard]
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass # This is expected when cancelling a task
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
|
||||
traceback.print_exc()
|
||||
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
|
||||
|
||||
# Start new download
|
||||
download_task = asyncio.create_task(self._download_shard(shard, repo_name))
|
||||
self.active_downloads[shard] = download_task
|
||||
try:
|
||||
path = await download_task
|
||||
self.completed_downloads[shard] = path
|
||||
return path
|
||||
finally:
|
||||
# Ensure the task is removed even if an exception occurs
|
||||
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
|
||||
if shard in self.active_downloads:
|
||||
self.active_downloads.pop(shard)
|
||||
|
||||
async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
|
||||
async def wrapped_progress_callback(event: RepoProgressEvent):
|
||||
self._on_progress.trigger_all(shard, event)
|
||||
|
||||
weight_map = await get_weight_map(repo_name)
|
||||
allow_patterns = get_allow_patterns(weight_map, shard)
|
||||
|
||||
return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self._on_progress
|
||||
|
||||
async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
|
||||
if not self.current_shard or not self.current_repo_id:
|
||||
if DEBUG >= 2:
|
||||
print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# If no snapshot directory exists, return None - no need to check remote files
|
||||
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
|
||||
if not snapshot_dir:
|
||||
if DEBUG >= 2:
|
||||
print(f"No snapshot directory found for {self.current_repo_id}")
|
||||
return None
|
||||
|
||||
if not await aios.path.exists(snapshot_dir/"model_index.json"):
|
||||
# Get the weight map to know what files we need
|
||||
weight_map = await get_weight_map(self.current_repo_id, self.revision)
|
||||
if not weight_map:
|
||||
if DEBUG >= 2:
|
||||
print(f"No weight map found for {self.current_repo_id}")
|
||||
return None
|
||||
|
||||
# Get all files needed for this shard
|
||||
patterns = get_allow_patterns(weight_map, self.current_shard)
|
||||
else:
|
||||
patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
|
||||
|
||||
|
||||
# Check download status for all relevant files
|
||||
status = {}
|
||||
total_bytes = 0
|
||||
downloaded_bytes = 0
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
||||
relevant_files = list(
|
||||
filter_repo_objects(
|
||||
file_list, allow_patterns=patterns, key=lambda x: x["path"]))
|
||||
|
||||
for file in relevant_files:
|
||||
file_size = file["size"]
|
||||
total_bytes += file_size
|
||||
|
||||
percentage = await get_file_download_percentage(
|
||||
session,
|
||||
self.current_repo_id,
|
||||
self.revision,
|
||||
file["path"],
|
||||
snapshot_dir,
|
||||
)
|
||||
status[file["path"]] = percentage
|
||||
downloaded_bytes += (file_size * (percentage / 100))
|
||||
|
||||
# Add overall progress weighted by file size
|
||||
if total_bytes > 0:
|
||||
status["overall"] = (downloaded_bytes / total_bytes) * 100
|
||||
else:
|
||||
status["overall"] = 0
|
||||
|
||||
# Add total size in bytes
|
||||
status["total_size"] = total_bytes
|
||||
if status["overall"] != 100:
|
||||
status["total_downloaded"] = downloaded_bytes
|
||||
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(f"Download calculation for {self.current_repo_id}:")
|
||||
print(f"Total bytes: {total_bytes}")
|
||||
print(f"Downloaded bytes: {downloaded_bytes}")
|
||||
if DEBUG >= 3:
|
||||
for file in relevant_files:
|
||||
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 3:
|
||||
print(f"Error getting shard download status: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
226
exo/download/new_shard_download.py
Normal file
226
exo/download/new_shard_download.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from exo.inference.shard import Shard
|
||||
from exo.models import get_repo
|
||||
from pathlib import Path
|
||||
from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
|
||||
from exo.helpers import AsyncCallbackSystem, DEBUG
|
||||
from exo.models import get_supported_models, build_full_shard
|
||||
import os
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Union, Tuple, Dict, List
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
def exo_home() -> Path:
|
||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||
|
||||
def exo_tmp() -> Path:
|
||||
return Path(tempfile.gettempdir())/"exo"
|
||||
|
||||
async def ensure_exo_tmp() -> Path:
|
||||
await aios.makedirs(exo_tmp(), exist_ok=True)
|
||||
return exo_tmp()
|
||||
|
||||
async def has_exo_home_read_access() -> bool:
|
||||
try: return await aios.access(exo_home(), os.R_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def has_exo_home_write_access() -> bool:
|
||||
try: return await aios.access(exo_home(), os.W_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def ensure_downloads_dir() -> Path:
|
||||
downloads_dir = exo_home()/"downloads"
|
||||
await aios.makedirs(downloads_dir, exist_ok=True)
|
||||
return downloads_dir
|
||||
|
||||
async def delete_model(model_id: str, inference_engine_name: str) -> bool:
|
||||
repo_id = get_repo(model_id, inference_engine_name)
|
||||
model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
||||
if not await aios.path.exists(model_dir): return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
return True
|
||||
|
||||
async def seed_models(seed_dir: Union[str, Path]):
|
||||
"""Move model in resources folder of app to .cache/huggingface/hub"""
|
||||
source_dir = Path(seed_dir)
|
||||
dest_dir = await ensure_downloads_dir()
|
||||
for path in source_dir.iterdir():
|
||||
if path.is_dir() and path.name.startswith("models--"):
|
||||
dest_path = dest_dir/path.name
|
||||
if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
|
||||
else:
|
||||
try: await aios.rename(str(path), str(dest_path))
|
||||
except:
|
||||
print(f"Error seeding model {path} to {dest_path}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def fetch_file_list(session, repo_id, revision, path=""):
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
files = []
|
||||
for item in data:
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
if (target_dir/path).exists(): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, path)
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
||||
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
||||
length = int(r.headers.get('content-length', 0))
|
||||
n_read = 0
|
||||
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
||||
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
||||
await aios.rename(temp_file.name, target_dir/path)
|
||||
return target_dir/path
|
||||
|
||||
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
||||
all_total_bytes = sum([p.total for p in file_progress.values()])
|
||||
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
|
||||
all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()])
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
|
||||
status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
|
||||
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
||||
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=10, sock_read=1800, sock_connect=10)) as session:
|
||||
index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
|
||||
async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
|
||||
try:
|
||||
weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except:
|
||||
if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}")
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
return ["*"]
|
||||
|
||||
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
||||
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
||||
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
||||
revision = "main"
|
||||
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
||||
if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
|
||||
|
||||
allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
file_list = await fetch_file_list(session, repo_id, revision)
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {}
|
||||
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
||||
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
||||
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
||||
speed = downloaded_this_session / (time.time() - start_time)
|
||||
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time)
|
||||
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
||||
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
async def download_with_semaphore(file):
|
||||
async with semaphore:
|
||||
await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
||||
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
||||
on_progress.trigger_all(shard, final_repo_progress)
|
||||
return target_dir, final_repo_progress
|
||||
|
||||
def new_shard_downloader() -> ShardDownloader:
|
||||
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
|
||||
|
||||
class SingletonShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self.shard_downloader.on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
|
||||
try: return await self.active_downloads[shard]
|
||||
finally:
|
||||
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
||||
|
||||
class CachedShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.cache: Dict[tuple[str, Shard], Path] = {}
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self.shard_downloader.on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
if (inference_engine_name, shard) in self.cache:
|
||||
if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
|
||||
return self.cache[(inference_engine_name, shard)]
|
||||
if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
|
||||
target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
|
||||
self.cache[(inference_engine_name, shard)] = target_dir
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
||||
|
||||
class NewShardDownloader(ShardDownloader):
|
||||
def __init__(self):
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self._on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
|
||||
downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
|
||||
if DEBUG >= 6: print("Downloaded shards:", downloads)
|
||||
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
|
||||
return [d for d in downloads if not isinstance(d, Exception)]
|
||||
|
||||
@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
"""Get the download status of shards.
|
||||
|
||||
Returns:
|
||||
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return AsyncCallbackSystem()
|
||||
|
||||
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
|
||||
return None
|
||||
|
||||
15
exo/download/test_new_shard_download.py
Normal file
15
exo/download/test_new_shard_download.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from exo.download.new_shard_download import download_shard, NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
async def test_new_shard_download():
|
||||
shard_downloader = NewShardDownloader()
|
||||
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
|
||||
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
|
||||
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
|
||||
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_new_shard_download())
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
import numpy as np
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from exo.models import build_base_shard
|
||||
from collections import deque
|
||||
@@ -10,7 +10,7 @@ from statistics import mean, median
|
||||
|
||||
async def test_non_blocking():
|
||||
# Setup
|
||||
shard_downloader = HFShardDownloader()
|
||||
shard_downloader = NewShardDownloader()
|
||||
engine = MLXDynamicShardInferenceEngine(shard_downloader)
|
||||
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
|
||||
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from exo.helpers import DEBUG
|
||||
import os
|
||||
@@ -44,13 +44,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
assert np.array_equal(next_resp_full, resp4)
|
||||
|
||||
|
||||
asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
|
||||
asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16))
|
||||
|
||||
if os.getenv("RUN_TINYGRAD", default="0") == "1":
|
||||
import tinygrad
|
||||
import os
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
||||
asyncio.run(
|
||||
test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
|
||||
)
|
||||
asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32))
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import traceback
|
||||
from aiofiles import os as aios
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from aiofiles import os as aios
|
||||
from typing import Union
|
||||
from transformers import AutoTokenizer, AutoProcessor
|
||||
import numpy as np
|
||||
from exo.download.hf.hf_helpers import get_local_snapshot_dir
|
||||
from exo.helpers import DEBUG
|
||||
from exo.download.new_shard_download import ensure_downloads_dir
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
@@ -24,25 +23,25 @@ class DummyTokenizer:
|
||||
return "dummy" * len(tokens)
|
||||
|
||||
|
||||
async def resolve_tokenizer(model_id: str):
|
||||
if model_id == "dummy":
|
||||
async def resolve_tokenizer(repo_id: Union[str, PathLike]):
|
||||
if repo_id == "dummy":
|
||||
return DummyTokenizer()
|
||||
local_path = await get_local_snapshot_dir(model_id)
|
||||
local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
|
||||
if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
|
||||
try:
|
||||
if local_path and await aios.path.exists(local_path):
|
||||
if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
|
||||
if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
|
||||
return await _resolve_tokenizer(local_path)
|
||||
except:
|
||||
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
|
||||
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
|
||||
if DEBUG >= 5: traceback.print_exc()
|
||||
return await _resolve_tokenizer(model_id)
|
||||
return await _resolve_tokenizer(repo_id)
|
||||
|
||||
|
||||
async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
|
||||
async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
|
||||
try:
|
||||
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
|
||||
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
|
||||
if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
|
||||
processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
|
||||
if not hasattr(processor, 'eos_token_id'):
|
||||
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
|
||||
if not hasattr(processor, 'encode'):
|
||||
@@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
|
||||
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
|
||||
return processor
|
||||
except Exception as e:
|
||||
if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
|
||||
if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
|
||||
if DEBUG >= 4: print(traceback.format_exc())
|
||||
|
||||
try:
|
||||
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
|
||||
return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
|
||||
if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
|
||||
return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
|
||||
except Exception as e:
|
||||
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
|
||||
if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
|
||||
if DEBUG >= 4: print(traceback.format_exc())
|
||||
|
||||
raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
|
||||
raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")
|
||||
|
||||
25
exo/main.py
25
exo/main.py
@@ -23,19 +23,18 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
|
||||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from exo.api import ChatGPTAPI
|
||||
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.download.shard_download import ShardDownloader, NoopShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.models import build_base_shard, get_repo
|
||||
from exo.viz.topology_viz import TopologyViz
|
||||
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
|
||||
import uvloop
|
||||
from contextlib import asynccontextmanager
|
||||
import concurrent.futures
|
||||
import socket
|
||||
import resource
|
||||
import psutil
|
||||
|
||||
@@ -117,8 +116,7 @@ print_yellow_exo()
|
||||
system_info = get_system_info()
|
||||
print(f"Detected system: {system_info}")
|
||||
|
||||
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
|
||||
max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
|
||||
shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
|
||||
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
||||
print(f"Inference engine name after selection: {inference_engine_name}")
|
||||
|
||||
@@ -175,10 +173,10 @@ node = Node(
|
||||
None,
|
||||
inference_engine,
|
||||
discovery,
|
||||
shard_downloader,
|
||||
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
|
||||
max_generate_tokens=args.max_generate_tokens,
|
||||
topology_viz=topology_viz,
|
||||
shard_downloader=shard_downloader,
|
||||
default_sample_temperature=args.default_temp
|
||||
)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
@@ -223,6 +221,7 @@ last_broadcast_time = 0
|
||||
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
||||
global last_broadcast_time
|
||||
current_time = time.time()
|
||||
if event.status == "not_started": return
|
||||
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
|
||||
last_broadcast_time = current_time
|
||||
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
||||
@@ -322,13 +321,13 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Check HuggingFace directory permissions
|
||||
hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
|
||||
if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
|
||||
# Check exo directory permissions
|
||||
home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
|
||||
if DEBUG >= 1: print(f"exo home directory: {home}")
|
||||
print(f"{has_read=}, {has_write=}")
|
||||
if not has_read or not has_write:
|
||||
print(f"""
|
||||
WARNING: Limited permissions for model storage directory: {hf_home}.
|
||||
WARNING: Limited permissions for exo home directory: {home}.
|
||||
This may prevent model downloads from working correctly.
|
||||
{"❌ No read access" if not has_read else ""}
|
||||
{"❌ No write access" if not has_write else ""}
|
||||
@@ -337,9 +336,9 @@ async def main():
|
||||
if not args.models_seed_dir is None:
|
||||
try:
|
||||
models_seed_dir = clean_path(args.models_seed_dir)
|
||||
await move_models_to_hf(models_seed_dir)
|
||||
await seed_models(models_seed_dir)
|
||||
except Exception as e:
|
||||
print(f"Error moving models to .cache/huggingface: {e}")
|
||||
print(f"Error seeding models: {e}")
|
||||
|
||||
def restore_cursor():
|
||||
if platform.system() != "Windows":
|
||||
|
||||
@@ -175,8 +175,11 @@ pretty_name = {
|
||||
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
||||
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
||||
"deepseek-v3": "Deepseek V3",
|
||||
"deepseek-v3-3bit": "Deepseek V3 (3-bit)",
|
||||
"deepseek-r1": "Deepseek R1",
|
||||
"deepseek-r1-3bit": "Deepseek R1 (3-bit)",
|
||||
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
||||
"qwen-2.5-0.5b": "Qwen 2.5 0.5B",
|
||||
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
|
||||
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
||||
"qwen-2.5-3b": "Qwen 2.5 3B",
|
||||
@@ -232,6 +235,9 @@ pretty_name = {
|
||||
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
||||
return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
|
||||
|
||||
def get_pretty_name(model_id: str) -> Optional[str]:
|
||||
return pretty_name.get(model_id, None)
|
||||
|
||||
def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
|
||||
repo = get_repo(model_id, inference_engine_classname)
|
||||
n_layers = model_cards.get(model_id, {}).get("layers", 0)
|
||||
@@ -239,7 +245,12 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
|
||||
return None
|
||||
return Shard(model_id, 0, 0, n_layers)
|
||||
|
||||
def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
|
||||
def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
|
||||
base_shard = build_base_shard(model_id, inference_engine_classname)
|
||||
if base_shard is None: return None
|
||||
return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers)
|
||||
|
||||
def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
|
||||
if not supported_inference_engine_lists:
|
||||
return list(model_cards.keys())
|
||||
|
||||
|
||||
@@ -13,9 +13,9 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
|
||||
from exo import DEBUG
|
||||
from exo.helpers import AsyncCallbackSystem
|
||||
from exo.viz.topology_viz import TopologyViz
|
||||
from exo.download.hf.hf_helpers import RepoProgressEvent
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
|
||||
class Node:
|
||||
def __init__(
|
||||
@@ -24,16 +24,17 @@ class Node:
|
||||
server: Server,
|
||||
inference_engine: InferenceEngine,
|
||||
discovery: Discovery,
|
||||
shard_downloader: ShardDownloader,
|
||||
partitioning_strategy: PartitioningStrategy = None,
|
||||
max_generate_tokens: int = 1024,
|
||||
default_sample_temperature: float = 0.0,
|
||||
topology_viz: Optional[TopologyViz] = None,
|
||||
shard_downloader: Optional[HFShardDownloader] = None,
|
||||
):
|
||||
self.id = _id
|
||||
self.inference_engine = inference_engine
|
||||
self.server = server
|
||||
self.discovery = discovery
|
||||
self.shard_downloader = shard_downloader
|
||||
self.partitioning_strategy = partitioning_strategy
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.topology: Topology = Topology()
|
||||
@@ -52,7 +53,6 @@ class Node:
|
||||
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
||||
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
||||
self.topology_inference_engines_pool: List[List[str]] = []
|
||||
self.shard_downloader = shard_downloader
|
||||
self.outstanding_requests = {}
|
||||
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from .node import Node
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
|
||||
from exo.download.shard_download import NoopShardDownloader
|
||||
|
||||
class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
@@ -22,7 +22,7 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
mock_peer2.id.return_value = "peer2"
|
||||
self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
|
||||
|
||||
self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
|
||||
self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader())
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await self.node.start()
|
||||
|
||||
@@ -843,4 +843,4 @@ main {
|
||||
font-size: 0.8em;
|
||||
color: var(--secondary-color-transparent);
|
||||
font-family: monospace;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,12 +75,12 @@ document.addEventListener("alpine:init", () => {
|
||||
while (true) {
|
||||
try {
|
||||
await this.populateSelector();
|
||||
// Wait 5 seconds before next poll
|
||||
await new Promise(resolve => setTimeout(resolve, 5000));
|
||||
// Wait 15 seconds before next poll
|
||||
await new Promise(resolve => setTimeout(resolve, 15000));
|
||||
} catch (error) {
|
||||
console.error('Model polling error:', error);
|
||||
// If there's an error, wait before retrying
|
||||
await new Promise(resolve => setTimeout(resolve, 5000));
|
||||
await new Promise(resolve => setTimeout(resolve, 15000));
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -637,6 +637,9 @@ document.addEventListener("alpine:init", () => {
|
||||
const vizElement = this.$refs.topologyViz;
|
||||
vizElement.innerHTML = ''; // Clear existing visualization
|
||||
|
||||
// Helper function to truncate node ID
|
||||
const truncateNodeId = (id) => id.substring(0, 8);
|
||||
|
||||
// Create nodes from object
|
||||
Object.entries(topologyData.nodes).forEach(([nodeId, node]) => {
|
||||
const nodeElement = document.createElement('div');
|
||||
@@ -647,14 +650,14 @@ document.addEventListener("alpine:init", () => {
|
||||
const peerConnectionsHtml = peerConnections.map(peer => `
|
||||
<div class="peer-connection">
|
||||
<i class="fas fa-arrow-right"></i>
|
||||
<span>To ${peer.to_id}: ${peer.description}</span>
|
||||
<span>To ${truncateNodeId(peer.to_id)}: ${peer.description}</span>
|
||||
</div>
|
||||
`).join('');
|
||||
|
||||
nodeElement.innerHTML = `
|
||||
<div class="node-info">
|
||||
<span class="status ${nodeId === topologyData.active_node_id ? 'active' : 'inactive'}"></span>
|
||||
<span>${node.model}</span>
|
||||
<span>${node.model} [${truncateNodeId(nodeId)}]</span>
|
||||
</div>
|
||||
<div class="node-details">
|
||||
<span>${node.chip}</span>
|
||||
|
||||
@@ -5,7 +5,7 @@ from exo.viz.topology_viz import TopologyViz
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
||||
from exo.topology.partitioning_strategy import Partition
|
||||
from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
|
||||
|
||||
def create_hf_repo_progress_event(
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Dict
|
||||
from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.partitioning_strategy import Partition
|
||||
from exo.download.hf.hf_helpers import RepoProgressEvent
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
|
||||
from rich.console import Console, Group
|
||||
from rich.text import Text
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
|
||||
|
||||
DEFAULT_ALLOW_PATTERNS = [
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.safetensors",
|
||||
]
|
||||
# Always ignore `.git` and `.cache/huggingface` folders in commits
|
||||
DEFAULT_IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".git/*",
|
||||
"*/.git",
|
||||
"**/.git/**",
|
||||
".cache/huggingface",
|
||||
".cache/huggingface/*",
|
||||
"*/.cache/huggingface",
|
||||
"**/.cache/huggingface/**",
|
||||
]
|
||||
|
||||
|
||||
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
|
||||
async def progress_callback(event: RepoProgressEvent):
|
||||
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
|
||||
print(f"Estimated time remaining: {event.overall_eta}")
|
||||
print("File Progress:")
|
||||
for file_path, progress in event.file_progress.items():
|
||||
status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
|
||||
eta_str = str(progress.eta)
|
||||
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
|
||||
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
|
||||
print("\n")
|
||||
|
||||
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
|
||||
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
|
||||
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
|
||||
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
|
||||
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
|
||||
@@ -1,26 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add the project root to the Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import asyncio
|
||||
from exo.download.hf.hf_helpers import get_weight_map
|
||||
|
||||
async def test_get_weight_map():
|
||||
repo_ids = [
|
||||
"mlx-community/quantized-gemma-2b",
|
||||
"mlx-community/Meta-Llama-3.1-8B-4bit",
|
||||
"mlx-community/Meta-Llama-3.1-70B-4bit",
|
||||
"mlx-community/Meta-Llama-3.1-405B-4bit",
|
||||
]
|
||||
for repo_id in repo_ids:
|
||||
weight_map = await get_weight_map(repo_id)
|
||||
assert weight_map is not None, "Weight map should not be None"
|
||||
assert isinstance(weight_map, dict), "Weight map should be a dictionary"
|
||||
assert len(weight_map) > 0, "Weight map should not be empty"
|
||||
print(f"OK: {repo_id}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_get_weight_map())
|
||||
Reference in New Issue
Block a user