mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #702 from exo-explore/alwayslogdownloaderror
make max_parallel_downloads configurable, increase download chunk size to 8MB
This commit is contained in:
@@ -112,7 +112,7 @@ async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str
|
||||
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
|
||||
hash.update(header)
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
while chunk := await f.read(1024 * 1024):
|
||||
while chunk := await f.read(8 * 1024 * 1024):
|
||||
hash.update(chunk)
|
||||
return hash.hexdigest()
|
||||
|
||||
@@ -154,7 +154,7 @@ async def _download_file(repo_id: str, revision: str, path: str, target_dir: Pat
|
||||
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
|
||||
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
||||
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
||||
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
||||
while chunk := await r.content.read(8 * 1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
||||
|
||||
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
|
||||
integrity = final_hash == remote_hash
|
||||
@@ -197,7 +197,7 @@ async def get_downloaded_size(path: Path) -> int:
|
||||
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
|
||||
return 0
|
||||
|
||||
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
||||
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 8, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
||||
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
||||
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
||||
revision = "main"
|
||||
@@ -238,8 +238,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
||||
else:
|
||||
return target_dir, final_repo_progress
|
||||
|
||||
def new_shard_downloader() -> ShardDownloader:
|
||||
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
|
||||
def new_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader(max_parallel_downloads)))
|
||||
|
||||
class SingletonShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
@@ -283,7 +283,8 @@ class CachedShardDownloader(ShardDownloader):
|
||||
yield path, status
|
||||
|
||||
class NewShardDownloader(ShardDownloader):
|
||||
def __init__(self):
|
||||
def __init__(self, max_parallel_downloads: int = 8):
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
|
||||
@property
|
||||
@@ -291,7 +292,7 @@ class NewShardDownloader(ShardDownloader):
|
||||
return self._on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
|
||||
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress, max_parallel_downloads=self.max_parallel_downloads)
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
|
||||
@@ -72,7 +72,7 @@ parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
||||
parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
||||
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
||||
parser.add_argument("--max-parallel-downloads", type=int, default=8, help="Max parallel downloads for model shards download")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
|
||||
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
|
||||
@@ -99,7 +99,7 @@ print_yellow_exo()
|
||||
system_info = get_system_info()
|
||||
print(f"Detected system: {system_info}")
|
||||
|
||||
shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
|
||||
shard_downloader: ShardDownloader = new_shard_downloader(args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
|
||||
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
||||
print(f"Inference engine name after selection: {inference_engine_name}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user