mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
retry fetch_file_list also
This commit is contained in:
@@ -69,6 +69,7 @@ async def seed_models(seed_dir: Union[str, Path]):
|
||||
print(f"Error seeding model {path} to {dest_path}")
|
||||
traceback.print_exc()
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
||||
async def fetch_file_list(repo_id, revision, path=""):
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
@@ -151,8 +152,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
||||
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
||||
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
||||
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
||||
speed = downloaded_this_session / (time.time() - start_time)
|
||||
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
|
||||
speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
|
||||
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
|
||||
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
||||
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
||||
|
||||
Reference in New Issue
Block a user