mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
resumable downloads with integrity checks
This commit is contained in:
@@ -11,7 +11,7 @@ import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Union, Tuple, Dict, List
|
||||
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import asyncio
|
||||
@@ -19,7 +19,8 @@ import json
|
||||
import traceback
|
||||
import shutil
|
||||
import tempfile
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
import hashlib
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
def exo_home() -> Path:
|
||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||
@@ -69,7 +70,7 @@ async def seed_models(seed_dir: Union[str, Path]):
|
||||
print(f"Error seeding model {path} to {dest_path}")
|
||||
traceback.print_exc()
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
||||
async def fetch_file_list(repo_id, revision, path=""):
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
@@ -90,29 +91,54 @@ async def fetch_file_list(repo_id, revision, path=""):
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
||||
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
|
||||
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
|
||||
if type == "sha1":
|
||||
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
|
||||
hash.update(header)
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
while chunk := await f.read(1024 * 1024):
|
||||
hash.update(chunk)
|
||||
return hash.hexdigest()
|
||||
|
||||
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.head(url, headers=headers) as r:
|
||||
content_length = int(r.headers.get('content-length', 0))
|
||||
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
|
||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
||||
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
temp_file_name = None
|
||||
try:
|
||||
if (target_dir/path).exists(): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, path)
|
||||
if await aios.path.exists(target_dir/path): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
length, remote_hash = await file_meta(repo_id, revision, path)
|
||||
partial_path = target_dir/f"{path}.partial"
|
||||
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
|
||||
n_read = 0
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
||||
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
||||
length = int(r.headers.get('content-length', 0))
|
||||
n_read = 0
|
||||
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
||||
temp_file_name = temp_file.name
|
||||
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
||||
await aios.rename(temp_file.name, target_dir/path)
|
||||
return target_dir/path
|
||||
finally:
|
||||
if temp_file_name: # attempt to delete tmp file if it still exists
|
||||
try: await aios.unlink(temp_file_name)
|
||||
except: pass
|
||||
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
||||
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
||||
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
||||
|
||||
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
|
||||
integrity = final_hash == remote_hash
|
||||
if not integrity:
|
||||
try: await aios.remove(partial_path)
|
||||
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
|
||||
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
|
||||
await aios.rename(partial_path, target_dir/path)
|
||||
return target_dir/path
|
||||
|
||||
|
||||
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
||||
|
||||
Reference in New Issue
Block a user