beautiful download

This commit is contained in:
Alex Cheema
2025-02-01 17:29:19 +00:00
parent 7034ee0fcb
commit 2c0d17c336
4 changed files with 48 additions and 29 deletions

View File

@@ -276,9 +276,8 @@ class ChatGPTAPI:
try: try:
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' }) response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
await response.prepare(request) await response.prepare(request)
downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname) async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
for (path, d) in downloads: model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n") await response.write(b"data: [DONE]\n\n")
return response return response

View File

@@ -11,7 +11,7 @@ import aiofiles.os as aios
import aiohttp import aiohttp
import aiofiles import aiofiles
from urllib.parse import urljoin from urllib.parse import urljoin
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
import time import time
from datetime import timedelta from datetime import timedelta
import asyncio import asyncio
@@ -20,7 +20,7 @@ import traceback
import shutil import shutil
import tempfile import tempfile
import hashlib import hashlib
from tenacity import retry, stop_after_attempt, wait_fixed from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
def exo_home() -> Path: def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -70,8 +70,17 @@ async def seed_models(seed_dir: Union[str, Path]):
print(f"Error seeding model {path} to {dest_path}") print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc() traceback.print_exc()
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
file_list = await fetch_file_list(repo_id, revision)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
return file_list
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1)) @retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def fetch_file_list(repo_id, revision, path=""): async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url url = f"{api_url}/{path}" if path else api_url
@@ -106,14 +115,14 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
headers = await get_auth_headers() headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.head(url, headers=headers) as r: async with session.head(url, headers=headers) as r:
content_length = int(r.headers.get('content-length', 0)) content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag') etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}" assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}" assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1] if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
return content_length, etag return content_length, etag
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1)) @retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
if await aios.path.exists(target_dir/path): return target_dir/path if await aios.path.exists(target_dir/path): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True) await aios.makedirs((target_dir/path).parent, exist_ok=True)
@@ -124,9 +133,10 @@ async def download_file(repo_id: str, revision: str, path: str, target_dir: Path
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path) url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers() headers = await get_auth_headers()
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-' if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
n_read = 0 n_read = resume_byte_pos or 0
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r: async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}" assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f: async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length) while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
@@ -152,7 +162,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status) return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--") target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir) index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
return index_data.get("weight_map") return index_data.get("weight_map")
@@ -166,6 +176,12 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
if DEBUG >= 1: traceback.print_exc() if DEBUG >= 1: traceback.print_exc()
return ["*"] return ["*"]
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path): return (await aios.stat(path)).st_size
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
return 0
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]: async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
repo_id = get_repo(shard.model_id, inference_engine_classname) repo_id = get_repo(shard.model_id, inference_engine_classname)
@@ -180,7 +196,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}") if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
all_start_time = time.time() all_start_time = time.time()
file_list = await fetch_file_list(repo_id, revision) file_list = await fetch_file_list_with_cache(repo_id, revision)
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"])) filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
file_progress: Dict[str, RepoFileProgressEvent] = {} file_progress: Dict[str, RepoFileProgressEvent] = {}
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
@@ -192,7 +208,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)) on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
for file in filtered_file_list: for file in filtered_file_list:
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0 downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time()) file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
semaphore = asyncio.Semaphore(max_parallel_downloads) semaphore = asyncio.Semaphore(max_parallel_downloads)
@@ -225,8 +241,9 @@ class SingletonShardDownloader(ShardDownloader):
finally: finally:
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard] if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name) async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status
class CachedShardDownloader(ShardDownloader): class CachedShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader): def __init__(self, shard_downloader: ShardDownloader):
@@ -246,8 +263,9 @@ class CachedShardDownloader(ShardDownloader):
self.cache[(inference_engine_name, shard)] = target_dir self.cache[(inference_engine_name, shard)] = target_dir
return target_dir return target_dir
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name) async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status
class NewShardDownloader(ShardDownloader): class NewShardDownloader(ShardDownloader):
def __init__(self): def __init__(self):
@@ -261,9 +279,12 @@ class NewShardDownloader(ShardDownloader):
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress) target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
return target_dir return target_dir
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name) if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True) tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
if DEBUG >= 6: print("Downloaded shards:", downloads) for task in asyncio.as_completed(tasks):
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)]) try:
return [d for d in downloads if not isinstance(d, Exception)] path, progress = await task
yield (path, progress)
except Exception as e:
print("Error downloading shard:", e)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict from typing import Optional, Tuple, Dict, AsyncIterator
from pathlib import Path from pathlib import Path
from exo.inference.shard import Shard from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent from exo.download.download_progress import RepoProgressEvent
@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
pass pass
@abstractmethod @abstractmethod
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
"""Get the download status of shards. """Get the download status of shards.
Returns: Returns:
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return AsyncCallbackSystem() return AsyncCallbackSystem()
async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]: async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
return None if False: yield

View File

@@ -1,14 +1,13 @@
from exo.download.new_shard_download import download_shard, NewShardDownloader from exo.download.new_shard_download import NewShardDownloader
from exo.inference.shard import Shard from exo.inference.shard import Shard
from pathlib import Path
import asyncio import asyncio
async def test_new_shard_download(): async def test_new_shard_download():
shard_downloader = NewShardDownloader() shard_downloader = NewShardDownloader()
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event)) shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine") await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine") async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0}) print("Shard download status:", path, shard_status)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test_new_shard_download()) asyncio.run(test_new_shard_download())