diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index b741fd8c..35ae31dc 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -69,6 +69,7 @@ async def seed_models(seed_dir: Union[str, Path]): print(f"Error seeding model {path} to {dest_path}") traceback.print_exc() +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5)) async def fetch_file_list(repo_id, revision, path=""): api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" url = f"{api_url}/{path}" if path else api_url @@ -151,8 +152,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time() downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes - speed = downloaded_this_session / (time.time() - start_time) - eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) + speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0 + eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0) file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time) on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)) if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")