separate hf_helpers, make extra dir with download_hf script, unify downloading so tinygrad uses the same method as mlx and interoperable model formats

This commit is contained in:
Alex Cheema
2024-08-05 23:03:17 +01:00
parent 9014efae86
commit 545a486ed3
10 changed files with 189 additions and 278 deletions

View File

@@ -25,11 +25,11 @@ shard_mappings = {
},
"llama-3-8b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32),
},
"llama-3-70b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
},
### mistral
"mistral-nemo": {
@@ -76,14 +76,6 @@ class ChatCompletionRequest:
}
def resolve_tinygrad_tokenizer(model_id: str):
if model_id == "llama3-8b-sfr":
return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
elif model_id == "llama3-70b-sfr":
return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
else:
raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
async def resolve_tokenizer(model_id: str):
try:
@@ -111,15 +103,6 @@ async def resolve_tokenizer(model_id: str):
if DEBUG >= 2: print(traceback.format_exc())
try:
if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
return resolve_tinygrad_tokenizer(model_id)
except Exception as e:
if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
import traceback
if DEBUG >= 2: print(traceback.format_exc())
if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer

View File

@@ -1,36 +1,17 @@
import asyncio
import aiohttp
import os
import argparse
from urllib.parse import urljoin
from typing import Callable, Optional, Coroutine, Any
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
from datetime import datetime, timedelta
from fnmatch import fnmatch
from pathlib import Path
from typing import Generator, Iterable, List, TypeVar, Union
from typing import Generator, Iterable, TypeVar, TypedDict
from dataclasses import dataclass
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from exo.helpers import DEBUG
T = TypeVar("T")
DEFAULT_ALLOW_PATTERNS = [
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
"*.safetensors",
]
# Always ignore `.git` and `.cache/huggingface` folders in commits
DEFAULT_IGNORE_PATTERNS = [
".git",
".git/*",
"*/.git",
"**/.git/**",
".cache/huggingface",
".cache/huggingface/*",
"*/.cache/huggingface",
"**/.cache/huggingface/**",
]
def filter_repo_objects(
items: Iterable[T],
*,
@@ -117,7 +98,35 @@ async def fetch_file_list(session, repo_id, revision, path=""):
else:
raise Exception(f"Failed to fetch file list: {response.status}")
async def download_file(session, repo_id, revision, file_path, save_directory, progress_callback: Optional[Callable[[str, int, int, float, timedelta], Coroutine[Any, Any, None]]] = None):
@dataclass
class HFRepoFileProgressEvent:
file_path: str
downloaded: int
total: int
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
@dataclass
class HFRepoProgressEvent:
completed_files: int
total_files: int
downloaded_bytes: int
total_bytes: int
overall_eta: timedelta
file_progress: Dict[str, HFRepoFileProgressEvent]
HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
reraise=True
)
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
local_path = os.path.join(save_directory, file_path)
@@ -125,64 +134,72 @@ async def download_file(session, repo_id, revision, file_path, save_directory, p
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# Check if file already exists and get its size
if os.path.exists(local_path):
local_file_size = os.path.getsize(local_path)
else:
local_file_size = 0
local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
headers = get_auth_headers()
if use_range_request:
headers["Range"] = f"bytes={local_file_size}-"
headers = {"Range": f"bytes={local_file_size}-"}
headers.update(get_auth_headers())
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get('Content-Length', 0))
downloaded_size = local_file_size
mode = 'ab' if use_range_request else 'wb'
if downloaded_size == total_size:
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
return
if response.status == 200:
# File doesn't support range requests, start from beginning
# File doesn't support range requests or we're not using them, start from beginning
mode = 'wb'
total_size = int(response.headers.get('Content-Length', 0))
downloaded_size = 0
elif response.status == 206:
# Partial content, resume download
mode = 'ab'
content_range = response.headers.get('Content-Range')
total_size = int(content_range.split('/')[-1])
downloaded_size = local_file_size
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
elif response.status == 416:
# Range not satisfiable, get the actual file size
if response.headers.get('Content-Type', '').startswith('text/html'):
content = await response.text()
print(f"Response content (HTML):\n{content}")
else:
print(response)
print("Return header: ", response.headers)
print("Return header: ", response.headers.get('Content-Range').split('/')[-1])
total_size = int(response.headers.get('Content-Range', '').split('/')[-1])
if local_file_size == total_size:
print(f"File already fully downloaded: {file_path}")
return
else:
# Start the download from the beginning
mode = 'wb'
downloaded_size = 0
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
if downloaded_size == total_size:
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
if progress_callback:
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
return
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
else:
print(f"Failed to download {file_path}: {response.status}")
return
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
if downloaded_size == total_size:
print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
return
DOWNLOAD_CHUNK_SIZE = 32768
start_time = datetime.now()
new_downloaded_size = 0
with open(local_path, mode) as f:
async for chunk in response.content.iter_chunked(8192):
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
f.write(chunk)
new_downloaded_size += len(chunk)
if progress_callback:
downloaded_size += len(chunk)
if progress_callback and total_size:
elapsed_time = (datetime.now() - start_time).total_seconds()
speed = new_downloaded_size / elapsed_time if elapsed_time > 0 else 0
eta = timedelta(seconds=(total_size - downloaded_size - new_downloaded_size) / speed) if speed > 0 else timedelta(0)
await progress_callback(file_path, new_downloaded_size, total_size - downloaded_size, speed, eta)
print(f"Downloaded: {file_path}")
speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
remaining_size = total_size - downloaded_size
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
status = "in_progress" if downloaded_size < total_size else "complete"
await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
if DEBUG >= 2: print(f"Downloaded: {file_path}")
async def download_all_files(repo_id, revision="main", progress_callback: Optional[Callable[[int, int, int, int, timedelta, dict], Coroutine[Any, Any, None]]] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
repo_root = get_repo_root(repo_id)
refs_dir = repo_root / "refs"
snapshots_dir = repo_root / "snapshots"
@@ -197,7 +214,7 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
headers = get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info: {response.status}")
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']
@@ -215,68 +232,32 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
completed_files = 0
total_bytes = sum(file["size"] for file in filtered_file_list)
downloaded_bytes = 0
new_downloaded_bytes = 0
file_progress = {file["path"]: {"status": "not_started", "downloaded": 0, "total": file["size"]} for file in filtered_file_list}
file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
start_time = datetime.now()
async def download_with_progress(file_info):
nonlocal completed_files, downloaded_bytes, new_downloaded_bytes, file_progress
nonlocal completed_files, downloaded_bytes, file_progress
async def file_progress_callback(path, file_downloaded, file_total, speed, file_eta):
nonlocal downloaded_bytes, new_downloaded_bytes, file_progress
new_downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
file_progress[path].update({
'status': 'in_progress',
'downloaded': file_downloaded,
'total': file_total,
'speed': speed,
'eta': file_eta
})
async def file_progress_callback(event: HFRepoFileProgressEvent):
nonlocal downloaded_bytes, file_progress
downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
file_progress[event.file_path] = event
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
completed_files += 1
file_progress[file_info["path"]]['status'] = 'complete'
file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
await asyncio.gather(*tasks)
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
async def progress_callback(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress):
print(f"Overall Progress: {completed_files}/{total_files} files, {downloaded_bytes}/{total_bytes} bytes")
print(f"Estimated time remaining: {overall_eta}")
print("File Progress:")
for file_path, progress in file_progress.items():
status_icon = {
'not_started': '',
'in_progress': '🔵',
'complete': ''
}[progress['status']]
eta_str = str(progress.get('eta', 'N/A'))
print(f"{status_icon} {file_path}: {progress.get('downloaded', 0)}/{progress['total']} bytes, "
f"Speed: {progress.get('speed', 0):.2f} B/s, ETA: {eta_str}")
print("\n")
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
parser.add_argument("--repo-id", help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
parser.add_argument("--allow-patterns", nargs="*", default=DEFAULT_ALLOW_PATTERNS, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
parser.add_argument("--ignore-patterns", nargs="*", default=DEFAULT_IGNORE_PATTERNS, help="Patterns of files to ignore (e.g., '.*')")
args = parser.parse_args()
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
return snapshot_dir

View File

@@ -1,9 +1,9 @@
import numpy as np
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional, Callable, Coroutine, Any
from abc import ABC, abstractmethod
from .shard import Shard
from exo.inference.hf_helpers import HFRepoProgressEvent
class InferenceEngine(ABC):
@abstractmethod
@@ -15,5 +15,5 @@ class InferenceEngine(ABC):
pass
@abstractmethod
def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
pass

View File

@@ -5,12 +5,13 @@ from .sharded_model import StatefulShardedModel
from .sharded_utils import load_shard, get_image_from_str
from ..shard import Shard
from typing import Optional, Callable
from exo.inference.hf_helpers import HFRepoProgressCallback
class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, on_download_progress: Callable[[int, int], None] = None):
def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
self.shard = None
self.on_download_progress = on_download_progress
self.progress_callback = progress_callback
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
await self.ensure_shard(shard)
@@ -33,9 +34,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
if self.shard == shard:
return
model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
model_shard, self.tokenizer = await load_shard(shard.model_id, shard, progress_callback=self.progress_callback)
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
self.shard = shard
def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
self.on_download_progress = on_download_progress
def set_progress_callback(self, progress_callback: HFRepoProgressCallback):
self.progress_callback = progress_callback

View File

@@ -12,10 +12,7 @@ from typing import Optional, Tuple, Union, List, Callable
from PIL import Image
from io import BytesIO
import base64
import os
import concurrent.futures
from exo import DEBUG
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
@@ -28,6 +25,8 @@ from transformers import AutoProcessor
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
from mlx_lm.tuner.utils import apply_lora_layers
from exo import DEBUG
from exo.inference.hf_helpers import download_all_files, HFRepoProgressCallback
from ..shard import Shard
@@ -164,52 +163,6 @@ def load_model_shard(
return model
async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)
async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
while True:
try:
await asyncio.sleep(0.1)
current_size = sum(os.path.getsize(os.path.join(root, file))
for root, _, files in os.walk(dir)
for file in files)
progress = min(current_size / total_size * 100, 100)
if print_progress:
print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
if on_progress:
on_progress(current_size, total_size)
if progress >= 100:
if print_progress:
print("\nDownload complete!")
break
except Exception as e:
print(f"Error monitoring progress: {e}")
async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
with concurrent.futures.ThreadPoolExecutor() as pool:
return await asyncio.get_event_loop().run_in_executor(
pool,
partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
)
async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
# os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
# os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
total_size = await get_repo_size(repo_id)
# Create tasks for download and progress checking
download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
# Wait for both tasks to complete
result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
return result[0] # Return the result from download_task
repo_id_safetensors_layers = {
"mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
"model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
@@ -313,7 +266,7 @@ def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
@@ -329,7 +282,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
if not model_path.exists():
try:
model_path = Path(
await download_async_with_progress(
await download_all_files(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
@@ -339,7 +292,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
"*.tiktoken",
"*.txt",
] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
on_progress=on_download_progress,
progress_callback=progress_callback,
)
)
except RepositoryNotFoundError:
@@ -360,7 +313,7 @@ async def load_shard(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
on_download_progress: Callable[[int, int], None] = None,
progress_callback: Optional[HFRepoProgressCallback] = None,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
@@ -383,7 +336,7 @@ async def load_shard(
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)
model_path = await get_model_path(path_or_hf_repo, shard, progress_callback=progress_callback)
model = load_model_shard(model_path, shard, lazy, model_config)
if adapter_path is not None:

View File

@@ -1,3 +1,4 @@
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.inference.inference_engine import InferenceEngine
from exo.inference.shard import Shard
@@ -40,17 +41,17 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
assert np.array_equal(next_resp_full, resp4)
asyncio.run(
test_inference_engine(
MLXDynamicShardInferenceEngine(),
MLXDynamicShardInferenceEngine(),
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
)
)
# asyncio.run(
# test_inference_engine(
# MLXDynamicShardInferenceEngine(),
# MLXDynamicShardInferenceEngine(),
# "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
# )
# )
# TODO: Need more memory or a smaller model
# asyncio.run(test_inference_engine(
# TinygradDynamicShardInferenceEngine(),
# TinygradDynamicShardInferenceEngine(),
# "llama3-8b-sfr",
# ))
asyncio.run(test_inference_engine(
TinygradDynamicShardInferenceEngine(),
TinygradDynamicShardInferenceEngine(),
"llama3-8b-sfr",
))

View File

@@ -1,9 +1,7 @@
import asyncio
from functools import partial
from pathlib import Path
from typing import List, Optional, Union, Callable
from typing import List, Optional, Union, Callable, Coroutine, Any
import json
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
@@ -12,7 +10,7 @@ from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
from exo.inference.shard import Shard
from exo.inference.inference_engine import InferenceEngine
import numpy as np
import os
from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files, get_repo_root
MODEL_PARAMS = {
"8B": {
@@ -46,15 +44,6 @@ MODEL_PARAMS = {
# **** helper functions ****
async def fetch_async(
url: str,
name: Optional[Union[Path, str]] = None,
subdir: Optional[str] = None,
allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
) -> Path:
func = partial(fetch, url, name, subdir, allow_caching)
return await asyncio.get_event_loop().run_in_executor(None, func)
def concat_weights(models, device=None):
def convert(name) -> Tensor:
@@ -159,8 +148,9 @@ def prefill(model, toks, start_pos=0):
class TinygradDynamicShardInferenceEngine(InferenceEngine):
def __init__(self):
def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
self.shard = None
self.progress_callback = progress_callback
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
# TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
@@ -199,62 +189,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
if self.shard == shard:
return
model_path = Path(shard.model_id)
models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
model_path = models_dir / shard.model_id
size = "8B"
if Path(model_path / "tokenizer_config.json").exists():
model = model_path
else:
if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
if shard.model_id.lower().find("llama3-8b-sfr") != -1:
num_files = 4
for i in range(num_files):
await fetch_async(
f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors",
f"model-{(i+1):05d}-of-{num_files:05d}.safetensors",
subdir=shard.model_id,
)
await fetch_async(
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json",
"config.json",
subdir=shard.model_id,
)
model = await fetch_async(
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json",
"model.safetensors.index.json",
subdir=shard.model_id,
)
await fetch_async(
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json",
"special_tokens_map.json",
subdir=shard.model_id,
)
await fetch_async(
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json",
"tokenizer.json",
subdir=shard.model_id,
)
await fetch_async(
"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json",
"tokenizer_config.json",
subdir=shard.model_id,
)
size = "8B"
elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
# fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
# fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
# model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
# size = "70B"
else:
raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
model = build_transformer(model_path, shard=shard, model_size=size)
model_path = await download_all_files(shard.model_id, progress_callback=self.progress_callback)
print(f"{model_path=}")
model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
@@ -262,5 +199,5 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
self.model = model
self.tokenizer = tokenizer
def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
pass
def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
self.progress_callback = progress_callback

53
extra/download_hf.py Normal file
View File

@@ -0,0 +1,53 @@
import argparse
import asyncio
from exo.inference.hf_helpers import download_all_files, HFRepoProgressEvent, HFRepoFileProgressEvent
DEFAULT_ALLOW_PATTERNS = [
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
"*.safetensors",
]
# Always ignore `.git` and `.cache/huggingface` folders in commits
DEFAULT_IGNORE_PATTERNS = [
".git",
".git/*",
"*/.git",
"**/.git/**",
".cache/huggingface",
".cache/huggingface/*",
"*/.cache/huggingface",
"**/.cache/huggingface/**",
]
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
async def progress_callback(event: HFRepoProgressEvent):
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
print(f"Estimated time remaining: {event.overall_eta}")
print("File Progress:")
for file_path, progress in event.file_progress.items():
status_icon = {
'not_started': '',
'in_progress': '🔵',
'complete': ''
}[progress.status]
eta_str = str(progress.eta)
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
print("\n")
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
args = parser.parse_args()
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

View File

@@ -60,7 +60,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
if args.prometheus_client_port:
from exo.stats.metrics import start_metrics_server
start_metrics_server(node, args.prometheus_client_port)
inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": event.downloaded_bytes, "total": event.total_bytes}))))
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""

View File

@@ -6,6 +6,7 @@ from setuptools import find_packages, setup
install_requires = [
"aiohttp==3.9.5",
"aiohttp_cors==0.7.0",
"aiofiles==24.1.0",
"blobfile==2.1.1",
"grpcio==1.64.1",
"grpcio-tools==1.64.1",
@@ -21,6 +22,7 @@ install_requires = [
"requests==2.32.3",
"rich==13.7.1",
"safetensors==0.4.3",
"tenacity==9.0.0",
"tiktoken==0.7.0",
"tokenizers==0.19.1",
"tqdm==4.66.4",