resumable downloads with integrity checks

This commit is contained in:
Alex Cheema
2025-02-01 13:22:51 +00:00
parent 0bebf8dfde
commit 7034ee0fcb

View File

@@ -11,7 +11,7 @@ import aiofiles.os as aios
import aiohttp import aiohttp
import aiofiles import aiofiles
from urllib.parse import urljoin from urllib.parse import urljoin
from typing import Callable, Union, Tuple, Dict, List from typing import Callable, Union, Tuple, Dict, List, Optional, Literal
import time import time
from datetime import timedelta from datetime import timedelta
import asyncio import asyncio
@@ -19,7 +19,8 @@ import json
import traceback import traceback
import shutil import shutil
import tempfile import tempfile
from tenacity import retry, stop_after_attempt, wait_exponential import hashlib
from tenacity import retry, stop_after_attempt, wait_fixed
def exo_home() -> Path: def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -69,7 +70,7 @@ async def seed_models(seed_dir: Union[str, Path]):
print(f"Error seeding model {path} to {dest_path}") print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc() traceback.print_exc()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5)) @retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def fetch_file_list(repo_id, revision, path=""): async def fetch_file_list(repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url url = f"{api_url}/{path}" if path else api_url
@@ -90,29 +91,54 @@ async def fetch_file_list(repo_id, revision, path=""):
else: else:
raise Exception(f"Failed to fetch file list: {response.status}") raise Exception(f"Failed to fetch file list: {response.status}")
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5)) async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
if type == "sha1":
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
hash.update(header)
async with aiofiles.open(path, 'rb') as f:
while chunk := await f.read(1024 * 1024):
hash.update(chunk)
return hash.hexdigest()
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.head(url, headers=headers) as r:
content_length = int(r.headers.get('content-length', 0))
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
return content_length, etag
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
temp_file_name = None if await aios.path.exists(target_dir/path): return target_dir/path
try: await aios.makedirs((target_dir/path).parent, exist_ok=True)
if (target_dir/path).exists(): return target_dir/path length, remote_hash = await file_meta(repo_id, revision, path)
await aios.makedirs((target_dir/path).parent, exist_ok=True) partial_path = target_dir/f"{path}.partial"
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
url = urljoin(base_url, path) if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers() headers = await get_auth_headers()
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
n_read = 0
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r: async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}" assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
length = int(r.headers.get('content-length', 0)) async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
n_read = 0 while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
temp_file_name = temp_file.name final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length) integrity = final_hash == remote_hash
await aios.rename(temp_file.name, target_dir/path) if not integrity:
return target_dir/path try: await aios.remove(partial_path)
finally: except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
if temp_file_name: # attempt to delete tmp file if it still exists raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
try: await aios.unlink(temp_file_name) await aios.rename(partial_path, target_dir/path)
except: pass return target_dir/path
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: