make max_parallel_downloads configurable, increase download chunk size to 8MB

This commit is contained in:
Alex Cheema
2025-02-14 21:26:41 +00:00
parent b4e6f8acad
commit 477e3a5e4c
2 changed files with 10 additions and 9 deletions

View File

@@ -112,7 +112,7 @@ async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
hash.update(header)
async with aiofiles.open(path, 'rb') as f:
while chunk := await f.read(1024 * 1024):
while chunk := await f.read(8 * 1024 * 1024):
hash.update(chunk)
return hash.hexdigest()
@@ -154,7 +154,7 @@ async def _download_file(repo_id: str, revision: str, path: str, target_dir: Pat
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
while chunk := await r.content.read(8 * 1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
integrity = final_hash == remote_hash
@@ -197,7 +197,7 @@ async def get_downloaded_size(path: Path) -> int:
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
return 0
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 8, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
repo_id = get_repo(shard.model_id, inference_engine_classname)
revision = "main"
@@ -238,8 +238,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
else:
return target_dir, final_repo_progress
def new_shard_downloader() -> ShardDownloader:
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
def new_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader(max_parallel_downloads)))
class SingletonShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
@@ -283,7 +283,8 @@ class CachedShardDownloader(ShardDownloader):
yield path, status
class NewShardDownloader(ShardDownloader):
def __init__(self):
def __init__(self, max_parallel_downloads: int = 8):
self.max_parallel_downloads = max_parallel_downloads
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
@property
@@ -291,7 +292,7 @@ class NewShardDownloader(ShardDownloader):
return self._on_progress
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress, max_parallel_downloads=self.max_parallel_downloads)
return target_dir
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:

View File

@@ -72,7 +72,7 @@ parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
parser.add_argument("--max-parallel-downloads", type=int, default=8, help="Max parallel downloads for model shards download")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
@@ -99,7 +99,7 @@ print_yellow_exo()
system_info = get_system_info()
print(f"Detected system: {system_info}")
shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
shard_downloader: ShardDownloader = new_shard_downloader(args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
print(f"Inference engine name after selection: {inference_engine_name}")