remove a lot of hf bloat

This commit is contained in:
Alex Cheema
2025-01-27 01:06:47 +00:00
parent b89495f444
commit 1df023023e
10 changed files with 76 additions and 651 deletions

View File

@@ -21,19 +21,17 @@ import numpy as np
import base64
from io import BytesIO
import platform
from exo.download.shard_download import RepoProgressEvent
from exo.download.download_progress import RepoProgressEvent
from exo.download.new_shard_download import ensure_downloads_dir, delete_model
import tempfile
from exo.apputil import create_animation_mp4
from collections import defaultdict
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
import tempfile
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
from collections import defaultdict
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
@@ -545,35 +543,13 @@ class ChatGPTAPI:
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_delete_model(self, request):
model_id = request.match_info.get('model_name')
try:
model_name = request.match_info.get('model_name')
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
if not model_name or model_name not in model_cards:
return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard:
return web.json_response({"detail": "Could not build shard for model"}, status=400)
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
# Get the HF cache directory using the helper function
hf_home = get_hf_home()
cache_dir = get_repo_root(repo_id)
if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
if os.path.exists(cache_dir):
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
except Exception as e:
return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
else:
return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")

View File

@@ -1,36 +1,16 @@
import aiofiles.os as aios
from typing import Union
import asyncio
import aiohttp
import json
import os
import sys
import shutil
from urllib.parse import urljoin
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
from datetime import datetime, timedelta
from typing import Callable, Optional, Dict, List, Union
from fnmatch import fnmatch
from pathlib import Path
from typing import Generator, Iterable, TypeVar, TypedDict
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from exo.helpers import DEBUG, is_frozen
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
from typing import Generator, Iterable, TypeVar
from exo.helpers import DEBUG
from exo.inference.shard import Shard
import aiofiles
T = TypeVar("T")
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
refs_dir = get_repo_root(repo_id)/"refs"
refs_file = refs_dir/revision
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
return snapshot_dir
return None
def filter_repo_objects(
items: Iterable[T],
*,
@@ -48,14 +28,12 @@ def filter_repo_objects(
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
if key is None:
def _identity(item: T) -> str:
if isinstance(item, str):
return item
if isinstance(item, Path):
return str(item)
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
key = _identity
for item in items:
@@ -66,22 +44,18 @@ def filter_repo_objects(
continue
yield item
def _add_wildcard_to_directories(pattern: str) -> str:
if pattern[-1] == "/":
return pattern + "*"
return pattern
def get_hf_endpoint() -> str:
return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
def get_hf_home() -> Path:
"""Get the Hugging Face home directory."""
return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
async def get_hf_token():
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
token_path = get_hf_home()/"token"
@@ -90,7 +64,6 @@ async def get_hf_token():
return (await f.read()).strip()
return None
async def get_auth_headers():
"""Get authentication headers if a token is available."""
token = await get_hf_token()
@@ -98,324 +71,6 @@ async def get_auth_headers():
return {"Authorization": f"Bearer {token}"}
return {}
def get_repo_root(repo_id: str) -> Path:
"""Get the root directory for a given repo ID in the Hugging Face cache."""
sanitized_repo_id = str(repo_id).replace("/", "--")
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
async def move_models_to_hf(seed_dir: Union[str, Path]):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = get_hf_home()/"hub"
await aios.makedirs(dest_dir, exist_ok=True)
for path in source_dir.iterdir():
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if await aios.path.exists(dest_path):
print('Skipping moving model to .cache directory')
else:
try:
await aios.rename(str(path), str(dest_path))
except Exception as e:
print(f'Error moving model to .cache: {e}')
async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_auth_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
files = []
for item in data:
if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory":
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
@retry(
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
)
async def download_file(
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
):
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
local_path = os.path.join(save_directory, file_path)
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
# Check if file already exists and get its size
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
headers = await get_auth_headers()
if use_range_request:
headers["Range"] = f"bytes={local_file_size}-"
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get('Content-Length', 0))
downloaded_size = local_file_size
downloaded_this_session = 0
mode = 'ab' if use_range_request else 'wb'
percentage = await get_file_download_percentage(
session,
repo_id,
revision,
file_path,
Path(save_directory)
)
if percentage == 100:
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
return
if response.status == 200:
# File doesn't support range requests or we're not using them, start from beginning
mode = 'wb'
downloaded_size = 0
elif response.status == 206:
# Partial content, resume download
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
elif response.status == 416:
# Range not satisfiable, get the actual file size
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
if downloaded_size == total_size:
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
else:
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
if downloaded_size == total_size:
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
DOWNLOAD_CHUNK_SIZE = 32768
start_time = datetime.now()
async with aiofiles.open(local_path, mode) as f:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
await f.write(chunk)
downloaded_size += len(chunk)
downloaded_this_session += len(chunk)
if progress_callback and total_size:
elapsed_time = (datetime.now() - start_time).total_seconds()
speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
remaining_size = total_size - downloaded_size
eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
status = "in_progress" if downloaded_size < total_size else "complete"
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
if DEBUG >= 2: print(f"Downloaded: {file_path}")
async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
repo_root = get_repo_root(repo_id)
refs_dir = repo_root/"refs"
refs_file = refs_dir/revision
# Check if we have a cached commit hash
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
return commit_hash
# Fetch the commit hash for the given revision
async with aiohttp.ClientSession() as session:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
headers = await get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']
# Cache the commit hash
await aios.makedirs(refs_dir, exist_ok=True)
async with aiofiles.open(refs_file, 'w') as f:
await f.write(commit_hash)
return commit_hash
async def download_repo_files(
repo_id: str,
revision: str = "main",
progress_callback: Optional[RepoProgressCallback] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_parallel_downloads: int = 4
) -> Path:
repo_root = get_repo_root(repo_id)
snapshots_dir = repo_root/"snapshots"
cachedreqs_dir = repo_root/"cachedreqs"
# Ensure directories exist
await aios.makedirs(snapshots_dir, exist_ok=True)
await aios.makedirs(cachedreqs_dir, exist_ok=True)
# Resolve revision to commit hash
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
# Set up the snapshot directory
snapshot_dir = snapshots_dir/commit_hash
await aios.makedirs(snapshot_dir, exist_ok=True)
# Set up the cached file list directory
cached_file_list_dir = cachedreqs_dir/commit_hash
await aios.makedirs(cached_file_list_dir, exist_ok=True)
cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
async with aiohttp.ClientSession() as session:
# Check if we have a cached file list
if await aios.path.exists(cached_file_list_path):
async with aiofiles.open(cached_file_list_path, 'r') as f:
file_list = json.loads(await f.read())
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
else:
file_list = await fetch_file_list(session, repo_id, revision)
# Cache the file list
async with aiofiles.open(cached_file_list_path, 'w') as f:
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
if model_index_exists:
allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)
file_progress: Dict[str, RepoFileProgressEvent] = {
file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
for file in filtered_file_list
}
start_time = datetime.now()
async def download_with_progress(file_info, progress_state):
local_path = snapshot_dir/file_info["path"]
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
progress_state['completed_files'] += 1
progress_state['downloaded_bytes'] += file_info["size"]
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
return
async def file_progress_callback(event: RepoFileProgressEvent):
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
file_progress[event.file_path] = event
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
progress_state['completed_files'] += 1
file_progress[
file_info["path"]
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file_info):
async with semaphore:
await download_with_progress(file_info, progress_state)
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
await asyncio.gather(*tasks)
return snapshot_dir
async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
"""
Retrieve the weight map from the model.safetensors.index.json file.
Args:
repo_id (str): The Hugging Face repository ID.
revision (str): The revision of the repository to use.
Returns:
Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
"""
# Download the index file
await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
# Check if the file exists
repo_root = get_repo_root(repo_id)
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
snapshot_dir = repo_root/"snapshots"/commit_hash
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
if index_file:
index_file_path = snapshot_dir/index_file
if await aios.path.exists(index_file_path):
async with aiofiles.open(index_file_path, 'r') as f:
index_data = json.loads(await f.read())
return index_data.get("weight_map")
return None
def extract_layer_num(tensor_name: str) -> Optional[int]:
# This is a simple example and might need to be adjusted based on the actual naming convention
parts = tensor_name.split('.')
@@ -424,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
return int(part)
return None
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
shard_specific_patterns = set()
@@ -442,65 +96,3 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
shard_specific_patterns = set(["*.safetensors"])
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)
async def get_file_download_percentage(
session: aiohttp.ClientSession,
repo_id: str,
revision: str,
file_path: str,
snapshot_dir: Path,
) -> float:
"""
Calculate the download percentage for a file by comparing local and remote sizes.
"""
try:
local_path = snapshot_dir / file_path
if not await aios.path.exists(local_path):
return 0
# Get local file size first
local_size = await aios.path.getsize(local_path)
if local_size == 0:
return 0
# Check remote size
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
headers = await get_auth_headers()
# Use HEAD request with redirect following for all files
async with session.head(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
if DEBUG >= 2:
print(f"Failed to get remote file info for {file_path}: {response.status}")
return 0
remote_size = int(response.headers.get('Content-Length', 0))
if remote_size == 0:
if DEBUG >= 2:
print(f"Remote size is 0 for {file_path}")
return 0
# Only return 100% if sizes match exactly
if local_size == remote_size:
return 100.0
# Calculate percentage based on sizes
return (local_size / remote_size) * 100 if remote_size > 0 else 0
except Exception as e:
if DEBUG >= 2:
print(f"Error checking file download status for {file_path}: {e}")
return 0
async def has_hf_home_read_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.R_OK)
except OSError: return False
async def has_hf_home_write_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.W_OK)
except OSError: return False

View File

@@ -1,172 +0,0 @@
import asyncio
import traceback
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
from exo.inference.shard import Shard
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent
from exo.download.hf.hf_helpers import (
download_repo_files, RepoProgressEvent, get_weight_map,
get_allow_patterns, get_repo_root, fetch_file_list,
get_local_snapshot_dir, get_file_download_percentage,
filter_repo_objects
)
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import model_cards, get_repo
import aiohttp
from aiofiles import os as aios
class HFShardDownloader(ShardDownloader):
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.quick_check = quick_check
self.max_parallel_downloads = max_parallel_downloads
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
self.current_shard: Optional[Shard] = None
self.current_repo_id: Optional[str] = None
self.revision: str = "main"
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
self.current_shard = shard
self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
repo_name = get_repo(shard.model_id, inference_engine_name)
if shard in self.completed_downloads:
return self.completed_downloads[shard]
if self.quick_check:
repo_root = get_repo_root(repo_name)
snapshots_dir = repo_root/"snapshots"
if snapshots_dir.exists():
visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
if visible_dirs:
most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
return most_recent_dir
# If a download on this shard is already in progress, keep that one
for active_shard in self.active_downloads:
if active_shard == shard:
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
return await self.active_downloads[shard]
# Cancel any downloads for this model_id on a different shard
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
for active_shard in existing_active_shards:
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
task = self.active_downloads[active_shard]
task.cancel()
try:
await task
except asyncio.CancelledError:
pass # This is expected when cancelling a task
except Exception as e:
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
traceback.print_exc()
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
# Start new download
download_task = asyncio.create_task(self._download_shard(shard, repo_name))
self.active_downloads[shard] = download_task
try:
path = await download_task
self.completed_downloads[shard] = path
return path
finally:
# Ensure the task is removed even if an exception occurs
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
if shard in self.active_downloads:
self.active_downloads.pop(shard)
async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
async def wrapped_progress_callback(event: RepoProgressEvent):
self._on_progress.trigger_all(shard, event)
weight_map = await get_weight_map(repo_name)
allow_patterns = get_allow_patterns(weight_map, shard)
return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress
async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
if not self.current_shard or not self.current_repo_id:
if DEBUG >= 2:
print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
return None
try:
# If no snapshot directory exists, return None - no need to check remote files
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
if not snapshot_dir:
if DEBUG >= 2:
print(f"No snapshot directory found for {self.current_repo_id}")
return None
if not await aios.path.exists(snapshot_dir/"model_index.json"):
# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
else:
patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
# Check download status for all relevant files
status = {}
total_bytes = 0
downloaded_bytes = 0
async with aiohttp.ClientSession() as session:
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
relevant_files = list(
filter_repo_objects(
file_list, allow_patterns=patterns, key=lambda x: x["path"]))
for file in relevant_files:
file_size = file["size"]
total_bytes += file_size
percentage = await get_file_download_percentage(
session,
self.current_repo_id,
self.revision,
file["path"],
snapshot_dir,
)
status[file["path"]] = percentage
downloaded_bytes += (file_size * (percentage / 100))
# Add overall progress weighted by file size
if total_bytes > 0:
status["overall"] = (downloaded_bytes / total_bytes) * 100
else:
status["overall"] = 0
# Add total size in bytes
status["total_size"] = total_bytes
if status["overall"] != 100:
status["total_downloaded"] = downloaded_bytes
if DEBUG >= 2:
print(f"Download calculation for {self.current_repo_id}:")
print(f"Total bytes: {total_bytes}")
print(f"Downloaded bytes: {downloaded_bytes}")
if DEBUG >= 3:
for file in relevant_files:
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
return status
except Exception as e:
if DEBUG >= 3:
print(f"Error getting shard download status: {e}")
traceback.print_exc()
return None

View File

@@ -16,15 +16,46 @@ import time
from datetime import timedelta
import asyncio
import json
import traceback
import shutil
def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
async def has_exo_home_read_access() -> bool:
try: return await aios.access(exo_home(), os.R_OK)
except OSError: return False
async def has_exo_home_write_access() -> bool:
try: return await aios.access(exo_home(), os.W_OK)
except OSError: return False
async def ensure_downloads_dir() -> Path:
downloads_dir = exo_home()/"downloads"
await aios.makedirs(downloads_dir, exist_ok=True)
return downloads_dir
async def delete_model(model_id: str, inference_engine_name: str) -> bool:
repo_id = get_repo(model_id, inference_engine_name)
model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
if not await aios.path.exists(model_dir): return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True
async def seed_models(seed_dir: Union[str, Path]):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = await ensure_downloads_dir()
for path in source_dir.iterdir():
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir/path.name
if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
else:
try: await aios.rename(str(path), str(dest_path))
except:
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()
async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
@@ -44,8 +75,9 @@ async def fetch_file_list(session, repo_id, revision, path=""):
else:
raise Exception(f"Failed to fetch file list: {response.status}")
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path:
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
if (target_dir/path).exists(): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, path)
headers = await get_auth_headers()
@@ -71,8 +103,7 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
async with aiohttp.ClientSession() as session:
index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f:
index_data = json.loads(await f.read())
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
return index_data.get("weight_map")
async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:

View File

@@ -1,12 +1,11 @@
import traceback
from aiofiles import os as aios
from os import PathLike
from pathlib import Path
from aiofiles import os as aios
from typing import Union
from transformers import AutoTokenizer, AutoProcessor
import numpy as np
from exo.download.hf.hf_helpers import get_local_snapshot_dir
from exo.helpers import DEBUG
from exo.download.new_shard_download import ensure_downloads_dir
class DummyTokenizer:
@@ -24,25 +23,25 @@ class DummyTokenizer:
return "dummy" * len(tokens)
async def resolve_tokenizer(model_id: str):
if model_id == "dummy":
async def resolve_tokenizer(repo_id: Union[str, PathLike]):
if repo_id == "dummy":
return DummyTokenizer()
local_path = await get_local_snapshot_dir(model_id)
local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
try:
if local_path and await aios.path.exists(local_path):
if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
return await _resolve_tokenizer(local_path)
except:
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
if DEBUG >= 5: traceback.print_exc()
return await _resolve_tokenizer(model_id)
return await _resolve_tokenizer(repo_id)
async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
try:
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
if not hasattr(processor, 'eos_token_id'):
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
if not hasattr(processor, 'encode'):
@@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
return processor
except Exception as e:
if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())
try:
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
except Exception as e:
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())
raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")

View File

@@ -23,19 +23,18 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
from exo.download.hf.new_shard_download import new_shard_downloader
from exo.download.shard_download import ShardDownloader, NoopShardDownloader
from exo.download.download_progress import RepoProgressEvent
from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.inference.tokenizers import resolve_tokenizer
from exo.models import build_base_shard, get_repo
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
import uvloop
from contextlib import asynccontextmanager
import concurrent.futures
import socket
import resource
import psutil
@@ -321,13 +320,13 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
async def main():
loop = asyncio.get_running_loop()
# Check HuggingFace directory permissions
hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
# Check exo directory permissions
home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
if DEBUG >= 1: print(f"exo home directory: {home}")
print(f"{has_read=}, {has_write=}")
if not has_read or not has_write:
print(f"""
WARNING: Limited permissions for model storage directory: {hf_home}.
WARNING: Limited permissions for exo home directory: {home}.
This may prevent model downloads from working correctly.
{"❌ No read access" if not has_read else ""}
{"❌ No write access" if not has_write else ""}
@@ -336,9 +335,9 @@ async def main():
if not args.models_seed_dir is None:
try:
models_seed_dir = clean_path(args.models_seed_dir)
await move_models_to_hf(models_seed_dir)
await seed_models(models_seed_dir)
except Exception as e:
print(f"Error moving models to .cache/huggingface: {e}")
print(f"Error seeding models: {e}")
def restore_cursor():
if platform.system() != "Windows":

View File

@@ -13,7 +13,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
from exo import DEBUG
from exo.helpers import AsyncCallbackSystem
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import RepoProgressEvent
from exo.download.download_progress import RepoProgressEvent
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.download.shard_download import ShardDownloader

View File

@@ -5,7 +5,7 @@ from exo.viz.topology_viz import TopologyViz
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.topology.partitioning_strategy import Partition
from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
from exo.download.download_progress import RepoProgressEvent
def create_hf_repo_progress_event(

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Dict
from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
from exo.topology.topology import Topology
from exo.topology.partitioning_strategy import Partition
from exo.download.hf.hf_helpers import RepoProgressEvent
from exo.download.download_progress import RepoProgressEvent
from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
from rich.console import Console, Group
from rich.text import Text