mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
remove tenacity dependency, implement simple retry logic instead
This commit is contained in:
@@ -20,7 +20,6 @@ import traceback
|
||||
import shutil
|
||||
import tempfile
|
||||
import hashlib
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
|
||||
|
||||
def exo_home() -> Path:
|
||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||
@@ -74,13 +73,20 @@ async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> Li
|
||||
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
||||
file_list = await fetch_file_list(repo_id, revision)
|
||||
file_list = await fetch_file_list_with_retry(repo_id, revision)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
||||
return file_list
|
||||
|
||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
||||
async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try: return await _fetch_file_list(repo_id, revision, path)
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1: raise e
|
||||
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||
|
||||
async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
@@ -94,7 +100,7 @@ async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "")
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await fetch_file_list(repo_id, revision, item["path"])
|
||||
subfiles = await _fetch_file_list(repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
@@ -122,8 +128,15 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
|
||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
|
||||
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
|
||||
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||
|
||||
async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
if await aios.path.exists(target_dir/path): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
length, remote_hash = await file_meta(repo_id, revision, path)
|
||||
@@ -163,7 +176,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
||||
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
||||
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
|
||||
@@ -214,7 +227,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
async def download_with_semaphore(file):
|
||||
async with semaphore:
|
||||
await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||
await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
||||
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
||||
on_progress.trigger_all(shard, final_repo_progress)
|
||||
|
||||
Reference in New Issue
Block a user