Merge pull request #664 from exo-explore/resumedownload

resumable downloads with integrity checks
This commit is contained in:
Alex Cheema
2025-02-01 18:34:36 +00:00
committed by GitHub
4 changed files with 91 additions and 46 deletions

View File

@@ -276,9 +276,8 @@ class ChatGPTAPI:
try:
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
await response.prepare(request)
downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
for (path, d) in downloads:
model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
return response

View File

@@ -11,7 +11,7 @@ import aiofiles.os as aios
import aiohttp
import aiofiles
from urllib.parse import urljoin
from typing import Callable, Union, Tuple, Dict, List
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
import time
from datetime import timedelta
import asyncio
@@ -19,7 +19,8 @@ import json
import traceback
import shutil
import tempfile
from tenacity import retry, stop_after_attempt, wait_exponential
import hashlib
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -69,8 +70,17 @@ async def seed_models(seed_dir: Union[str, Path]):
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
async def fetch_file_list(repo_id, revision, path=""):
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
file_list = await fetch_file_list(repo_id, revision)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
return file_list
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
@@ -90,29 +100,55 @@ async def fetch_file_list(repo_id, revision, path=""):
else:
raise Exception(f"Failed to fetch file list: {response.status}")
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
if type == "sha1":
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
hash.update(header)
async with aiofiles.open(path, 'rb') as f:
while chunk := await f.read(1024 * 1024):
hash.update(chunk)
return hash.hexdigest()
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.head(url, headers=headers) as r:
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
return content_length, etag
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
temp_file_name = None
try:
if (target_dir/path).exists(): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, path)
if await aios.path.exists(target_dir/path): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
length, remote_hash = await file_meta(repo_id, revision, path)
partial_path = target_dir/f"{path}.partial"
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
n_read = resume_byte_pos or 0
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
length = int(r.headers.get('content-length', 0))
n_read = 0
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
temp_file_name = temp_file.name
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
await aios.rename(temp_file.name, target_dir/path)
return target_dir/path
finally:
if temp_file_name: # attempt to delete tmp file if it still exists
try: await aios.unlink(temp_file_name)
except: pass
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
integrity = final_hash == remote_hash
if not integrity:
try: await aios.remove(partial_path)
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
await aios.rename(partial_path, target_dir/path)
return target_dir/path
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
@@ -126,7 +162,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
return index_data.get("weight_map")
@@ -140,6 +176,12 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
if DEBUG >= 1: traceback.print_exc()
return ["*"]
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path): return (await aios.stat(path)).st_size
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
return 0
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
repo_id = get_repo(shard.model_id, inference_engine_classname)
@@ -154,7 +196,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
all_start_time = time.time()
file_list = await fetch_file_list(repo_id, revision)
file_list = await fetch_file_list_with_cache(repo_id, revision)
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
file_progress: Dict[str, RepoFileProgressEvent] = {}
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
@@ -166,7 +208,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
for file in filtered_file_list:
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
semaphore = asyncio.Semaphore(max_parallel_downloads)
@@ -199,8 +241,9 @@ class SingletonShardDownloader(ShardDownloader):
finally:
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status
class CachedShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
@@ -220,8 +263,9 @@ class CachedShardDownloader(ShardDownloader):
self.cache[(inference_engine_name, shard)] = target_dir
return target_dir
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status
class NewShardDownloader(ShardDownloader):
def __init__(self):
@@ -235,9 +279,12 @@ class NewShardDownloader(ShardDownloader):
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
return target_dir
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
if DEBUG >= 6: print("Downloaded shards:", downloads)
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
return [d for d in downloads if not isinstance(d, Exception)]
tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
for task in asyncio.as_completed(tasks):
try:
path, progress = await task
yield (path, progress)
except Exception as e:
print("Error downloading shard:", e)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple, Dict, AsyncIterator
from pathlib import Path
from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent
@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
pass
@abstractmethod
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
"""Get the download status of shards.
Returns:
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return AsyncCallbackSystem()
async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
return None
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
if False: yield

View File

@@ -1,14 +1,13 @@
from exo.download.new_shard_download import download_shard, NewShardDownloader
from exo.download.new_shard_download import NewShardDownloader
from exo.inference.shard import Shard
from pathlib import Path
import asyncio
async def test_new_shard_download():
shard_downloader = NewShardDownloader()
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
print("Shard download status:", path, shard_status)
if __name__ == "__main__":
asyncio.run(test_new_shard_download())