mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #664 from exo-explore/resumedownload
resumable downloads with integrity checks
This commit is contained in:
@@ -276,9 +276,8 @@ class ChatGPTAPI:
|
||||
try:
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
|
||||
await response.prepare(request)
|
||||
downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
|
||||
for (path, d) in downloads:
|
||||
model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
|
||||
async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
|
||||
model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
@@ -11,7 +11,7 @@ import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Union, Tuple, Dict, List
|
||||
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import asyncio
|
||||
@@ -19,7 +19,8 @@ import json
|
||||
import traceback
|
||||
import shutil
|
||||
import tempfile
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
import hashlib
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
|
||||
|
||||
def exo_home() -> Path:
|
||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||
@@ -69,8 +70,17 @@ async def seed_models(seed_dir: Union[str, Path]):
|
||||
print(f"Error seeding model {path} to {dest_path}")
|
||||
traceback.print_exc()
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
||||
async def fetch_file_list(repo_id, revision, path=""):
|
||||
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
|
||||
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
||||
file_list = await fetch_file_list(repo_id, revision)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
||||
return file_list
|
||||
|
||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
||||
async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
@@ -90,29 +100,55 @@ async def fetch_file_list(repo_id, revision, path=""):
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
||||
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
|
||||
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
|
||||
if type == "sha1":
|
||||
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
|
||||
hash.update(header)
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
while chunk := await f.read(1024 * 1024):
|
||||
hash.update(chunk)
|
||||
return hash.hexdigest()
|
||||
|
||||
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.head(url, headers=headers) as r:
|
||||
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
|
||||
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
|
||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
|
||||
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
temp_file_name = None
|
||||
try:
|
||||
if (target_dir/path).exists(): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, path)
|
||||
if await aios.path.exists(target_dir/path): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
length, remote_hash = await file_meta(repo_id, revision, path)
|
||||
partial_path = target_dir/f"{path}.partial"
|
||||
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
|
||||
n_read = resume_byte_pos or 0
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
||||
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
||||
length = int(r.headers.get('content-length', 0))
|
||||
n_read = 0
|
||||
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
||||
temp_file_name = temp_file.name
|
||||
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
||||
await aios.rename(temp_file.name, target_dir/path)
|
||||
return target_dir/path
|
||||
finally:
|
||||
if temp_file_name: # attempt to delete tmp file if it still exists
|
||||
try: await aios.unlink(temp_file_name)
|
||||
except: pass
|
||||
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
|
||||
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
||||
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
||||
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
||||
|
||||
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
|
||||
integrity = final_hash == remote_hash
|
||||
if not integrity:
|
||||
try: await aios.remove(partial_path)
|
||||
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
|
||||
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
|
||||
await aios.rename(partial_path, target_dir/path)
|
||||
return target_dir/path
|
||||
|
||||
|
||||
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
||||
@@ -126,7 +162,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
|
||||
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
||||
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
|
||||
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
||||
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
@@ -140,6 +176,12 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
return ["*"]
|
||||
|
||||
async def get_downloaded_size(path: Path) -> int:
|
||||
partial_path = path.with_suffix(path.suffix + ".partial")
|
||||
if await aios.path.exists(path): return (await aios.stat(path)).st_size
|
||||
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
|
||||
return 0
|
||||
|
||||
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
||||
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
||||
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
||||
@@ -154,7 +196,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
||||
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
file_list = await fetch_file_list(repo_id, revision)
|
||||
file_list = await fetch_file_list_with_cache(repo_id, revision)
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {}
|
||||
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
||||
@@ -166,7 +208,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
||||
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
||||
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
||||
downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
@@ -199,8 +241,9 @@ class SingletonShardDownloader(ShardDownloader):
|
||||
finally:
|
||||
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
||||
yield path, status
|
||||
|
||||
class CachedShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
@@ -220,8 +263,9 @@ class CachedShardDownloader(ShardDownloader):
|
||||
self.cache[(inference_engine_name, shard)] = target_dir
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
||||
yield path, status
|
||||
|
||||
class NewShardDownloader(ShardDownloader):
|
||||
def __init__(self):
|
||||
@@ -235,9 +279,12 @@ class NewShardDownloader(ShardDownloader):
|
||||
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
|
||||
downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
|
||||
if DEBUG >= 6: print("Downloaded shards:", downloads)
|
||||
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
|
||||
return [d for d in downloads if not isinstance(d, Exception)]
|
||||
tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
path, progress = await task
|
||||
yield (path, progress)
|
||||
except Exception as e:
|
||||
print("Error downloading shard:", e)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple, Dict
|
||||
from typing import Optional, Tuple, Dict, AsyncIterator
|
||||
from pathlib import Path
|
||||
from exo.inference.shard import Shard
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
"""Get the download status of shards.
|
||||
|
||||
Returns:
|
||||
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return AsyncCallbackSystem()
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
|
||||
return None
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
if False: yield
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from exo.download.new_shard_download import download_shard, NewShardDownloader
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
async def test_new_shard_download():
|
||||
shard_downloader = NewShardDownloader()
|
||||
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
|
||||
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
|
||||
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
|
||||
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
|
||||
async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
|
||||
print("Shard download status:", path, shard_status)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_new_shard_download())
|
||||
|
||||
Reference in New Issue
Block a user