Merge pull request #660 from exo-explore/robustdownload

cleanup tmp files on failed download
This commit is contained in:
Alex Cheema
2025-01-30 20:25:15 +00:00
committed by GitHub

View File

@@ -75,7 +75,7 @@ async def fetch_file_list(repo_id, revision, path=""):
url = f"{api_url}/{path}" if path else api_url url = f"{api_url}/{path}" if path else api_url
headers = await get_auth_headers() headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=1800, sock_connect=60)) as session: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=30, sock_connect=10)) as session:
async with session.get(url, headers=headers) as response: async with session.get(url, headers=headers) as response:
if response.status == 200: if response.status == 200:
data = await response.json() data = await response.json()
@@ -84,7 +84,7 @@ async def fetch_file_list(repo_id, revision, path=""):
if item["type"] == "file": if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]}) files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory": elif item["type"] == "directory":
subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) subfiles = await fetch_file_list(repo_id, revision, item["path"])
files.extend(subfiles) files.extend(subfiles)
return files return files
else: else:
@@ -92,20 +92,28 @@ async def fetch_file_list(repo_id, revision, path=""):
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5)) @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
if (target_dir/path).exists(): return target_dir/path temp_file_name = None
await aios.makedirs((target_dir/path).parent, exist_ok=True) try:
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" if (target_dir/path).exists(): return target_dir/path
url = urljoin(base_url, path) await aios.makedirs((target_dir/path).parent, exist_ok=True)
headers = await get_auth_headers() base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session: url = urljoin(base_url, path)
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r: headers = await get_auth_headers()
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}" async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
length = int(r.headers.get('content-length', 0)) async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
n_read = 0 assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file: length = int(r.headers.get('content-length', 0))
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length) n_read = 0
await aios.rename(temp_file.name, target_dir/path) async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
return target_dir/path temp_file_name = temp_file.name
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
await aios.rename(temp_file.name, target_dir/path)
return target_dir/path
finally:
if temp_file_name: # attempt to delete tmp file if it still exists
try: await aios.unlink(temp_file_name)
except: pass
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
all_total_bytes = sum([p.total for p in file_progress.values()]) all_total_bytes = sum([p.total for p in file_progress.values()])
@@ -161,17 +169,17 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0 downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time()) file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
semaphore = asyncio.Semaphore(max_parallel_downloads) semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file): async def download_with_semaphore(file):
async with semaphore: async with semaphore:
await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes)) await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list]) if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time) final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
on_progress.trigger_all(shard, final_repo_progress) on_progress.trigger_all(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None): if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
return target_dir/gguf["path"], final_repo_progress return target_dir/gguf["path"], final_repo_progress
else: else:
return target_dir, final_repo_progress return target_dir, final_repo_progress
def new_shard_downloader() -> ShardDownloader: def new_shard_downloader() -> ShardDownloader:
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader())) return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
@@ -233,4 +241,3 @@ class NewShardDownloader(ShardDownloader):
if DEBUG >= 6: print("Downloaded shards:", downloads) if DEBUG >= 6: print("Downloaded shards:", downloads)
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)]) if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
return [d for d in downloads if not isinstance(d, Exception)] return [d for d in downloads if not isinstance(d, Exception)]