mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #675 from exo-explore/rmtenacity
remove tenacity dependency, implement simple retry logic instead
This commit is contained in:
@@ -20,7 +20,6 @@ import traceback
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import hashlib
|
import hashlib
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
|
|
||||||
|
|
||||||
def exo_home() -> Path:
|
def exo_home() -> Path:
|
||||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||||
@@ -74,13 +73,20 @@ async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> Li
|
|||||||
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||||
if await aios.path.exists(cache_file):
|
if await aios.path.exists(cache_file):
|
||||||
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
||||||
file_list = await fetch_file_list(repo_id, revision)
|
file_list = await fetch_file_list_with_retry(repo_id, revision)
|
||||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||||
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
||||||
return file_list
|
return file_list
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||||
async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
n_attempts = 30
|
||||||
|
for attempt in range(n_attempts):
|
||||||
|
try: return await _fetch_file_list(repo_id, revision, path)
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == n_attempts - 1: raise e
|
||||||
|
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||||
|
|
||||||
|
async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||||
url = f"{api_url}/{path}" if path else api_url
|
url = f"{api_url}/{path}" if path else api_url
|
||||||
|
|
||||||
@@ -94,7 +100,7 @@ async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "")
|
|||||||
if item["type"] == "file":
|
if item["type"] == "file":
|
||||||
files.append({"path": item["path"], "size": item["size"]})
|
files.append({"path": item["path"], "size": item["size"]})
|
||||||
elif item["type"] == "directory":
|
elif item["type"] == "directory":
|
||||||
subfiles = await fetch_file_list(repo_id, revision, item["path"])
|
subfiles = await _fetch_file_list(repo_id, revision, item["path"])
|
||||||
files.extend(subfiles)
|
files.extend(subfiles)
|
||||||
return files
|
return files
|
||||||
else:
|
else:
|
||||||
@@ -122,8 +128,15 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
|||||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
||||||
return content_length, etag
|
return content_length, etag
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
|
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||||
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
n_attempts = 30
|
||||||
|
for attempt in range(n_attempts):
|
||||||
|
try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
|
||||||
|
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||||
|
|
||||||
|
async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||||
if await aios.path.exists(target_dir/path): return target_dir/path
|
if await aios.path.exists(target_dir/path): return target_dir/path
|
||||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||||
length, remote_hash = await file_meta(repo_id, revision, path)
|
length, remote_hash = await file_meta(repo_id, revision, path)
|
||||||
@@ -163,7 +176,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
|
|||||||
|
|
||||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
||||||
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
||||||
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
|
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||||
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
||||||
return index_data.get("weight_map")
|
return index_data.get("weight_map")
|
||||||
|
|
||||||
@@ -214,7 +227,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|||||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||||
async def download_with_semaphore(file):
|
async def download_with_semaphore(file):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||||
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
||||||
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
||||||
on_progress.trigger_all(shard, final_repo_progress)
|
on_progress.trigger_all(shard, final_repo_progress)
|
||||||
|
|||||||
Reference in New Issue
Block a user