diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index 0e8957ae..c692946b 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -11,7 +11,7 @@ import aiofiles.os as aios import aiohttp import aiofiles from urllib.parse import urljoin -from typing import Callable, Union, Tuple, Dict, List +from typing import Callable, Union, Tuple, Dict, List, Optional, Literal import time from datetime import timedelta import asyncio @@ -19,7 +19,8 @@ import json import traceback import shutil import tempfile -from tenacity import retry, stop_after_attempt, wait_exponential +import hashlib +from tenacity import retry, stop_after_attempt, wait_fixed def exo_home() -> Path: return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) @@ -69,7 +70,7 @@ async def seed_models(seed_dir: Union[str, Path]): print(f"Error seeding model {path} to {dest_path}") traceback.print_exc() -@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5)) +@retry(stop=stop_after_attempt(30), wait=wait_fixed(1)) async def fetch_file_list(repo_id, revision, path=""): api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" url = f"{api_url}/{path}" if path else api_url @@ -90,29 +91,54 @@ async def fetch_file_list(repo_id, revision, path=""): else: raise Exception(f"Failed to fetch file list: {response.status}") -@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5)) +async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str: + hash = hashlib.sha1() if type == "sha1" else hashlib.sha256() + if type == "sha1": + header = f"blob {(await aios.stat(path)).st_size}\0".encode() + hash.update(header) + async with aiofiles.open(path, 'rb') as f: + while chunk := await f.read(1024 * 1024): + hash.update(chunk) + return hash.hexdigest() + +async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]: + url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path) + headers = await get_auth_headers() + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session: + async with session.head(url, headers=headers) as r: + content_length = int(r.headers.get('content-length', 0)) + etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag') + assert content_length > 0, f"No content length for {url}" + assert etag is not None, f"No remote hash for {url}" + if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1] + return content_length, etag + +@retry(stop=stop_after_attempt(30), wait=wait_fixed(1)) async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: - temp_file_name = None - try: - if (target_dir/path).exists(): return target_dir/path - await aios.makedirs((target_dir/path).parent, exist_ok=True) - base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" - url = urljoin(base_url, path) + if await aios.path.exists(target_dir/path): return target_dir/path + await aios.makedirs((target_dir/path).parent, exist_ok=True) + length, remote_hash = await file_meta(repo_id, revision, path) + partial_path = target_dir/f"{path}.partial" + resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None + if resume_byte_pos != length: + url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path) headers = await get_auth_headers() + if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-' + n_read = 0 async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session: async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r: - assert r.status == 200, f"Failed to download {path} from {url}: {r.status}" - length = int(r.headers.get('content-length', 0)) - n_read = 0 - async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file: - temp_file_name = temp_file.name - while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length) - await aios.rename(temp_file.name, target_dir/path) - return target_dir/path - finally: - if temp_file_name: # attempt to delete tmp file if it still exists - try: await aios.unlink(temp_file_name) - except: pass + assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}" + async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f: + while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length) + + final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1") + integrity = final_hash == remote_hash + if not integrity: + try: await aios.remove(partial_path) + except Exception as e: print(f"Error removing partial file {partial_path}: {e}") + raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}") + await aios.rename(partial_path, target_dir/path) + return target_dir/path def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: