mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
resumable downloads with integrity checks
This commit is contained in:
@@ -11,7 +11,7 @@ import aiofiles.os as aios
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
from typing import Callable, Union, Tuple, Dict, List
|
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal
|
||||||
import time
|
import time
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -19,7 +19,8 @@ import json
|
|||||||
import traceback
|
import traceback
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
import hashlib
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
|
|
||||||
def exo_home() -> Path:
|
def exo_home() -> Path:
|
||||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||||
@@ -69,7 +70,7 @@ async def seed_models(seed_dir: Union[str, Path]):
|
|||||||
print(f"Error seeding model {path} to {dest_path}")
|
print(f"Error seeding model {path} to {dest_path}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
||||||
async def fetch_file_list(repo_id, revision, path=""):
|
async def fetch_file_list(repo_id, revision, path=""):
|
||||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||||
url = f"{api_url}/{path}" if path else api_url
|
url = f"{api_url}/{path}" if path else api_url
|
||||||
@@ -90,29 +91,54 @@ async def fetch_file_list(repo_id, revision, path=""):
|
|||||||
else:
|
else:
|
||||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
|
||||||
|
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
|
||||||
|
if type == "sha1":
|
||||||
|
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
|
||||||
|
hash.update(header)
|
||||||
|
async with aiofiles.open(path, 'rb') as f:
|
||||||
|
while chunk := await f.read(1024 * 1024):
|
||||||
|
hash.update(chunk)
|
||||||
|
return hash.hexdigest()
|
||||||
|
|
||||||
|
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
||||||
|
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||||
|
headers = await get_auth_headers()
|
||||||
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||||
|
async with session.head(url, headers=headers) as r:
|
||||||
|
content_length = int(r.headers.get('content-length', 0))
|
||||||
|
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||||
|
assert content_length > 0, f"No content length for {url}"
|
||||||
|
assert etag is not None, f"No remote hash for {url}"
|
||||||
|
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
||||||
|
return content_length, etag
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
||||||
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||||
temp_file_name = None
|
if await aios.path.exists(target_dir/path): return target_dir/path
|
||||||
try:
|
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||||
if (target_dir/path).exists(): return target_dir/path
|
length, remote_hash = await file_meta(repo_id, revision, path)
|
||||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
partial_path = target_dir/f"{path}.partial"
|
||||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
|
||||||
url = urljoin(base_url, path)
|
if resume_byte_pos != length:
|
||||||
|
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||||
headers = await get_auth_headers()
|
headers = await get_auth_headers()
|
||||||
|
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
|
||||||
|
n_read = 0
|
||||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
||||||
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
||||||
length = int(r.headers.get('content-length', 0))
|
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
||||||
n_read = 0
|
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
||||||
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
|
||||||
temp_file_name = temp_file.name
|
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
|
||||||
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
integrity = final_hash == remote_hash
|
||||||
await aios.rename(temp_file.name, target_dir/path)
|
if not integrity:
|
||||||
return target_dir/path
|
try: await aios.remove(partial_path)
|
||||||
finally:
|
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
|
||||||
if temp_file_name: # attempt to delete tmp file if it still exists
|
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
|
||||||
try: await aios.unlink(temp_file_name)
|
await aios.rename(partial_path, target_dir/path)
|
||||||
except: pass
|
return target_dir/path
|
||||||
|
|
||||||
|
|
||||||
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
||||||
|
|||||||
Reference in New Issue
Block a user