resumable downloads with integrity checks

This commit is contained in:
Alex Cheema
2025-02-01 13:22:51 +00:00
parent 0bebf8dfde
commit 7034ee0fcb

View File

@@ -11,7 +11,7 @@ import aiofiles.os as aios
import aiohttp
import aiofiles
from urllib.parse import urljoin
from typing import Callable, Union, Tuple, Dict, List
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal
import time
from datetime import timedelta
import asyncio
@@ -19,7 +19,8 @@ import json
import traceback
import shutil
import tempfile
from tenacity import retry, stop_after_attempt, wait_exponential
import hashlib
from tenacity import retry, stop_after_attempt, wait_fixed
def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -69,7 +70,7 @@ async def seed_models(seed_dir: Union[str, Path]):
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def fetch_file_list(repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
@@ -90,29 +91,54 @@ async def fetch_file_list(repo_id, revision, path=""):
else:
raise Exception(f"Failed to fetch file list: {response.status}")
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
temp_file_name = None
try:
if (target_dir/path).exists(): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, path)
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
if type == "sha1":
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
hash.update(header)
async with aiofiles.open(path, 'rb') as f:
while chunk := await f.read(1024 * 1024):
hash.update(chunk)
return hash.hexdigest()
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
length = int(r.headers.get('content-length', 0))
async with session.head(url, headers=headers) as r:
content_length = int(r.headers.get('content-length', 0))
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
return content_length, etag
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
if await aios.path.exists(target_dir/path): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
length, remote_hash = await file_meta(repo_id, revision, path)
partial_path = target_dir/f"{path}.partial"
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
n_read = 0
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
temp_file_name = temp_file.name
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
await aios.rename(temp_file.name, target_dir/path)
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
integrity = final_hash == remote_hash
if not integrity:
try: await aios.remove(partial_path)
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
await aios.rename(partial_path, target_dir/path)
return target_dir/path
finally:
if temp_file_name: # attempt to delete tmp file if it still exists
try: await aios.unlink(temp_file_name)
except: pass
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: