mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
separate hf_helpers, make extra dir with download_hf script, unify downloading so tinygrad uses the same method as mlx and interoperable model formats
This commit is contained in:
@@ -25,11 +25,11 @@ shard_mappings = {
|
||||
},
|
||||
"llama-3-8b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3-70b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
### mistral
|
||||
"mistral-nemo": {
|
||||
@@ -76,14 +76,6 @@ class ChatCompletionRequest:
|
||||
}
|
||||
|
||||
|
||||
def resolve_tinygrad_tokenizer(model_id: str):
|
||||
if model_id == "llama3-8b-sfr":
|
||||
return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
|
||||
elif model_id == "llama3-70b-sfr":
|
||||
return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
|
||||
else:
|
||||
raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
|
||||
|
||||
|
||||
async def resolve_tokenizer(model_id: str):
|
||||
try:
|
||||
@@ -111,15 +103,6 @@ async def resolve_tokenizer(model_id: str):
|
||||
|
||||
if DEBUG >= 2: print(traceback.format_exc())
|
||||
|
||||
try:
|
||||
if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
|
||||
return resolve_tinygrad_tokenizer(model_id)
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
|
||||
import traceback
|
||||
|
||||
if DEBUG >= 2: print(traceback.format_exc())
|
||||
|
||||
if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
|
||||
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
|
||||
|
||||
@@ -1,36 +1,17 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import argparse
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Optional, Coroutine, Any
|
||||
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
|
||||
from datetime import datetime, timedelta
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable, List, TypeVar, Union
|
||||
from typing import Generator, Iterable, TypeVar, TypedDict
|
||||
from dataclasses import dataclass
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from exo.helpers import DEBUG
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
DEFAULT_ALLOW_PATTERNS = [
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.safetensors",
|
||||
]
|
||||
# Always ignore `.git` and `.cache/huggingface` folders in commits
|
||||
DEFAULT_IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".git/*",
|
||||
"*/.git",
|
||||
"**/.git/**",
|
||||
".cache/huggingface",
|
||||
".cache/huggingface/*",
|
||||
"*/.cache/huggingface",
|
||||
"**/.cache/huggingface/**",
|
||||
]
|
||||
|
||||
def filter_repo_objects(
|
||||
items: Iterable[T],
|
||||
*,
|
||||
@@ -117,7 +98,35 @@ async def fetch_file_list(session, repo_id, revision, path=""):
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
async def download_file(session, repo_id, revision, file_path, save_directory, progress_callback: Optional[Callable[[str, int, int, float, timedelta], Coroutine[Any, Any, None]]] = None):
|
||||
|
||||
@dataclass
|
||||
class HFRepoFileProgressEvent:
|
||||
file_path: str
|
||||
downloaded: int
|
||||
total: int
|
||||
speed: float
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
@dataclass
|
||||
class HFRepoProgressEvent:
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: int
|
||||
total_bytes: int
|
||||
overall_eta: timedelta
|
||||
file_progress: Dict[str, HFRepoFileProgressEvent]
|
||||
|
||||
HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
|
||||
HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
|
||||
reraise=True
|
||||
)
|
||||
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
|
||||
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, file_path)
|
||||
local_path = os.path.join(save_directory, file_path)
|
||||
@@ -125,64 +134,72 @@ async def download_file(session, repo_id, revision, file_path, save_directory, p
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
# Check if file already exists and get its size
|
||||
if os.path.exists(local_path):
|
||||
local_file_size = os.path.getsize(local_path)
|
||||
else:
|
||||
local_file_size = 0
|
||||
local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
|
||||
|
||||
headers = get_auth_headers()
|
||||
if use_range_request:
|
||||
headers["Range"] = f"bytes={local_file_size}-"
|
||||
|
||||
headers = {"Range": f"bytes={local_file_size}-"}
|
||||
headers.update(get_auth_headers())
|
||||
async with session.get(url, headers=headers) as response:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded_size = local_file_size
|
||||
mode = 'ab' if use_range_request else 'wb'
|
||||
if downloaded_size == total_size:
|
||||
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
if response.status == 200:
|
||||
# File doesn't support range requests, start from beginning
|
||||
# File doesn't support range requests or we're not using them, start from beginning
|
||||
mode = 'wb'
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded_size = 0
|
||||
elif response.status == 206:
|
||||
# Partial content, resume download
|
||||
mode = 'ab'
|
||||
content_range = response.headers.get('Content-Range')
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
downloaded_size = local_file_size
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable, get the actual file size
|
||||
if response.headers.get('Content-Type', '').startswith('text/html'):
|
||||
content = await response.text()
|
||||
print(f"Response content (HTML):\n{content}")
|
||||
else:
|
||||
print(response)
|
||||
print("Return header: ", response.headers)
|
||||
print("Return header: ", response.headers.get('Content-Range').split('/')[-1])
|
||||
total_size = int(response.headers.get('Content-Range', '').split('/')[-1])
|
||||
if local_file_size == total_size:
|
||||
print(f"File already fully downloaded: {file_path}")
|
||||
return
|
||||
else:
|
||||
# Start the download from the beginning
|
||||
mode = 'wb'
|
||||
downloaded_size = 0
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
if downloaded_size == total_size:
|
||||
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
else:
|
||||
print(f"Failed to download {file_path}: {response.status}")
|
||||
return
|
||||
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
|
||||
|
||||
if downloaded_size == total_size:
|
||||
print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
DOWNLOAD_CHUNK_SIZE = 32768
|
||||
start_time = datetime.now()
|
||||
new_downloaded_size = 0
|
||||
with open(local_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
||||
f.write(chunk)
|
||||
new_downloaded_size += len(chunk)
|
||||
if progress_callback:
|
||||
downloaded_size += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
speed = new_downloaded_size / elapsed_time if elapsed_time > 0 else 0
|
||||
eta = timedelta(seconds=(total_size - downloaded_size - new_downloaded_size) / speed) if speed > 0 else timedelta(0)
|
||||
await progress_callback(file_path, new_downloaded_size, total_size - downloaded_size, speed, eta)
|
||||
print(f"Downloaded: {file_path}")
|
||||
speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
|
||||
remaining_size = total_size - downloaded_size
|
||||
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
|
||||
status = "in_progress" if downloaded_size < total_size else "complete"
|
||||
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
|
||||
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
||||
|
||||
async def download_all_files(repo_id, revision="main", progress_callback: Optional[Callable[[int, int, int, int, timedelta, dict], Coroutine[Any, Any, None]]] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
|
||||
async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
|
||||
repo_root = get_repo_root(repo_id)
|
||||
refs_dir = repo_root / "refs"
|
||||
snapshots_dir = repo_root / "snapshots"
|
||||
@@ -197,7 +214,7 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
|
||||
headers = get_auth_headers()
|
||||
async with session.get(api_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to fetch revision info: {response.status}")
|
||||
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
||||
revision_info = await response.json()
|
||||
commit_hash = revision_info['sha']
|
||||
|
||||
@@ -215,68 +232,32 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
|
||||
completed_files = 0
|
||||
total_bytes = sum(file["size"] for file in filtered_file_list)
|
||||
downloaded_bytes = 0
|
||||
new_downloaded_bytes = 0
|
||||
file_progress = {file["path"]: {"status": "not_started", "downloaded": 0, "total": file["size"]} for file in filtered_file_list}
|
||||
file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
|
||||
start_time = datetime.now()
|
||||
|
||||
async def download_with_progress(file_info):
|
||||
nonlocal completed_files, downloaded_bytes, new_downloaded_bytes, file_progress
|
||||
nonlocal completed_files, downloaded_bytes, file_progress
|
||||
|
||||
async def file_progress_callback(path, file_downloaded, file_total, speed, file_eta):
|
||||
nonlocal downloaded_bytes, new_downloaded_bytes, file_progress
|
||||
new_downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
|
||||
downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
|
||||
file_progress[path].update({
|
||||
'status': 'in_progress',
|
||||
'downloaded': file_downloaded,
|
||||
'total': file_total,
|
||||
'speed': speed,
|
||||
'eta': file_eta
|
||||
})
|
||||
async def file_progress_callback(event: HFRepoFileProgressEvent):
|
||||
nonlocal downloaded_bytes, file_progress
|
||||
downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
|
||||
file_progress[event.file_path] = event
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
||||
await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
|
||||
await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
|
||||
|
||||
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
||||
completed_files += 1
|
||||
file_progress[file_info["path"]]['status'] = 'complete'
|
||||
file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
||||
await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
|
||||
await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
|
||||
|
||||
tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
|
||||
async def progress_callback(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress):
|
||||
print(f"Overall Progress: {completed_files}/{total_files} files, {downloaded_bytes}/{total_bytes} bytes")
|
||||
print(f"Estimated time remaining: {overall_eta}")
|
||||
print("File Progress:")
|
||||
for file_path, progress in file_progress.items():
|
||||
status_icon = {
|
||||
'not_started': '⚪',
|
||||
'in_progress': '🔵',
|
||||
'complete': '✅'
|
||||
}[progress['status']]
|
||||
eta_str = str(progress.get('eta', 'N/A'))
|
||||
print(f"{status_icon} {file_path}: {progress.get('downloaded', 0)}/{progress['total']} bytes, "
|
||||
f"Speed: {progress.get('speed', 0):.2f} B/s, ETA: {eta_str}")
|
||||
print("\n")
|
||||
|
||||
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
|
||||
parser.add_argument("--repo-id", help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
|
||||
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
|
||||
parser.add_argument("--allow-patterns", nargs="*", default=DEFAULT_ALLOW_PATTERNS, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
|
||||
parser.add_argument("--ignore-patterns", nargs="*", default=DEFAULT_IGNORE_PATTERNS, help="Patterns of files to ignore (e.g., '.*')")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
|
||||
return snapshot_dir
|
||||
@@ -1,9 +1,9 @@
|
||||
import numpy as np
|
||||
|
||||
from typing import Tuple, Optional, Callable
|
||||
from typing import Tuple, Optional, Callable, Coroutine, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from .shard import Shard
|
||||
|
||||
from exo.inference.hf_helpers import HFRepoProgressEvent
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
@@ -15,5 +15,5 @@ class InferenceEngine(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
|
||||
def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
|
||||
pass
|
||||
|
||||
@@ -5,12 +5,13 @@ from .sharded_model import StatefulShardedModel
|
||||
from .sharded_utils import load_shard, get_image_from_str
|
||||
from ..shard import Shard
|
||||
from typing import Optional, Callable
|
||||
from exo.inference.hf_helpers import HFRepoProgressCallback
|
||||
|
||||
|
||||
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, on_download_progress: Callable[[int, int], None] = None):
|
||||
def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
|
||||
self.shard = None
|
||||
self.on_download_progress = on_download_progress
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
||||
await self.ensure_shard(shard)
|
||||
@@ -33,9 +34,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
if self.shard == shard:
|
||||
return
|
||||
|
||||
model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
|
||||
model_shard, self.tokenizer = await load_shard(shard.model_id, shard, progress_callback=self.progress_callback)
|
||||
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
|
||||
self.shard = shard
|
||||
|
||||
def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
|
||||
self.on_download_progress = on_download_progress
|
||||
def set_progress_callback(self, progress_callback: HFRepoProgressCallback):
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
@@ -12,10 +12,7 @@ from typing import Optional, Tuple, Union, List, Callable
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import os
|
||||
import concurrent.futures
|
||||
|
||||
from exo import DEBUG
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
|
||||
@@ -28,6 +25,8 @@ from transformers import AutoProcessor
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
|
||||
from mlx_lm.tuner.utils import apply_lora_layers
|
||||
|
||||
from exo import DEBUG
|
||||
from exo.inference.hf_helpers import download_all_files, HFRepoProgressCallback
|
||||
from ..shard import Shard
|
||||
|
||||
|
||||
@@ -164,52 +163,6 @@ def load_model_shard(
|
||||
return model
|
||||
|
||||
|
||||
async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
|
||||
it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
|
||||
files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
|
||||
return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)
|
||||
|
||||
async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
current_size = sum(os.path.getsize(os.path.join(root, file))
|
||||
for root, _, files in os.walk(dir)
|
||||
for file in files)
|
||||
progress = min(current_size / total_size * 100, 100)
|
||||
if print_progress:
|
||||
print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
|
||||
if on_progress:
|
||||
on_progress(current_size, total_size)
|
||||
if progress >= 100:
|
||||
if print_progress:
|
||||
print("\nDownload complete!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error monitoring progress: {e}")
|
||||
|
||||
async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
pool,
|
||||
partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
|
||||
)
|
||||
|
||||
async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
|
||||
storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
|
||||
# os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
|
||||
# os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
||||
|
||||
total_size = await get_repo_size(repo_id)
|
||||
|
||||
# Create tasks for download and progress checking
|
||||
download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
|
||||
progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
|
||||
|
||||
# Wait for both tasks to complete
|
||||
result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
|
||||
return result[0] # Return the result from download_task
|
||||
|
||||
repo_id_safetensors_layers = {
|
||||
"mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
|
||||
"model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
|
||||
@@ -313,7 +266,7 @@ def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
|
||||
|
||||
return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
|
||||
|
||||
async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
|
||||
async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None) -> Path:
|
||||
"""
|
||||
Ensures the model is available locally. If the path does not exist locally,
|
||||
it is downloaded from the Hugging Face Hub.
|
||||
@@ -329,7 +282,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
|
||||
if not model_path.exists():
|
||||
try:
|
||||
model_path = Path(
|
||||
await download_async_with_progress(
|
||||
await download_all_files(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
@@ -339,7 +292,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
|
||||
on_progress=on_download_progress,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
@@ -360,7 +313,7 @@ async def load_shard(
|
||||
model_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
on_download_progress: Callable[[int, int], None] = None,
|
||||
progress_callback: Optional[HFRepoProgressCallback] = None,
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
@@ -383,7 +336,7 @@ async def load_shard(
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
"""
|
||||
model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)
|
||||
model_path = await get_model_path(path_or_hf_repo, shard, progress_callback=progress_callback)
|
||||
|
||||
model = load_model_shard(model_path, shard, lazy, model_config)
|
||||
if adapter_path is not None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
@@ -40,17 +41,17 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
assert np.array_equal(next_resp_full, resp4)
|
||||
|
||||
|
||||
asyncio.run(
|
||||
test_inference_engine(
|
||||
MLXDynamicShardInferenceEngine(),
|
||||
MLXDynamicShardInferenceEngine(),
|
||||
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
)
|
||||
)
|
||||
# asyncio.run(
|
||||
# test_inference_engine(
|
||||
# MLXDynamicShardInferenceEngine(),
|
||||
# MLXDynamicShardInferenceEngine(),
|
||||
# "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
# )
|
||||
# )
|
||||
|
||||
# TODO: Need more memory or a smaller model
|
||||
# asyncio.run(test_inference_engine(
|
||||
# TinygradDynamicShardInferenceEngine(),
|
||||
# TinygradDynamicShardInferenceEngine(),
|
||||
# "llama3-8b-sfr",
|
||||
# ))
|
||||
asyncio.run(test_inference_engine(
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
"llama3-8b-sfr",
|
||||
))
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union, Callable
|
||||
from typing import List, Optional, Union, Callable, Coroutine, Any
|
||||
import json
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
|
||||
@@ -12,7 +10,7 @@ from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
import numpy as np
|
||||
import os
|
||||
from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files, get_repo_root
|
||||
|
||||
MODEL_PARAMS = {
|
||||
"8B": {
|
||||
@@ -46,15 +44,6 @@ MODEL_PARAMS = {
|
||||
|
||||
|
||||
# **** helper functions ****
|
||||
async def fetch_async(
|
||||
url: str,
|
||||
name: Optional[Union[Path, str]] = None,
|
||||
subdir: Optional[str] = None,
|
||||
allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
|
||||
) -> Path:
|
||||
func = partial(fetch, url, name, subdir, allow_caching)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
|
||||
def concat_weights(models, device=None):
|
||||
def convert(name) -> Tensor:
|
||||
@@ -159,8 +148,9 @@ def prefill(model, toks, start_pos=0):
|
||||
|
||||
|
||||
class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self):
|
||||
def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
|
||||
self.shard = None
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
||||
# TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
|
||||
@@ -199,62 +189,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
if self.shard == shard:
|
||||
return
|
||||
|
||||
model_path = Path(shard.model_id)
|
||||
models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
|
||||
model_path = models_dir / shard.model_id
|
||||
size = "8B"
|
||||
if Path(model_path / "tokenizer_config.json").exists():
|
||||
model = model_path
|
||||
else:
|
||||
|
||||
if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
|
||||
if shard.model_id.lower().find("llama3-8b-sfr") != -1:
|
||||
num_files = 4
|
||||
for i in range(num_files):
|
||||
await fetch_async(
|
||||
f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors",
|
||||
f"model-{(i+1):05d}-of-{num_files:05d}.safetensors",
|
||||
subdir=shard.model_id,
|
||||
)
|
||||
await fetch_async(
|
||||
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json",
|
||||
"config.json",
|
||||
subdir=shard.model_id,
|
||||
)
|
||||
model = await fetch_async(
|
||||
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json",
|
||||
"model.safetensors.index.json",
|
||||
subdir=shard.model_id,
|
||||
)
|
||||
await fetch_async(
|
||||
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json",
|
||||
"special_tokens_map.json",
|
||||
subdir=shard.model_id,
|
||||
)
|
||||
await fetch_async(
|
||||
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json",
|
||||
"tokenizer.json",
|
||||
subdir=shard.model_id,
|
||||
)
|
||||
await fetch_async(
|
||||
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json",
|
||||
"tokenizer_config.json",
|
||||
subdir=shard.model_id,
|
||||
)
|
||||
size = "8B"
|
||||
elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
|
||||
raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
|
||||
# fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
|
||||
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
|
||||
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
|
||||
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
|
||||
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
|
||||
# model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
|
||||
# size = "70B"
|
||||
else:
|
||||
raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
|
||||
|
||||
model = build_transformer(model_path, shard=shard, model_size=size)
|
||||
model_path = await download_all_files(shard.model_id, progress_callback=self.progress_callback)
|
||||
print(f"{model_path=}")
|
||||
model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
|
||||
|
||||
@@ -262,5 +199,5 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
|
||||
pass
|
||||
def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
|
||||
self.progress_callback = progress_callback
|
||||
53
extra/download_hf.py
Normal file
53
extra/download_hf.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from exo.inference.hf_helpers import download_all_files, HFRepoProgressEvent, HFRepoFileProgressEvent
|
||||
|
||||
DEFAULT_ALLOW_PATTERNS = [
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.safetensors",
|
||||
]
|
||||
# Always ignore `.git` and `.cache/huggingface` folders in commits
|
||||
DEFAULT_IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".git/*",
|
||||
"*/.git",
|
||||
"**/.git/**",
|
||||
".cache/huggingface",
|
||||
".cache/huggingface/*",
|
||||
"*/.cache/huggingface",
|
||||
"**/.cache/huggingface/**",
|
||||
]
|
||||
|
||||
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
|
||||
async def progress_callback(event: HFRepoProgressEvent):
|
||||
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
|
||||
print(f"Estimated time remaining: {event.overall_eta}")
|
||||
print("File Progress:")
|
||||
for file_path, progress in event.file_progress.items():
|
||||
status_icon = {
|
||||
'not_started': '⚪',
|
||||
'in_progress': '🔵',
|
||||
'complete': '✅'
|
||||
}[progress.status]
|
||||
eta_str = str(progress.eta)
|
||||
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
|
||||
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
|
||||
print("\n")
|
||||
|
||||
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
|
||||
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
|
||||
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
|
||||
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
|
||||
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
|
||||
2
main.py
2
main.py
@@ -60,7 +60,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
|
||||
if args.prometheus_client_port:
|
||||
from exo.stats.metrics import start_metrics_server
|
||||
start_metrics_server(node, args.prometheus_client_port)
|
||||
inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
|
||||
inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": event.downloaded_bytes, "total": event.total_bytes}))))
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
|
||||
2
setup.py
2
setup.py
@@ -6,6 +6,7 @@ from setuptools import find_packages, setup
|
||||
install_requires = [
|
||||
"aiohttp==3.9.5",
|
||||
"aiohttp_cors==0.7.0",
|
||||
"aiofiles==24.1.0",
|
||||
"blobfile==2.1.1",
|
||||
"grpcio==1.64.1",
|
||||
"grpcio-tools==1.64.1",
|
||||
@@ -21,6 +22,7 @@ install_requires = [
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"safetensors==0.4.3",
|
||||
"tenacity==9.0.0",
|
||||
"tiktoken==0.7.0",
|
||||
"tokenizers==0.19.1",
|
||||
"tqdm==4.66.4",
|
||||
|
||||
Reference in New Issue
Block a user