reformat with yapf format.py

This commit is contained in:
Alex Cheema
2024-08-22 14:05:43 +01:00
parent 2e27076665
commit ea70c9fb76
48 changed files with 1810 additions and 1789 deletions

View File

@@ -13,8 +13,8 @@ import argparse
import uuid
models = {
"mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
"mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
}
path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
@@ -29,60 +29,53 @@ tokenizer = load_tokenizer(model_path, tokenizer_config)
# "localhost:8080",
# DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
# )
peer2 = GRPCPeerHandle(
"node2",
"localhost:8081",
DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
)
peer2 = GRPCPeerHandle("node2", "localhost:8081", DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
shard = models[path_or_hf_repo]
request_id = str(uuid.uuid4())
async def run_prompt(prompt: str):
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if (hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None):
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
await peer2.connect()
await peer2.connect()
try:
await peer2.send_prompt(shard, prompt, request_id)
except Exception as e:
print(e)
import time
# poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
previous_length = 0
n_tokens = 0
start_time = time.perf_counter()
while True:
try:
await peer2.send_prompt(shard, prompt, request_id)
result, is_finished = await peer2.get_inference_result(request_id)
except Exception as e:
print(e)
continue
await asyncio.sleep(0.1)
import time
# poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
previous_length = 0
n_tokens = 0
start_time = time.perf_counter()
while True:
try:
result, is_finished = await peer2.get_inference_result(request_id)
except Exception as e:
continue
await asyncio.sleep(0.1)
# Print the updated string in place
updated_string = tokenizer.decode(result)
n_tokens = len(result)
print(updated_string[previous_length:], end='', flush=True)
previous_length = len(updated_string)
# Print the updated string in place
updated_string = tokenizer.decode(result)
n_tokens = len(result)
print(updated_string[previous_length:], end='', flush=True)
previous_length = len(updated_string)
if is_finished:
print("\nDone")
break
end_time = time.perf_counter()
print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
if is_finished:
print("\nDone")
break
end_time = time.perf_counter()
print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run prompt")
parser.add_argument("--prompt", type=str, help="The prompt to run")
args = parser.parse_args()
parser = argparse.ArgumentParser(description="Run prompt")
parser.add_argument("--prompt", type=str, help="The prompt to run")
args = parser.parse_args()
asyncio.run(run_prompt(args.prompt))
asyncio.run(run_prompt(args.prompt))

View File

@@ -16,29 +16,27 @@ from exo.orchestration import Node
from exo.models import model_base_shards
from typing import Callable
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
self.role = role
self.content = content
def to_dict(self):
return {
"role": self.role,
"content": self.content
}
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
self.role = role
self.content = content
def to_dict(self):
return {"role": self.role, "content": self.content}
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
self.model = model
self.messages = messages
self.temperature = temperature
def to_dict(self):
return {
"model": self.model,
"messages": [message.to_dict() for message in self.messages],
"temperature": self.temperature
}
def __init__(self, model: str, messages: List[Message], temperature: float):
self.model = model
self.messages = messages
self.temperature = temperature
def to_dict(self):
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
def generate_completion(
chat_request: ChatCompletionRequest,
@@ -56,14 +54,12 @@ def generate_completion(
"created": int(time.time()),
"model": chat_request.model,
"system_fingerprint": f"exo_{VERSION}",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": tokenizer.decode(tokens)},
"logprobs": None,
"finish_reason": finish_reason,
}
],
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": tokenizer.decode(tokens)},
"logprobs": None,
"finish_reason": finish_reason,
}],
}
if not stream:
@@ -86,37 +82,38 @@ def generate_completion(
def remap_messages(messages: List[Message]) -> List[Message]:
remapped_messages = []
last_image = None
for message in messages:
if not isinstance(message.content, list):
remapped_messages.append(message)
continue
remapped_messages = []
last_image = None
for message in messages:
if not isinstance(message.content, list):
remapped_messages.append(message)
continue
remapped_content = []
for content in message.content:
if isinstance(content, dict):
if content.get("type") in ["image_url", "image"]:
image_url = content.get("image_url", {}).get("url") or content.get("image")
if image_url:
last_image = {"type": "image", "image": image_url}
remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
else:
remapped_content.append(content)
else:
remapped_content.append(content)
remapped_messages.append(Message(role=message.role, content=remapped_content))
remapped_content = []
for content in message.content:
if isinstance(content, dict):
if content.get("type") in ["image_url", "image"]:
image_url = content.get("image_url", {}).get("url") or content.get("image")
if image_url:
last_image = {"type": "image", "image": image_url}
remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
else:
remapped_content.append(content)
else:
remapped_content.append(content)
remapped_messages.append(Message(role=message.role, content=remapped_content))
if last_image:
# Replace the last image placeholder with the actual image content
for message in reversed(remapped_messages):
for i, content in enumerate(message.content):
if isinstance(content, dict):
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
message.content[i] = last_image
return remapped_messages
if last_image:
# Replace the last image placeholder with the actual image content
for message in reversed(remapped_messages):
for i, content in enumerate(message.content):
if isinstance(content, dict):
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
message.content[i] = last_image
return remapped_messages
return remapped_messages
return remapped_messages
def build_prompt(tokenizer, _messages: List[Message]):
messages = remap_messages(_messages)
@@ -149,13 +146,17 @@ def parse_chat_request(data: dict):
data.get("temperature", 0.0),
)
class PromptSession:
def __init__(self, request_id: str, timestamp: int, prompt: str):
self.request_id = request_id
self.timestamp = timestamp
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
@@ -182,6 +183,7 @@ class ChatGPTAPI:
self.app.middlewares.append(self.log_request)
async def log_request(self, app, handler):
async def middleware(request):
if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
return await handler(request)
@@ -268,7 +270,8 @@ class ChatGPTAPI:
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if is_finished:

View File

@@ -2,81 +2,67 @@ from typing import Dict, Callable, Coroutine, Any, Literal
from dataclasses import dataclass
from datetime import timedelta
@dataclass
class RepoFileProgressEvent:
repo_id: str
repo_revision: str
file_path: str
downloaded: int
downloaded_this_session: int
total: int
speed: int
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
repo_id: str
repo_revision: str
file_path: str
downloaded: int
downloaded_this_session: int
total: int
speed: int
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
def to_dict(self):
return {
"repo_id": self.repo_id,
"repo_revision": self.repo_revision,
"file_path": self.file_path,
"downloaded": self.downloaded,
"downloaded_this_session": self.downloaded_this_session,
"total": self.total,
"speed": self.speed,
"eta": self.eta.total_seconds(),
"status": self.status
}
def to_dict(self):
return {
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
}
@classmethod
def from_dict(cls, data):
# Convert eta from seconds back to timedelta
if 'eta' in data:
data['eta'] = timedelta(seconds=data['eta'])
return cls(**data)
@classmethod
def from_dict(cls, data):
# Convert eta from seconds back to timedelta
if 'eta' in data:
data['eta'] = timedelta(seconds=data['eta'])
return cls(**data)
@dataclass
class RepoProgressEvent:
repo_id: str
repo_revision: str
completed_files: int
total_files: int
downloaded_bytes: int
downloaded_bytes_this_session: int
total_bytes: int
overall_speed: int
overall_eta: timedelta
file_progress: Dict[str, RepoFileProgressEvent]
status: Literal["not_started", "in_progress", "complete"]
repo_id: str
repo_revision: str
completed_files: int
total_files: int
downloaded_bytes: int
downloaded_bytes_this_session: int
total_bytes: int
overall_speed: int
overall_eta: timedelta
file_progress: Dict[str, RepoFileProgressEvent]
status: Literal["not_started", "in_progress", "complete"]
def to_dict(self):
return {
"repo_id": self.repo_id,
"repo_revision": self.repo_revision,
"completed_files": self.completed_files,
"total_files": self.total_files,
"downloaded_bytes": self.downloaded_bytes,
"downloaded_bytes_this_session": self.downloaded_bytes_this_session,
"total_bytes": self.total_bytes,
"overall_speed": self.overall_speed,
"overall_eta": self.overall_eta.total_seconds(),
"file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
"status": self.status
}
def to_dict(self):
return {
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
"downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
"file_progress": {k: v.to_dict()
for k, v in self.file_progress.items()}, "status": self.status
}
@classmethod
def from_dict(cls, data):
# Convert overall_eta from seconds back to timedelta
if 'overall_eta' in data:
data['overall_eta'] = timedelta(seconds=data['overall_eta'])
@classmethod
def from_dict(cls, data):
# Convert overall_eta from seconds back to timedelta
if 'overall_eta' in data:
data['overall_eta'] = timedelta(seconds=data['overall_eta'])
# Parse file_progress
if 'file_progress' in data:
data['file_progress'] = {
k: RepoFileProgressEvent.from_dict(v)
for k, v in data['file_progress'].items()
}
# Parse file_progress
if 'file_progress' in data:
data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
return cls(**data)
return cls(**data)
RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

View File

@@ -16,282 +16,322 @@ import aiofiles
from aiofiles import os as aios
T = TypeVar("T")
def filter_repo_objects(
items: Iterable[T],
*,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
key: Optional[Callable[[T], str]] = None,
items: Iterable[T],
*,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
key: Optional[Callable[[T], str]] = None,
) -> Generator[T, None, None]:
if isinstance(allow_patterns, str):
allow_patterns = [allow_patterns]
if isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
if allow_patterns is not None:
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
if ignore_patterns is not None:
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
if isinstance(allow_patterns, str):
allow_patterns = [allow_patterns]
if isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
if allow_patterns is not None:
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
if ignore_patterns is not None:
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
if key is None:
def _identity(item: T) -> str:
if isinstance(item, str):
return item
if isinstance(item, Path):
return str(item)
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
key = _identity
if key is None:
def _identity(item: T) -> str:
if isinstance(item, str):
return item
if isinstance(item, Path):
return str(item)
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
key = _identity
for item in items:
path = key(item)
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
continue
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
continue
yield item
for item in items:
path = key(item)
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
continue
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
continue
yield item
def _add_wildcard_to_directories(pattern: str) -> str:
if pattern[-1] == "/":
return pattern + "*"
return pattern
if pattern[-1] == "/":
return pattern + "*"
return pattern
def get_hf_home() -> Path:
"""Get the Hugging Face home directory."""
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
"""Get the Hugging Face home directory."""
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
async def get_hf_token():
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
token_path = get_hf_home() / "token"
if await aios.path.exists(token_path):
async with aiofiles.open(token_path, 'r') as f:
return (await f.read()).strip()
return None
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
token_path = get_hf_home() / "token"
if await aios.path.exists(token_path):
async with aiofiles.open(token_path, 'r') as f:
return (await f.read()).strip()
return None
async def get_auth_headers():
"""Get authentication headers if a token is available."""
token = await get_hf_token()
if token:
return {"Authorization": f"Bearer {token}"}
return {}
"""Get authentication headers if a token is available."""
token = await get_hf_token()
if token:
return {"Authorization": f"Bearer {token}"}
return {}
def get_repo_root(repo_id: str) -> Path:
"""Get the root directory for a given repo ID in the Hugging Face cache."""
sanitized_repo_id = repo_id.replace("/", "--")
return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
"""Get the root directory for a given repo ID in the Hugging Face cache."""
sanitized_repo_id = repo_id.replace("/", "--")
return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_auth_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
files = []
for item in data:
if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory":
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
headers = await get_auth_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
files = []
for item in data:
if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory":
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
reraise=True
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
)
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
local_path = os.path.join(save_directory, file_path)
async def download_file(
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
):
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
local_path = os.path.join(save_directory, file_path)
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
# Check if file already exists and get its size
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
# Check if file already exists and get its size
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
headers = await get_auth_headers()
if use_range_request:
headers["Range"] = f"bytes={local_file_size}-"
headers = await get_auth_headers()
if use_range_request:
headers["Range"] = f"bytes={local_file_size}-"
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get('Content-Length', 0))
downloaded_size = local_file_size
downloaded_this_session = 0
mode = 'ab' if use_range_request else 'wb'
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get('Content-Length', 0))
downloaded_size = local_file_size
downloaded_this_session = 0
mode = 'ab' if use_range_request else 'wb'
if downloaded_size == total_size:
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
if response.status == 200:
# File doesn't support range requests or we're not using them, start from beginning
mode = 'wb'
downloaded_size = 0
elif response.status == 206:
# Partial content, resume download
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
elif response.status == 416:
# Range not satisfiable, get the actual file size
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
if downloaded_size == total_size:
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
if response.status == 200:
# File doesn't support range requests or we're not using them, start from beginning
mode = 'wb'
downloaded_size = 0
elif response.status == 206:
# Partial content, resume download
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
elif response.status == 416:
# Range not satisfiable, get the actual file size
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
if downloaded_size == total_size:
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
else:
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
if downloaded_size == total_size:
print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
DOWNLOAD_CHUNK_SIZE = 32768
start_time = datetime.now()
async with aiofiles.open(local_path, mode) as f:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
await f.write(chunk)
downloaded_size += len(chunk)
downloaded_this_session += len(chunk)
if progress_callback and total_size:
elapsed_time = (datetime.now() - start_time).total_seconds()
speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
remaining_size = total_size - downloaded_size
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
status = "in_progress" if downloaded_size < total_size else "complete"
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
if DEBUG >= 2: print(f"Downloaded: {file_path}")
async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
repo_root = get_repo_root(repo_id)
refs_dir = repo_root / "refs"
snapshots_dir = repo_root / "snapshots"
cachedreqs_dir = repo_root / "cachedreqs"
# Ensure directories exist
await aios.makedirs(refs_dir, exist_ok=True)
await aios.makedirs(snapshots_dir, exist_ok=True)
await aios.makedirs(cachedreqs_dir, exist_ok=True)
# Check if we have a cached commit hash
refs_file = refs_dir / revision
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
else:
async with aiohttp.ClientSession() as session:
# Fetch the commit hash for the given revision
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
headers = await get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
# Cache the commit hash
async with aiofiles.open(refs_file, 'w') as f:
await f.write(commit_hash)
if downloaded_size == total_size:
print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
# Set up the snapshot directory
snapshot_dir = snapshots_dir / commit_hash
await aios.makedirs(snapshot_dir, exist_ok=True)
DOWNLOAD_CHUNK_SIZE = 32768
start_time = datetime.now()
async with aiofiles.open(local_path, mode) as f:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
await f.write(chunk)
downloaded_size += len(chunk)
downloaded_this_session += len(chunk)
if progress_callback and total_size:
elapsed_time = (datetime.now() - start_time).total_seconds()
speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
remaining_size = total_size - downloaded_size
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
status = "in_progress" if downloaded_size < total_size else "complete"
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
if DEBUG >= 2: print(f"Downloaded: {file_path}")
# Set up the cached file list directory
cached_file_list_dir = cachedreqs_dir / commit_hash
await aios.makedirs(cached_file_list_dir, exist_ok=True)
cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
async def download_repo_files(
repo_id: str,
revision: str = "main",
progress_callback: Optional[RepoProgressCallback] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_parallel_downloads: int = 4
) -> Path:
repo_root = get_repo_root(repo_id)
refs_dir = repo_root / "refs"
snapshots_dir = repo_root / "snapshots"
cachedreqs_dir = repo_root / "cachedreqs"
# Ensure directories exist
await aios.makedirs(refs_dir, exist_ok=True)
await aios.makedirs(snapshots_dir, exist_ok=True)
await aios.makedirs(cachedreqs_dir, exist_ok=True)
# Check if we have a cached commit hash
refs_file = refs_dir / revision
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
else:
async with aiohttp.ClientSession() as session:
# Check if we have a cached file list
if await aios.path.exists(cached_file_list_path):
async with aiofiles.open(cached_file_list_path, 'r') as f:
file_list = json.loads(await f.read())
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
else:
file_list = await fetch_file_list(session, repo_id, revision)
# Cache the file list
async with aiofiles.open(cached_file_list_path, 'w') as f:
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
# Fetch the commit hash for the given revision
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
headers = await get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)
file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
start_time = datetime.now()
# Cache the commit hash
async with aiofiles.open(refs_file, 'w') as f:
await f.write(commit_hash)
async def download_with_progress(file_info, progress_state):
local_path = snapshot_dir / file_info["path"]
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
progress_state['completed_files'] += 1
progress_state['downloaded_bytes'] += file_info["size"]
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
return
# Set up the snapshot directory
snapshot_dir = snapshots_dir / commit_hash
await aios.makedirs(snapshot_dir, exist_ok=True)
async def file_progress_callback(event: RepoFileProgressEvent):
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
file_progress[event.file_path] = event
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
# Set up the cached file list directory
cached_file_list_dir = cachedreqs_dir / commit_hash
await aios.makedirs(cached_file_list_dir, exist_ok=True)
cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
progress_state['completed_files'] += 1
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
async with aiohttp.ClientSession() as session:
# Check if we have a cached file list
if await aios.path.exists(cached_file_list_path):
async with aiofiles.open(cached_file_list_path, 'r') as f:
file_list = json.loads(await f.read())
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
else:
file_list = await fetch_file_list(session, repo_id, revision)
# Cache the file list
async with aiofiles.open(cached_file_list_path, 'w') as f:
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)
file_progress: Dict[str, RepoFileProgressEvent] = {
file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
for file in filtered_file_list
}
start_time = datetime.now()
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file_info):
async with semaphore:
await download_with_progress(file_info, progress_state)
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
await asyncio.gather(*tasks)
async def download_with_progress(file_info, progress_state):
local_path = snapshot_dir / file_info["path"]
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
progress_state['completed_files'] += 1
progress_state['downloaded_bytes'] += file_info["size"]
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
return
async def file_progress_callback(event: RepoFileProgressEvent):
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
file_progress[event.file_path] = event
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
progress_state['completed_files'] += 1
file_progress[
file_info["path"]
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file_info):
async with semaphore:
await download_with_progress(file_info, progress_state)
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
await asyncio.gather(*tasks)
return snapshot_dir
return snapshot_dir
async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
"""
"""
Retrieve the weight map from the model.safetensors.index.json file.
Args:
@@ -302,55 +342,52 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
"""
# Download the index file
await download_repo_files(
repo_id=repo_id,
revision=revision,
allow_patterns="model.safetensors.index.json"
)
# Download the index file
await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
# Check if the file exists
repo_root = get_repo_root(repo_id)
snapshot_dir = repo_root / "snapshots"
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
# Check if the file exists
repo_root = get_repo_root(repo_id)
snapshot_dir = repo_root / "snapshots"
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
if index_file:
index_file_path = snapshot_dir / index_file
if await aios.path.exists(index_file_path):
async with aiofiles.open(index_file_path, 'r') as f:
index_data = json.loads(await f.read())
return index_data.get("weight_map")
if index_file:
index_file_path = snapshot_dir / index_file
if await aios.path.exists(index_file_path):
async with aiofiles.open(index_file_path, 'r') as f:
index_data = json.loads(await f.read())
return index_data.get("weight_map")
return None
return None
def extract_layer_num(tensor_name: str) -> Optional[int]:
# This is a simple example and might need to be adjusted based on the actual naming convention
parts = tensor_name.split('.')
for part in parts:
if part.isdigit():
return int(part)
return None
# This is a simple example and might need to be adjusted based on the actual naming convention
parts = tensor_name.split('.')
for part in parts:
if part.isdigit():
return int(part)
return None
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
default_patterns = [
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
]
shard_specific_patterns = []
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
shard_specific_patterns.append(filename)
sorted_file_names = sorted(weight_map.values())
if shard.is_first_layer():
shard_specific_patterns.append(sorted_file_names[0])
elif shard.is_last_layer():
shard_specific_patterns.append(sorted_file_names[-1])
else:
shard_specific_patterns = ["*.safetensors"]
return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates
default_patterns = [
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
]
shard_specific_patterns = []
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
shard_specific_patterns.append(filename)
sorted_file_names = sorted(weight_map.values())
if shard.is_first_layer():
shard_specific_patterns.append(sorted_file_names[0])
elif shard.is_last_layer():
shard_specific_patterns.append(sorted_file_names[-1])
else:
shard_specific_patterns = ["*.safetensors"]
return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates

View File

@@ -8,72 +8,70 @@ from exo.download.download_progress import RepoProgressEvent
from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
from exo.helpers import AsyncCallbackSystem, DEBUG
class HFShardDownloader(ShardDownloader):
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.quick_check = quick_check
self.max_parallel_downloads = max_parallel_downloads
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
async def ensure_shard(self, shard: Shard) -> Path:
if shard in self.completed_downloads:
return self.completed_downloads[shard]
if self.quick_check:
repo_root = get_repo_root(shard.model_id)
snapshots_dir = repo_root / "snapshots"
if snapshots_dir.exists():
most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
return most_recent_dir
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.quick_check = quick_check
self.max_parallel_downloads = max_parallel_downloads
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
# If a download on this shard is already in progress, keep that one
for active_shard in self.active_downloads:
if active_shard == shard:
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
return await self.active_downloads[shard]
async def ensure_shard(self, shard: Shard) -> Path:
if shard in self.completed_downloads:
return self.completed_downloads[shard]
if self.quick_check:
repo_root = get_repo_root(shard.model_id)
snapshots_dir = repo_root / "snapshots"
if snapshots_dir.exists():
most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
return most_recent_dir
# Cancel any downloads for this model_id on a different shard
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
for active_shard in existing_active_shards:
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
task = self.active_downloads[active_shard]
task.cancel()
try:
await task
except asyncio.CancelledError:
pass # This is expected when cancelling a task
except Exception as e:
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
traceback.print_exc()
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
# If a download on this shard is already in progress, keep that one
for active_shard in self.active_downloads:
if active_shard == shard:
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
return await self.active_downloads[shard]
# Start new download
download_task = asyncio.create_task(self._download_shard(shard))
self.active_downloads[shard] = download_task
try:
path = await download_task
self.completed_downloads[shard] = path
return path
finally:
# Ensure the task is removed even if an exception occurs
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
if shard in self.active_downloads:
self.active_downloads.pop(shard)
# Cancel any downloads for this model_id on a different shard
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
for active_shard in existing_active_shards:
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
task = self.active_downloads[active_shard]
task.cancel()
try:
await task
except asyncio.CancelledError:
pass # This is expected when cancelling a task
except Exception as e:
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
traceback.print_exc()
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
async def _download_shard(self, shard: Shard) -> Path:
async def wrapped_progress_callback(event: RepoProgressEvent):
self._on_progress.trigger_all(shard, event)
# Start new download
download_task = asyncio.create_task(self._download_shard(shard))
self.active_downloads[shard] = download_task
try:
path = await download_task
self.completed_downloads[shard] = path
return path
finally:
# Ensure the task is removed even if an exception occurs
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
if shard in self.active_downloads:
self.active_downloads.pop(shard)
weight_map = await get_weight_map(shard.model_id)
allow_patterns = get_allow_patterns(weight_map, shard)
async def _download_shard(self, shard: Shard) -> Path:
return await download_repo_files(
repo_id=shard.model_id,
progress_callback=wrapped_progress_callback,
allow_patterns=allow_patterns,
max_parallel_downloads=self.max_parallel_downloads
)
async def wrapped_progress_callback(event: RepoProgressEvent):
self._on_progress.trigger_all(shard, event)
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress
weight_map = await get_weight_map(shard.model_id)
allow_patterns = get_allow_patterns(weight_map, shard)
return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress

View File

@@ -5,10 +5,12 @@ from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import AsyncCallbackSystem
class ShardDownloader(ABC):
@abstractmethod
async def ensure_shard(self, shard: Shard) -> Path:
"""
@abstractmethod
async def ensure_shard(self, shard: Shard) -> Path:
"""
Ensures that the shard is downloaded.
Does not allow multiple overlapping downloads at once.
If you try to download a Shard which overlaps a Shard that is already being downloaded,
@@ -17,9 +19,9 @@ class ShardDownloader(ABC):
Args:
shard (Shard): The shard to download.
"""
pass
pass
@property
@abstractmethod
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
pass
@property
@abstractmethod
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
pass

View File

@@ -20,6 +20,7 @@ exo_text = r"""
\___/_/\_\___/
"""
def get_system_info():
if psutil.MACOS:
if platform.machine() == "arm64":
@@ -87,7 +88,10 @@ def terminal_link(uri, label=None):
T = TypeVar("T")
K = TypeVar("K")
class AsyncCallback(Generic[T]):
def __init__(self) -> None:
self.condition: asyncio.Condition = asyncio.Condition()
self.result: Optional[Tuple[T, ...]] = None
@@ -95,9 +99,7 @@ class AsyncCallback(Generic[T]):
async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
async with self.condition:
await asyncio.wait_for(
self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout
)
await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
assert self.result is not None # for type checking
return self.result
@@ -116,6 +118,7 @@ class AsyncCallback(Generic[T]):
class AsyncCallbackSystem(Generic[K, T]):
def __init__(self) -> None:
self.callbacks: Dict[K, AsyncCallback[T]] = {}
@@ -139,89 +142,97 @@ class AsyncCallbackSystem(Generic[K, T]):
K = TypeVar('K', bound=str)
V = TypeVar('V')
class PrefixDict(Generic[K, V]):
def __init__(self):
self.items: Dict[K, V] = {}
def add(self, key: K, value: V) -> None:
self.items[key] = value
def __init__(self):
self.items: Dict[K, V] = {}
def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
def add(self, key: K, value: V) -> None:
self.items[key] = value
def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
matches = self.find_prefix(argument)
if len(matches) == 0:
return None
def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
matches = self.find_prefix(argument)
if len(matches) == 0:
return None
return max(matches, key=lambda x: len(x[0]))
return max(matches, key=lambda x: len(x[0]))
def is_valid_uuid(val):
try:
uuid.UUID(str(val))
return True
except ValueError:
return False
try:
uuid.UUID(str(val))
return True
except ValueError:
return False
def get_or_create_node_id():
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
try:
if NODE_ID_FILE.is_file():
with open(NODE_ID_FILE, "r") as f:
stored_id = f.read().strip()
if is_valid_uuid(stored_id):
if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
return stored_id
else:
if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
try:
if NODE_ID_FILE.is_file():
with open(NODE_ID_FILE, "r") as f:
stored_id = f.read().strip()
if is_valid_uuid(stored_id):
if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
return stored_id
else:
if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
new_id = str(uuid.uuid4())
with open(NODE_ID_FILE, "w") as f:
f.write(new_id)
new_id = str(uuid.uuid4())
with open(NODE_ID_FILE, "w") as f:
f.write(new_id)
if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
return new_id
except IOError as e:
if DEBUG >= 2: print(f"IO error creating node_id: {e}")
return str(uuid.uuid4())
except Exception as e:
if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
return str(uuid.uuid4())
if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
return new_id
except IOError as e:
if DEBUG >= 2: print(f"IO error creating node_id: {e}")
return str(uuid.uuid4())
except Exception as e:
if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
return str(uuid.uuid4())
def pretty_print_bytes(size_in_bytes: int) -> str:
if size_in_bytes < 1024:
return f"{size_in_bytes} B"
elif size_in_bytes < 1024 ** 2:
return f"{size_in_bytes / 1024:.2f} KB"
elif size_in_bytes < 1024 ** 3:
return f"{size_in_bytes / (1024 ** 2):.2f} MB"
elif size_in_bytes < 1024 ** 4:
return f"{size_in_bytes / (1024 ** 3):.2f} GB"
else:
return f"{size_in_bytes / (1024 ** 4):.2f} TB"
if size_in_bytes < 1024:
return f"{size_in_bytes} B"
elif size_in_bytes < 1024**2:
return f"{size_in_bytes / 1024:.2f} KB"
elif size_in_bytes < 1024**3:
return f"{size_in_bytes / (1024 ** 2):.2f} MB"
elif size_in_bytes < 1024**4:
return f"{size_in_bytes / (1024 ** 3):.2f} GB"
else:
return f"{size_in_bytes / (1024 ** 4):.2f} TB"
def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
if bytes_per_second < 1024:
return f"{bytes_per_second} B/s"
elif bytes_per_second < 1024 ** 2:
return f"{bytes_per_second / 1024:.2f} KB/s"
elif bytes_per_second < 1024 ** 3:
return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
elif bytes_per_second < 1024 ** 4:
return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
else:
return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
if bytes_per_second < 1024:
return f"{bytes_per_second} B/s"
elif bytes_per_second < 1024**2:
return f"{bytes_per_second / 1024:.2f} KB/s"
elif bytes_per_second < 1024**3:
return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
elif bytes_per_second < 1024**4:
return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
else:
return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
def get_all_ip_addresses():
try:
ip_addresses = []
for interface in netifaces.interfaces():
ifaddresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in ifaddresses:
for link in ifaddresses[netifaces.AF_INET]:
ip = link['addr']
ip_addresses.append(ip)
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return ["localhost"]
try:
ip_addresses = []
for interface in netifaces.interfaces():
ifaddresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in ifaddresses:
for link in ifaddresses[netifaces.AF_INET]:
ip = link['addr']
ip_addresses.append(ip)
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return ["localhost"]

View File

@@ -52,10 +52,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
assert np.array_equal(next_resp_full, resp4)
asyncio.run(
test_inference_engine(
TinygradDynamicShardInferenceEngine(),
TinygradDynamicShardInferenceEngine(),
"llama3-8b-sfr",
)
)
asyncio.run(test_inference_engine(
TinygradDynamicShardInferenceEngine(),
TinygradDynamicShardInferenceEngine(),
"llama3-8b-sfr",
))

View File

@@ -5,7 +5,9 @@ from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
class InferenceEngine(ABC):
@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
pass

View File

@@ -5,5 +5,6 @@ from mlx_lm.models.base import KVCache
class IdentityBlock(nn.Module):
def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
return x

View File

@@ -7,7 +7,7 @@ import mlx.nn as nn
from mlx_lm.models.base import KVCache
from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
from .base import IdentityBlock
from ...shard import Shard
from exo.inference.shard import Shard
@dataclass
@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
class DeepseekV2Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
@@ -70,6 +71,7 @@ class DeepseekV2Model(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
@@ -107,10 +109,7 @@ class Model(nn.Module):
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
shard_state_dict[
f"{prefix}.mlp.switch_mlp.{
m}.{k}"
] = mx.stack(to_join)
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return shard_state_dict

View File

@@ -24,7 +24,9 @@ class ModelArgs(ModelArgs):
self.shard = Shard(**self.shard)
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
@@ -66,7 +68,9 @@ class LlamaModel(nn.Module):
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
@@ -116,9 +120,7 @@ class Model(nn.Module):
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
@property
def n_kv_heads(self):

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,7 @@ from exo.download.shard_download import ShardDownloader
class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader

View File

@@ -10,6 +10,7 @@ from ..shard import Shard
class StatefulShardedModel:
def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
self.shard = shard
self.model = model
@@ -26,6 +27,7 @@ class StatefulShardedModel:
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
def sample(logits: mx.array) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
@@ -74,16 +76,9 @@ class StatefulShardedModel:
return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
def init_cache(self, request_id: str):
kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
if self.max_kv_size is not None:
cache = [
RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4)
for n in kv_heads
]
cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
else:
cache = [KVCache(self.model.head_dim, n) for n in kv_heads]

View File

@@ -25,6 +25,7 @@ from ..shard import Shard
class ModelNotFoundError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
@@ -139,9 +140,10 @@ def load_model_shard(
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
def class_predicate(p, m):
if not hasattr(m, "to_quantized"):
return False
return f"{p}.scales" in weights
if not hasattr(m, "to_quantized"):
return False
return f"{p}.scales" in weights
nn.quantize(
model,
**quantization,
@@ -156,6 +158,7 @@ def load_model_shard(
model.eval()
return model
async def load_shard(
model_path: str,
shard: Shard,
@@ -179,26 +182,27 @@ async def load_shard(
tokenizer = load_tokenizer(model_path, tokenizer_config)
return model, tokenizer
async def get_image_from_str(_image_str: str):
image_str = _image_str.strip()
image_str = _image_str.strip()
if image_str.startswith("http"):
async with aiohttp.ClientSession() as session:
async with session.get(image_str, timeout=10) as response:
content = await response.read()
return Image.open(BytesIO(content)).convert("RGB")
elif image_str.startswith("data:image/"):
# Extract the image format and base64 data
format_prefix, base64_data = image_str.split(";base64,")
image_format = format_prefix.split("/")[1].lower()
if DEBUG >= 2: print(f"{image_str=} {image_format=}")
imgdata = base64.b64decode(base64_data)
img = Image.open(BytesIO(imgdata))
if image_str.startswith("http"):
async with aiohttp.ClientSession() as session:
async with session.get(image_str, timeout=10) as response:
content = await response.read()
return Image.open(BytesIO(content)).convert("RGB")
elif image_str.startswith("data:image/"):
# Extract the image format and base64 data
format_prefix, base64_data = image_str.split(";base64,")
image_format = format_prefix.split("/")[1].lower()
if DEBUG >= 2: print(f"{image_str=} {image_format=}")
imgdata = base64.b64decode(base64_data)
img = Image.open(BytesIO(imgdata))
# Convert to RGB if not already
if img.mode != "RGB":
img = img.convert("RGB")
# Convert to RGB if not already
if img.mode != "RGB":
img = img.convert("RGB")
return img
else:
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
return img
else:
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")

View File

@@ -39,8 +39,8 @@ y = full.step("full", input_ids, pixel_values, temp=0)
full_generated_tokens = [y.item()]
for _ in range(13):
y = full.step("full", y, temp=0)
full_generated_tokens.append(y.item())
y = full.step("full", y, temp=0)
full_generated_tokens.append(y.item())
full_response = full_processor.tokenizer.decode(full_generated_tokens)
print("full response:", full_response)
@@ -54,11 +54,11 @@ y = m2.step("shard", y, temp=0)
full_generated_tokens = [y.item()]
for _ in range(13):
y = m1.step("shard", y, temp=0)
y = m2.step("shard", y, temp=0)
full_generated_tokens.append(y.item())
y = m1.step("shard", y, temp=0)
y = m2.step("shard", y, temp=0)
full_generated_tokens.append(y.item())
sharded_response = processor2.tokenizer.decode(full_generated_tokens)
print("sharded response:", sharded_response)
assert full_response == sharded_response
assert full_response == sharded_response

View File

@@ -6,6 +6,7 @@ import numpy as np
class DummyModel(nn.Module):
def __init__(self, shard: Optional[Shard] = None):
self.shard = shard
self.layers = [
@@ -21,7 +22,7 @@ class DummyModel(nn.Module):
def __call__(self, x, cache=None):
if self.shard:
for layer in self.layers[self.shard.start_layer : self.shard.end_layer + 1]:
for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
x = layer(x)
if self.shard.is_last_layer():
x = x.reshape((1, 2, 4))

View File

@@ -34,8 +34,6 @@ class Shard:
def overlaps(self, other: 'Shard') -> bool:
return shards_overlap(self, other)
def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
return (
shard1.model_id == shard2.model_id
and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
)
return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))

View File

@@ -7,6 +7,7 @@ import os
import asyncio
import numpy as np
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
prompt = "In a single word only, what is the last name of the current president of the USA?"
@@ -22,7 +23,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
input_data=resp1,
inference_state=inference_state_1,
)
@@ -34,7 +35,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
)
resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
input_data=resp3,
inference_state=inference_state_3,
)
@@ -42,21 +43,22 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
assert np.array_equal(resp_full, resp2)
assert np.array_equal(next_resp_full, resp4)
asyncio.run(
test_inference_engine(
MLXDynamicShardInferenceEngine(HFShardDownloader()),
MLXDynamicShardInferenceEngine(HFShardDownloader()),
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
)
)
asyncio.run(test_inference_engine(
MLXDynamicShardInferenceEngine(HFShardDownloader()),
MLXDynamicShardInferenceEngine(HFShardDownloader()),
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
))
if os.getenv("RUN_TINYGRAD", default="0") == "1":
import tinygrad
import os
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
asyncio.run(test_inference_engine(
asyncio.run(
test_inference_engine(
TinygradDynamicShardInferenceEngine(HFShardDownloader()),
TinygradDynamicShardInferenceEngine(HFShardDownloader()),
"TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
))
)
)

View File

@@ -20,16 +20,11 @@ TOP_P = 0.9
ALPHA_F = 0.1
ALPHA_P = 0.0
MODEL_PARAMS = {
"8B": {
"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
"files": 1
},
"70B": {
"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672},
"files": 8
}
"8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
"70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
}
def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
# build model
linear = nn.Linear
@@ -48,10 +43,12 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
with Context(BEAM=0):
# replace weights in model
load_state_dict(model, weights, strict=False, consume=False) # consume=True
load_state_dict(model, weights, strict=False, consume=False) # consume=True
return model
class TinygradDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
@@ -64,7 +61,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
toks = self.tokenizer.encode(prompt)
h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
if h.shape == (1,):
if h.shape == (1, ):
start_pos += len(toks)
start_pos += 1
n_captured_toks = 0
@@ -80,7 +77,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
if h.shape == (1,):
if h.shape == (1, ):
start_pos += n_captured_toks
start_pos += 1
n_captured_toks = 0

View File

@@ -2,21 +2,24 @@ from typing import Tuple, Union, Optional, Dict, Any
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = 1.0 / (theta**(Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
# TODO: move dtype outside this
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
def complex_mult(A, c, d):
a,b = A[..., 0:1], A[..., 1:2]
ro = a*c - b*d
co = a*d + b*c
a, b = A[..., 0:1], A[..., 1:2]
ro = a * c - b * d
co = a * d + b * c
return ro.cat(co, dim=-1)
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -26,16 +29,19 @@ def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Te
xk_out = complex_mult(xk, c, d)
return xq_out.flatten(3), xk_out.flatten(3)
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
bs, seqlen, n_kv_heads, head_dim = x.shape
if n_rep == 1: return x
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
@@ -45,7 +51,7 @@ class Attention:
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
if getenv("WQKV"):
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
xqkv = x @ self.wqkv.T
@@ -69,10 +75,10 @@ class Attention:
# update the cache
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -80,26 +86,31 @@ class Attention:
attn = attn.reshape(bsz, seqlen, -1)
return self.wo(attn)
class FeedForward:
def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
def __call__(self, x: Tensor) -> Tensor:
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
def __call__(self, x:Tensor) -> Tensor:
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
self.feed_forward = feed_forward(dim, hidden_dim, linear)
self.attention_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
# standard openai sampling
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
assert logits.ndim == 1, "only works on 1d tensors"
@@ -127,8 +138,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
for i in range(k):
t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1), ))
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1), ))
t = (counter == t_argmax).where(0, t)
# approximate top p
@@ -149,10 +160,28 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
return output_token
from exo.inference.shard import Shard
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
def __init__(
self,
dim: int,
hidden_dim: int,
n_heads: int,
n_layers: int,
norm_eps: float,
vocab_size,
shard: Shard = None,
linear=nn.Linear,
n_kv_heads=None,
rope_theta=10000,
max_context=1024,
jit=True,
feed_forward=FeedForward
):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
@@ -162,10 +191,10 @@ class Transformer:
self.forward_jit = TinyJit(self.forward) if jit else None
self.shard = shard
def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
seqlen = x.shape[1]
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
if self.shard.is_first_layer():
h = self.tok_embeddings(x)
@@ -182,24 +211,33 @@ class Transformer:
else:
return h
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
# *** helpers ***
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
for x in ["q", "k", "v", "o"]
for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
for l in range(len(model.layers))},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
@@ -215,9 +253,10 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
sd[keymap[k]] = v
return sd
def fix_bf16(weights:Dict[Any, Tensor]):
def fix_bf16(weights: Dict[Any, Tensor]):
if getenv("SUPPORT_BF16", 1):
# TODO: without casting to float16, 70B llama OOM on tinybox.
return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
# TODO: check if device supports bf16
return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

View File

@@ -8,8 +8,10 @@ from exo.helpers import DEBUG
from exo.download.hf.hf_helpers import get_allow_patterns
from fnmatch import fnmatch
# **** helper functions ****
def concat_weights(models, device=None):
def convert(name) -> Tensor:
disk_tensors: List[Tensor] = [model[name] for model in models]
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
@@ -17,11 +19,14 @@ def concat_weights(models, device=None):
axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
lazy_tensors = [data.to(device=device) for data in disk_tensors]
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
return {name: convert(name) for name in {name: None for model in models for name in model}}
def load(fn:str, shard: Shard):
def load(fn: str, shard: Shard):
if fn.endswith('.index.json'):
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
with open(fn) as fp:
weight_map = json.load(fp)['weight_map']
parts = {}
filtered_weight_map = {}
allow_patterns = get_allow_patterns(weight_map, shard)

View File

@@ -2,6 +2,7 @@ import traceback
from transformers import AutoTokenizer, AutoProcessor
from exo.helpers import DEBUG
async def resolve_tokenizer(model_id: str):
try:
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")

View File

@@ -2,39 +2,32 @@ from exo.inference.shard import Shard
model_base_shards = {
### llama
"llama-3.1-8b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
},
"llama-3.1-70b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-405b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
},
"llama-3-8b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
},
"llama-3-70b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-8b":
{
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
},
"llama-3.1-70b":
{
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), },
"llama-3-8b":
{
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
},
"llama-3-70b":
{
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
},
### mistral
"mistral-nemo": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
},
"mistral-large": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
},
"mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), },
"mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), },
### deepseek v2
"deepseek-coder-v2-lite": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
},
"deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), },
### llava
"llava-1.5-7b-hf": {
"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
},
"llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), },
}

View File

@@ -4,6 +4,7 @@ from .peer_handle import PeerHandle
class Discovery(ABC):
@abstractmethod
async def start(self) -> None:
pass

View File

@@ -11,6 +11,7 @@ from exo import DEBUG_DISCOVERY
class ListenProtocol(asyncio.DatagramProtocol):
def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
super().__init__()
self.on_message = on_message
@@ -24,6 +25,7 @@ class ListenProtocol(asyncio.DatagramProtocol):
class GRPCDiscovery(Discovery):
def __init__(
self,
node_id: str,
@@ -97,14 +99,12 @@ class GRPCDiscovery(Discovery):
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
message = json.dumps(
{
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
}
).encode("utf-8")
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
}).encode("utf-8")
while True:
try:
@@ -166,14 +166,14 @@ class GRPCDiscovery(Discovery):
try:
current_time = time.time()
peers_to_remove = [
peer_handle.id()
for peer_handle, connected_at, last_seen in self.known_peers.values()
peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
]
if DEBUG_DISCOVERY >= 2:
print(
"Peer statuses:",
{peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()},
{peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
for peer_handle, connected_at, last_seen in self.known_peers.values()},
)
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
print(f"Cleaning up peers: {peers_to_remove}")

View File

@@ -13,6 +13,7 @@ from exo.topology.device_capabilities import DeviceCapabilities
class GRPCPeerHandle(PeerHandle):
def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
self._id = _id
self.address = address

View File

@@ -11,6 +11,7 @@ from exo.orchestration import Node
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
def __init__(self, node: Node, host: str, port: int):
self.node = node
self.host = host
@@ -81,9 +82,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
node_service_pb2.InferenceResult(
tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
is_finished=result[1],
)
if result[0] is not None
else node_service_pb2.InferenceResult(is_finished=result[1])
) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
)
async def CollectTopology(self, request, context):
@@ -91,12 +90,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
visited = set(request.visited)
topology = await self.node.collect_topology(visited, max_depth)
nodes = {
node_id: node_service_pb2.DeviceCapabilities(
model=cap.model,
chip=cap.chip,
memory=cap.memory,
flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
)
node_id:
node_service_pb2.DeviceCapabilities(
model=cap.model,
chip=cap.chip,
memory=cap.memory,
flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
)
for node_id, cap in topology.nodes.items()
}
peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}

View File

@@ -11,10 +11,9 @@ from google.protobuf.internal import builder as _builder
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -25,38 +24,38 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
_globals['_SHARD']._serialized_start=36
_globals['_SHARD']._serialized_end=119
_globals['_PROMPTREQUEST']._serialized_start=122
_globals['_PROMPTREQUEST']._serialized_end=317
_globals['_TENSORREQUEST']._serialized_start=320
_globals['_TENSORREQUEST']._serialized_end=499
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=501
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=548
_globals['_INFERENCERESULT']._serialized_start=550
_globals['_INFERENCERESULT']._serialized_end=642
_globals['_TENSOR']._serialized_start=644
_globals['_TENSOR']._serialized_end=703
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=705
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=765
_globals['_TOPOLOGY']._serialized_start=768
_globals['_TOPOLOGY']._serialized_end=1038
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=889
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=967
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=969
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1038
_globals['_PEERS']._serialized_start=1040
_globals['_PEERS']._serialized_end=1065
_globals['_DEVICEFLOPS']._serialized_start=1067
_globals['_DEVICEFLOPS']._serialized_end=1122
_globals['_DEVICECAPABILITIES']._serialized_start=1124
_globals['_DEVICECAPABILITIES']._serialized_end=1231
_globals['_SENDRESULTREQUEST']._serialized_start=1233
_globals['_SENDRESULTREQUEST']._serialized_end=1309
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1311
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1372
_globals['_EMPTY']._serialized_start=1374
_globals['_EMPTY']._serialized_end=1381
_globals['_NODESERVICE']._serialized_start=1384
_globals['_NODESERVICE']._serialized_end=1862
_globals['_SHARD']._serialized_start = 36
_globals['_SHARD']._serialized_end = 119
_globals['_PROMPTREQUEST']._serialized_start = 122
_globals['_PROMPTREQUEST']._serialized_end = 317
_globals['_TENSORREQUEST']._serialized_start = 320
_globals['_TENSORREQUEST']._serialized_end = 499
_globals['_GETINFERENCERESULTREQUEST']._serialized_start = 501
_globals['_GETINFERENCERESULTREQUEST']._serialized_end = 548
_globals['_INFERENCERESULT']._serialized_start = 550
_globals['_INFERENCERESULT']._serialized_end = 642
_globals['_TENSOR']._serialized_start = 644
_globals['_TENSOR']._serialized_end = 703
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start = 705
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end = 765
_globals['_TOPOLOGY']._serialized_start = 768
_globals['_TOPOLOGY']._serialized_end = 1038
_globals['_TOPOLOGY_NODESENTRY']._serialized_start = 889
_globals['_TOPOLOGY_NODESENTRY']._serialized_end = 967
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start = 969
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end = 1038
_globals['_PEERS']._serialized_start = 1040
_globals['_PEERS']._serialized_end = 1065
_globals['_DEVICEFLOPS']._serialized_start = 1067
_globals['_DEVICEFLOPS']._serialized_end = 1122
_globals['_DEVICECAPABILITIES']._serialized_start = 1124
_globals['_DEVICECAPABILITIES']._serialized_end = 1231
_globals['_SENDRESULTREQUEST']._serialized_start = 1233
_globals['_SENDRESULTREQUEST']._serialized_end = 1309
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start = 1311
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end = 1372
_globals['_EMPTY']._serialized_start = 1374
_globals['_EMPTY']._serialized_end = 1381
_globals['_NODESERVICE']._serialized_start = 1384
_globals['_NODESERVICE']._serialized_end = 1862
# @@protoc_insertion_point(module_scope)

View File

@@ -12,306 +12,264 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
_version_not_supported = True
if _version_not_supported:
warnings.warn(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in node_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
RuntimeWarning
)
warnings.warn(
f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
)
class NodeServiceStub(object):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=node__service__pb2.InferenceResult.FromString,
_registered_method=True)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
'/node_service.NodeService/SendOpaqueStatus',
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True
)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True
)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=node__service__pb2.InferenceResult.FromString,
_registered_method=True
)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
_registered_method=True
)
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True
)
self.SendOpaqueStatus = channel.unary_unary(
'/node_service.NodeService/SendOpaqueStatus',
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True
)
class NodeServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def SendPrompt(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPrompt(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendTensor(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendTensor(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetInferenceResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetInferenceResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CollectTopology(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CollectTopology(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendOpaqueStatus(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendOpaqueStatus(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_NodeServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
),
'CollectTopology': grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
),
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
servicer.SendOpaqueStatus,
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'node_service.NodeService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
rpc_method_handlers = {
'SendPrompt':
grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendTensor':
grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'GetInferenceResult':
grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
),
'CollectTopology':
grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
),
'SendResult':
grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus':
grpc.unary_unary_rpc_method_handler(
servicer.SendOpaqueStatus,
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler, ))
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class NodeService(object):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file."""
@staticmethod
def SendPrompt(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True
)
@staticmethod
def SendTensor(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True
)
@staticmethod
def GetInferenceResult(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/GetInferenceResult',
node__service__pb2.GetInferenceResultRequest.SerializeToString,
node__service__pb2.InferenceResult.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/GetInferenceResult',
node__service__pb2.GetInferenceResultRequest.SerializeToString,
node__service__pb2.InferenceResult.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True
)
@staticmethod
def CollectTopology(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/CollectTopology',
node__service__pb2.CollectTopologyRequest.SerializeToString,
node__service__pb2.Topology.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/CollectTopology',
node__service__pb2.CollectTopologyRequest.SerializeToString,
node__service__pb2.Topology.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True
)
@staticmethod
def SendResult(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True
)
@staticmethod
def SendOpaqueStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendOpaqueStatus',
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendOpaqueStatus',
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True
)

View File

@@ -4,6 +4,7 @@ from .grpc_discovery import GRPCDiscovery
class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)

View File

@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
class PeerHandle(ABC):
@abstractmethod
def id(self) -> str:
pass

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
class Server(ABC):
@abstractmethod
async def start(self) -> None:
pass

View File

@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
class Node(ABC):
@abstractmethod
async def start(self, wait_for_peers: int = 0) -> None:
pass

View File

@@ -18,6 +18,7 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
class StandardNode(Node):
def __init__(
self,
_id: str,
@@ -359,6 +360,7 @@ class StandardNode(Node):
self.on_token.trigger_all(request_id, tokens, is_finished)
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
async def send_result_to_peer(peer):
try:
await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
@@ -372,6 +374,7 @@ class StandardNode(Node):
async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
async def send_status_to_peer(peer):
try:
await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)

View File

@@ -7,6 +7,7 @@ from exo.networking.peer_handle import PeerHandle
class TestNode(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_inference_engine = AsyncMock()
self.mock_server = AsyncMock()

View File

@@ -12,6 +12,7 @@ async def main() -> None:
callback2 = callback_system.register("callback2")
def on_next_callback(name: str) -> Callable[..., None]:
def callback(*args: Any) -> None:
print(f"{name} received values: {args}")

View File

@@ -14,6 +14,7 @@ class Partition:
class PartitioningStrategy(ABC):
@abstractmethod
def partition(self, topology: Topology) -> List[Partition]:
pass

View File

@@ -5,6 +5,7 @@ from .partitioning_strategy import Partition
class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
def partition(self, topology: Topology) -> List[Partition]:
nodes = list(topology.all_nodes())
nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)

View File

@@ -4,6 +4,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
class TestMacDeviceCapabilities(unittest.TestCase):
@patch("subprocess.check_output")
def test_mac_device_capabilities_pro(self, mock_check_output):
# Mock the subprocess output

View File

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
def test_map_partitions_to_shards(self):
partitions = [
Partition("node1", 0.0, 0.42857),

View File

@@ -6,6 +6,7 @@ from exo.topology.partitioning_strategy import Partition
class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
def test_partition(self):
# triangle
# node1 -> node2 -> node3 -> node1

View File

@@ -3,6 +3,7 @@ from typing import Dict, Set, Optional
class Topology:
def __init__(self):
self.nodes: Dict[str, DeviceCapabilities] = {} # Maps node IDs to DeviceCapabilities
self.peer_graph: Dict[str, Set[str]] = {} # Adjacency list representing the graph

View File

@@ -9,58 +9,60 @@ from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
def create_hf_repo_progress_event(
completed_files: int = 5,
total_files: int = 10,
downloaded_bytes: int = 500000000,
downloaded_bytes_this_session: int = 250000000,
total_bytes: int = 1000000000,
overall_speed: int = 5000000,
overall_eta: timedelta = timedelta(seconds=100),
file_progress: dict = None,
status: str = "in_progress"
completed_files: int = 5,
total_files: int = 10,
downloaded_bytes: int = 500000000,
downloaded_bytes_this_session: int = 250000000,
total_bytes: int = 1000000000,
overall_speed: int = 5000000,
overall_eta: timedelta = timedelta(seconds=100),
file_progress: dict = None,
status: str = "in_progress"
) -> RepoProgressEvent:
if file_progress is None:
file_progress = {
"file1.bin": RepoFileProgressEvent(
repo_id="repo_id",
repo_revision="repo_revision",
file_path="file1.bin",
downloaded=100000000,
downloaded_this_session=50000000,
total=200000000,
speed=1000000,
eta=timedelta(seconds=100),
status="in_progress"
),
"file2.bin": RepoFileProgressEvent(
repo_id="repo_id",
repo_revision="repo_revision",
file_path="file2.bin",
downloaded=200000000,
downloaded_this_session=100000000,
total=200000000,
speed=2000000,
eta=timedelta(seconds=0),
status="complete"
)
}
if file_progress is None:
file_progress = {
"file1.bin":
RepoFileProgressEvent(
repo_id="repo_id",
repo_revision="repo_revision",
file_path="file1.bin",
downloaded=100000000,
downloaded_this_session=50000000,
total=200000000,
speed=1000000,
eta=timedelta(seconds=100),
status="in_progress"
), "file2.bin":
RepoFileProgressEvent(
repo_id="repo_id",
repo_revision="repo_revision",
file_path="file2.bin",
downloaded=200000000,
downloaded_this_session=100000000,
total=200000000,
speed=2000000,
eta=timedelta(seconds=0),
status="complete"
)
}
return RepoProgressEvent(
repo_id="repo_id",
repo_revision="repo_revision",
completed_files=completed_files,
total_files=total_files,
downloaded_bytes=downloaded_bytes,
downloaded_bytes_this_session=downloaded_bytes_this_session,
total_bytes=total_bytes,
overall_speed=overall_speed,
overall_eta=overall_eta,
file_progress=file_progress,
status=status
)
return RepoProgressEvent(
repo_id="repo_id",
repo_revision="repo_revision",
completed_files=completed_files,
total_files=total_files,
downloaded_bytes=downloaded_bytes,
downloaded_bytes_this_session=downloaded_bytes_this_session,
total_bytes=total_bytes,
overall_speed=overall_speed,
overall_eta=overall_eta,
file_progress=file_progress,
status=status
)
class TestNodeViz(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.topology = Topology()
self.topology.update_node(

View File

@@ -16,7 +16,9 @@ from rich.syntax import Syntax
from rich.panel import Panel
from rich.markdown import Markdown
class TopologyViz:
def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
self.chatgpt_api_endpoints = chatgpt_api_endpoints
self.web_chat_urls = web_chat_urls
@@ -28,11 +30,7 @@ class TopologyViz:
self.console = Console()
self.layout = Layout()
self.layout.split(
Layout(name="main"),
Layout(name="prompt_output", size=15),
Layout(name="download", size=25)
)
self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25))
self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
self.download_panel = Panel("", title="Download Progress", border_style="cyan")
@@ -75,11 +73,11 @@ class TopologyViz:
# Update and show/hide prompt and output panel
if any(r[0] or r[1] for r in self.requests.values()):
self.prompt_output_panel = self._generate_prompt_output_layout()
self.layout["prompt_output"].update(self.prompt_output_panel)
self.layout["prompt_output"].visible = True
self.prompt_output_panel = self._generate_prompt_output_layout()
self.layout["prompt_output"].update(self.prompt_output_panel)
self.layout["prompt_output"].visible = True
else:
self.layout["prompt_output"].visible = False
self.layout["prompt_output"].visible = False
# Only show download_panel if there are in-progress downloads
if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
@@ -97,33 +95,33 @@ class TopologyViz:
max_lines = 13 # Maximum number of lines for the entire panel content
for (prompt, output) in reversed(requests):
prompt_icon, output_icon = "💬️", "🤖"
prompt_icon, output_icon = "💬️", "🤖"
# Process prompt
prompt_lines = prompt.split('\n')
if len(prompt_lines) > max_lines // 2:
prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
# Process prompt
prompt_lines = prompt.split('\n')
if len(prompt_lines) > max_lines // 2:
prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
# Process output
output_lines = output.split('\n')
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
if len(output_lines) > remaining_lines:
output_lines = output_lines[:remaining_lines - 1] + ['...']
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
# Process output
output_lines = output.split('\n')
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
if len(output_lines) > remaining_lines:
output_lines = output_lines[:remaining_lines - 1] + ['...']
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
content.append(prompt_text)
content.append(output_text)
content.append(Text()) # Empty line between entries
content.append(prompt_text)
content.append(output_text)
content.append(Text()) # Empty line between entries
return Panel(
Group(*content),
title="",
border_style="cyan",
height=15, # Increased height to accommodate multiple lines
expand=True # Allow the panel to expand to full width
Group(*content),
title="",
border_style="cyan",
height=15, # Increased height to accommodate multiple lines
expand=True # Allow the panel to expand to full width
)
def _generate_main_layout(self) -> str:
@@ -185,14 +183,14 @@ class TopologyViz:
visualization[bar_y][bar_start_x + i] = segment
# Add labels
visualization[bar_y - 1][bar_start_x - 10 : bar_start_x - 3] = "GPU poor"
visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2 : bar_start_x + bar_width * 2 + 11] = "GPU rich"
visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2:bar_start_x + bar_width * 2 + 11] = "GPU rich"
# Add position indicator and FLOPS value
pos_x = bar_start_x + int(bar_pos * bar_width)
flops_str = f"{total_flops:.2f} TFLOPS"
visualization[bar_y - 1][pos_x] = ""
visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
visualization[bar_y + 2][pos_x] = ""
# Add an extra empty line for spacing
@@ -270,41 +268,41 @@ class TopologyViz:
# Current node download progress
if self.node_id in self.node_download_progress:
download_progress = self.node_download_progress[self.node_id]
title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
summary.add_row(Text(title, style="bold"))
progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
summary.add_row(progress_info)
download_progress = self.node_download_progress[self.node_id]
title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
summary.add_row(Text(title, style="bold"))
progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
summary.add_row(progress_info)
eta_info = f"{download_progress.overall_eta}"
summary.add_row(eta_info)
eta_info = f"{download_progress.overall_eta}"
summary.add_row(eta_info)
summary.add_row("") # Empty row for spacing
summary.add_row("") # Empty row for spacing
for file_path, file_progress in download_progress.file_progress.items():
if file_progress.status != "complete":
progress = int(file_progress.downloaded / file_progress.total * 30)
bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
for file_path, file_progress in download_progress.file_progress.items():
if file_progress.status != "complete":
progress = int(file_progress.downloaded / file_progress.total * 30)
bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
summary.add_row("") # Empty row for spacing
# Other nodes download progress summary
summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
for node_id, progress in self.node_download_progress.items():
if node_id != self.node_id:
device = self.topology.nodes.get(node_id)
partition = next((p for p in self.partitions if p.node_id == node_id), None)
partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
speed = pretty_print_bytes_per_second(progress.overall_speed)
device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
percentage_str = f"{percentage:.1f}%"
eta_str = f"{progress.overall_eta}"
summary.add_row(device_info, progress_info, percentage_str)
summary.add_row("", progress_bar, eta_str)
if node_id != self.node_id:
device = self.topology.nodes.get(node_id)
partition = next((p for p in self.partitions if p.node_id == node_id), None)
partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
speed = pretty_print_bytes_per_second(progress.overall_speed)
device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
percentage_str = f"{percentage:.1f}%"
eta_str = f"{progress.overall_eta}"
summary.add_row(device_info, progress_info, percentage_str)
summary.add_row("", progress_bar, eta_str)
return summary
return summary

View File

@@ -3,51 +3,49 @@ import asyncio
from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
DEFAULT_ALLOW_PATTERNS = [
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
"*.safetensors",
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
"*.safetensors",
]
# Always ignore `.git` and `.cache/huggingface` folders in commits
DEFAULT_IGNORE_PATTERNS = [
".git",
".git/*",
"*/.git",
"**/.git/**",
".cache/huggingface",
".cache/huggingface/*",
"*/.cache/huggingface",
"**/.cache/huggingface/**",
".git",
".git/*",
"*/.git",
"**/.git/**",
".cache/huggingface",
".cache/huggingface/*",
"*/.cache/huggingface",
"**/.cache/huggingface/**",
]
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
async def progress_callback(event: RepoProgressEvent):
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
print(f"Estimated time remaining: {event.overall_eta}")
print("File Progress:")
for file_path, progress in event.file_progress.items():
status_icon = {
'not_started': '',
'in_progress': '🔵',
'complete': ''
}[progress.status]
eta_str = str(progress.eta)
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
print("\n")
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
async def progress_callback(event: RepoProgressEvent):
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
print(f"Estimated time remaining: {event.overall_eta}")
print("File Progress:")
for file_path, progress in event.file_progress.items():
status_icon = {'not_started': '', 'in_progress': '🔵', 'complete': ''}[progress.status]
eta_str = str(progress.eta)
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
print("\n")
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
args = parser.parse_args()
args = parser.parse_args()
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

207
main.py
View File

@@ -53,131 +53,146 @@ inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
if args.node_port is None:
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
args.node_id = args.node_id or get_or_create_node_id()
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
if DEBUG >= 0:
print("Chat interface started:")
for web_chat_url in web_chat_urls:
print(f" - {terminal_link(web_chat_url)}")
print("ChatGPT API endpoint served at:")
for chatgpt_api_endpoint in chatgpt_api_endpoints:
print(f" - {terminal_link(chatgpt_api_endpoint)}")
print("Chat interface started:")
for web_chat_url in web_chat_urls:
print(f" - {terminal_link(web_chat_url)}")
print("ChatGPT API endpoint served at:")
for chatgpt_api_endpoint in chatgpt_api_endpoints:
print(f" - {terminal_link(chatgpt_api_endpoint)}")
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
node = StandardNode(
args.node_id,
None,
inference_engine,
discovery,
chatgpt_api_endpoints=chatgpt_api_endpoints,
web_chat_urls=web_chat_urls,
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
disable_tui=args.disable_tui,
max_generate_tokens=args.max_generate_tokens,
topology_viz=topology_viz
args.node_id,
None,
inference_engine,
discovery,
chatgpt_api_endpoints=chatgpt_api_endpoints,
web_chat_urls=web_chat_urls,
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
disable_tui=args.disable_tui,
max_generate_tokens=args.max_generate_tokens,
topology_viz=topology_viz
)
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
api = ChatGPTAPI(
node,
inference_engine.__class__.__name__,
response_timeout_secs=args.chatgpt_api_response_timeout_secs,
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id,
inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
)
def preemptively_start_download(request_id: str, opaque_status: str):
try:
status = json.loads(opaque_status)
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(shard_downloader.ensure_shard(current_shard))
except Exception as e:
if DEBUG >= 2:
print(f"Failed to preemptively start download: {e}")
traceback.print_exc()
try:
status = json.loads(opaque_status)
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(shard_downloader.ensure_shard(current_shard))
except Exception as e:
if DEBUG >= 2:
print(f"Failed to preemptively start download: {e}")
traceback.print_exc()
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
if args.prometheus_client_port:
from exo.stats.metrics import start_metrics_server
start_metrics_server(node, args.prometheus_client_port)
from exo.stats.metrics import start_metrics_server
start_metrics_server(node, args.prometheus_client_port)
last_broadcast_time = 0
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
global last_broadcast_time
current_time = time.time()
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
last_broadcast_time = current_time
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
global last_broadcast_time
current_time = time.time()
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
last_broadcast_time = current_time
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
print("Thank you for using exo.")
print_yellow_exo()
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.stop()
loop.stop()
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
print("Thank you for using exo.")
print_yellow_exo()
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.stop()
loop.stop()
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
if not shard:
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
return
tokenizer = await resolve_tokenizer(shard.model_id)
request_id = str(uuid.uuid4())
callback_id = f"cli-wait-response-{request_id}"
callback = node.on_token.register(callback_id)
if topology_viz:
topology_viz.update_prompt(request_id, prompt)
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
if not shard:
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
return
tokenizer = await resolve_tokenizer(shard.model_id)
request_id = str(uuid.uuid4())
callback_id = f"cli-wait-response-{request_id}"
callback = node.on_token.register(callback_id)
if topology_viz:
topology_viz.update_prompt(request_id, prompt)
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
try:
print(f"Processing prompt: {prompt}")
await node.process_prompt(shard, prompt, None, request_id=request_id)
try:
print(f"Processing prompt: {prompt}")
await node.process_prompt(shard, prompt, None, request_id=request_id)
_, tokens, _ = await callback.wait(
lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
timeout=300
)
_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
print("\nGenerated response:")
print(tokenizer.decode(tokens))
except Exception as e:
print(f"Error processing prompt: {str(e)}")
traceback.print_exc()
finally:
node.on_token.deregister(callback_id)
print("\nGenerated response:")
print(tokenizer.decode(tokens))
except Exception as e:
print(f"Error processing prompt: {str(e)}")
traceback.print_exc()
finally:
node.on_token.deregister(callback_id)
async def main():
loop = asyncio.get_running_loop()
loop = asyncio.get_running_loop()
# Use a more direct approach to handle signals
def handle_exit():
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
# Use a more direct approach to handle signals
def handle_exit():
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, handle_exit)
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, handle_exit)
await node.start(wait_for_peers=args.wait_for_peers)
await node.start(wait_for_peers=args.wait_for_peers)
if args.run_model:
await run_model_cli(node, inference_engine, args.run_model, args.prompt)
else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()
if args.run_model:
await run_model_cli(node, inference_engine, args.run_model, args.prompt)
else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
loop.close()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
loop.close()