mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
reformat with yapf format.py
This commit is contained in:
@@ -13,8 +13,8 @@ import argparse
|
||||
import uuid
|
||||
|
||||
models = {
|
||||
"mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
|
||||
"mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
|
||||
}
|
||||
|
||||
path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
|
||||
@@ -29,60 +29,53 @@ tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
# "localhost:8080",
|
||||
# DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
|
||||
# )
|
||||
peer2 = GRPCPeerHandle(
|
||||
"node2",
|
||||
"localhost:8081",
|
||||
DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
|
||||
)
|
||||
peer2 = GRPCPeerHandle("node2", "localhost:8081", DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
|
||||
shard = models[path_or_hf_repo]
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
async def run_prompt(prompt: str):
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
if (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
if (hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
await peer2.connect()
|
||||
await peer2.connect()
|
||||
|
||||
try:
|
||||
await peer2.send_prompt(shard, prompt, request_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
import time
|
||||
# poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
|
||||
previous_length = 0
|
||||
n_tokens = 0
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
try:
|
||||
await peer2.send_prompt(shard, prompt, request_id)
|
||||
result, is_finished = await peer2.get_inference_result(request_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
import time
|
||||
# poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
|
||||
previous_length = 0
|
||||
n_tokens = 0
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
try:
|
||||
result, is_finished = await peer2.get_inference_result(request_id)
|
||||
except Exception as e:
|
||||
continue
|
||||
await asyncio.sleep(0.1)
|
||||
# Print the updated string in place
|
||||
updated_string = tokenizer.decode(result)
|
||||
n_tokens = len(result)
|
||||
print(updated_string[previous_length:], end='', flush=True)
|
||||
previous_length = len(updated_string)
|
||||
|
||||
# Print the updated string in place
|
||||
updated_string = tokenizer.decode(result)
|
||||
n_tokens = len(result)
|
||||
print(updated_string[previous_length:], end='', flush=True)
|
||||
previous_length = len(updated_string)
|
||||
if is_finished:
|
||||
print("\nDone")
|
||||
break
|
||||
end_time = time.perf_counter()
|
||||
print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
|
||||
|
||||
if is_finished:
|
||||
print("\nDone")
|
||||
break
|
||||
end_time = time.perf_counter()
|
||||
print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run prompt")
|
||||
parser.add_argument("--prompt", type=str, help="The prompt to run")
|
||||
args = parser.parse_args()
|
||||
parser = argparse.ArgumentParser(description="Run prompt")
|
||||
parser.add_argument("--prompt", type=str, help="The prompt to run")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_prompt(args.prompt))
|
||||
asyncio.run(run_prompt(args.prompt))
|
||||
|
||||
@@ -16,29 +16,27 @@ from exo.orchestration import Node
|
||||
from exo.models import model_base_shards
|
||||
from typing import Callable
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content
|
||||
}
|
||||
class Message:
|
||||
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
def to_dict(self):
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"model": self.model,
|
||||
"messages": [message.to_dict() for message in self.messages],
|
||||
"temperature": self.temperature
|
||||
}
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
|
||||
def to_dict(self):
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
|
||||
|
||||
|
||||
def generate_completion(
|
||||
chat_request: ChatCompletionRequest,
|
||||
@@ -56,14 +54,12 @@ def generate_completion(
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"system_fingerprint": f"exo_{VERSION}",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": tokenizer.decode(tokens)},
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": tokenizer.decode(tokens)},
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
}],
|
||||
}
|
||||
|
||||
if not stream:
|
||||
@@ -86,37 +82,38 @@ def generate_completion(
|
||||
|
||||
|
||||
def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
remapped_messages = []
|
||||
last_image = None
|
||||
for message in messages:
|
||||
if not isinstance(message.content, list):
|
||||
remapped_messages.append(message)
|
||||
continue
|
||||
remapped_messages = []
|
||||
last_image = None
|
||||
for message in messages:
|
||||
if not isinstance(message.content, list):
|
||||
remapped_messages.append(message)
|
||||
continue
|
||||
|
||||
remapped_content = []
|
||||
for content in message.content:
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") in ["image_url", "image"]:
|
||||
image_url = content.get("image_url", {}).get("url") or content.get("image")
|
||||
if image_url:
|
||||
last_image = {"type": "image", "image": image_url}
|
||||
remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
|
||||
else:
|
||||
remapped_content.append(content)
|
||||
else:
|
||||
remapped_content.append(content)
|
||||
remapped_messages.append(Message(role=message.role, content=remapped_content))
|
||||
remapped_content = []
|
||||
for content in message.content:
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") in ["image_url", "image"]:
|
||||
image_url = content.get("image_url", {}).get("url") or content.get("image")
|
||||
if image_url:
|
||||
last_image = {"type": "image", "image": image_url}
|
||||
remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
|
||||
else:
|
||||
remapped_content.append(content)
|
||||
else:
|
||||
remapped_content.append(content)
|
||||
remapped_messages.append(Message(role=message.role, content=remapped_content))
|
||||
|
||||
if last_image:
|
||||
# Replace the last image placeholder with the actual image content
|
||||
for message in reversed(remapped_messages):
|
||||
for i, content in enumerate(message.content):
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
|
||||
message.content[i] = last_image
|
||||
return remapped_messages
|
||||
if last_image:
|
||||
# Replace the last image placeholder with the actual image content
|
||||
for message in reversed(remapped_messages):
|
||||
for i, content in enumerate(message.content):
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
|
||||
message.content[i] = last_image
|
||||
return remapped_messages
|
||||
|
||||
return remapped_messages
|
||||
|
||||
return remapped_messages
|
||||
|
||||
def build_prompt(tokenizer, _messages: List[Message]):
|
||||
messages = remap_messages(_messages)
|
||||
@@ -149,13 +146,17 @@ def parse_chat_request(data: dict):
|
||||
data.get("temperature", 0.0),
|
||||
)
|
||||
|
||||
|
||||
class PromptSession:
|
||||
|
||||
def __init__(self, request_id: str, timestamp: int, prompt: str):
|
||||
self.request_id = request_id
|
||||
self.timestamp = timestamp
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ChatGPTAPI:
|
||||
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
|
||||
self.node = node
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
@@ -182,6 +183,7 @@ class ChatGPTAPI:
|
||||
self.app.middlewares.append(self.log_request)
|
||||
|
||||
async def log_request(self, app, handler):
|
||||
|
||||
async def middleware(request):
|
||||
if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
|
||||
return await handler(request)
|
||||
@@ -268,7 +270,8 @@ class ChatGPTAPI:
|
||||
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
|
||||
new_tokens = tokens[prev_last_tokens_len:]
|
||||
finish_reason = None
|
||||
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
|
||||
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
|
||||
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
|
||||
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
|
||||
new_tokens = new_tokens[:-1]
|
||||
if is_finished:
|
||||
|
||||
@@ -2,81 +2,67 @@ from typing import Dict, Callable, Coroutine, Any, Literal
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepoFileProgressEvent:
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: int
|
||||
downloaded_this_session: int
|
||||
total: int
|
||||
speed: int
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: int
|
||||
downloaded_this_session: int
|
||||
total: int
|
||||
speed: int
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id,
|
||||
"repo_revision": self.repo_revision,
|
||||
"file_path": self.file_path,
|
||||
"downloaded": self.downloaded,
|
||||
"downloaded_this_session": self.downloaded_this_session,
|
||||
"total": self.total,
|
||||
"speed": self.speed,
|
||||
"eta": self.eta.total_seconds(),
|
||||
"status": self.status
|
||||
}
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
|
||||
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
# Convert eta from seconds back to timedelta
|
||||
if 'eta' in data:
|
||||
data['eta'] = timedelta(seconds=data['eta'])
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
# Convert eta from seconds back to timedelta
|
||||
if 'eta' in data:
|
||||
data['eta'] = timedelta(seconds=data['eta'])
|
||||
return cls(**data)
|
||||
|
||||
@dataclass
|
||||
class RepoProgressEvent:
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: int
|
||||
downloaded_bytes_this_session: int
|
||||
total_bytes: int
|
||||
overall_speed: int
|
||||
overall_eta: timedelta
|
||||
file_progress: Dict[str, RepoFileProgressEvent]
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: int
|
||||
downloaded_bytes_this_session: int
|
||||
total_bytes: int
|
||||
overall_speed: int
|
||||
overall_eta: timedelta
|
||||
file_progress: Dict[str, RepoFileProgressEvent]
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id,
|
||||
"repo_revision": self.repo_revision,
|
||||
"completed_files": self.completed_files,
|
||||
"total_files": self.total_files,
|
||||
"downloaded_bytes": self.downloaded_bytes,
|
||||
"downloaded_bytes_this_session": self.downloaded_bytes_this_session,
|
||||
"total_bytes": self.total_bytes,
|
||||
"overall_speed": self.overall_speed,
|
||||
"overall_eta": self.overall_eta.total_seconds(),
|
||||
"file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
|
||||
"status": self.status
|
||||
}
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
|
||||
"downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
|
||||
"file_progress": {k: v.to_dict()
|
||||
for k, v in self.file_progress.items()}, "status": self.status
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
# Convert overall_eta from seconds back to timedelta
|
||||
if 'overall_eta' in data:
|
||||
data['overall_eta'] = timedelta(seconds=data['overall_eta'])
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
# Convert overall_eta from seconds back to timedelta
|
||||
if 'overall_eta' in data:
|
||||
data['overall_eta'] = timedelta(seconds=data['overall_eta'])
|
||||
|
||||
# Parse file_progress
|
||||
if 'file_progress' in data:
|
||||
data['file_progress'] = {
|
||||
k: RepoFileProgressEvent.from_dict(v)
|
||||
for k, v in data['file_progress'].items()
|
||||
}
|
||||
# Parse file_progress
|
||||
if 'file_progress' in data:
|
||||
data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
|
||||
|
||||
return cls(**data)
|
||||
|
||||
return cls(**data)
|
||||
|
||||
RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
|
||||
RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]
|
||||
|
||||
@@ -16,282 +16,322 @@ import aiofiles
|
||||
from aiofiles import os as aios
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_repo_objects(
|
||||
items: Iterable[T],
|
||||
*,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
key: Optional[Callable[[T], str]] = None,
|
||||
items: Iterable[T],
|
||||
*,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
key: Optional[Callable[[T], str]] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
if isinstance(allow_patterns, str):
|
||||
allow_patterns = [allow_patterns]
|
||||
if isinstance(ignore_patterns, str):
|
||||
ignore_patterns = [ignore_patterns]
|
||||
if allow_patterns is not None:
|
||||
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
|
||||
if ignore_patterns is not None:
|
||||
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
|
||||
if isinstance(allow_patterns, str):
|
||||
allow_patterns = [allow_patterns]
|
||||
if isinstance(ignore_patterns, str):
|
||||
ignore_patterns = [ignore_patterns]
|
||||
if allow_patterns is not None:
|
||||
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
|
||||
if ignore_patterns is not None:
|
||||
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
|
||||
|
||||
if key is None:
|
||||
def _identity(item: T) -> str:
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, Path):
|
||||
return str(item)
|
||||
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
|
||||
key = _identity
|
||||
if key is None:
|
||||
|
||||
def _identity(item: T) -> str:
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, Path):
|
||||
return str(item)
|
||||
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
|
||||
|
||||
key = _identity
|
||||
|
||||
for item in items:
|
||||
path = key(item)
|
||||
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
|
||||
continue
|
||||
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
|
||||
continue
|
||||
yield item
|
||||
|
||||
for item in items:
|
||||
path = key(item)
|
||||
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
|
||||
continue
|
||||
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
|
||||
continue
|
||||
yield item
|
||||
|
||||
def _add_wildcard_to_directories(pattern: str) -> str:
|
||||
if pattern[-1] == "/":
|
||||
return pattern + "*"
|
||||
return pattern
|
||||
if pattern[-1] == "/":
|
||||
return pattern + "*"
|
||||
return pattern
|
||||
|
||||
|
||||
def get_hf_home() -> Path:
|
||||
"""Get the Hugging Face home directory."""
|
||||
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
|
||||
"""Get the Hugging Face home directory."""
|
||||
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
|
||||
|
||||
|
||||
async def get_hf_token():
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
token_path = get_hf_home() / "token"
|
||||
if await aios.path.exists(token_path):
|
||||
async with aiofiles.open(token_path, 'r') as f:
|
||||
return (await f.read()).strip()
|
||||
return None
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
token_path = get_hf_home() / "token"
|
||||
if await aios.path.exists(token_path):
|
||||
async with aiofiles.open(token_path, 'r') as f:
|
||||
return (await f.read()).strip()
|
||||
return None
|
||||
|
||||
|
||||
async def get_auth_headers():
|
||||
"""Get authentication headers if a token is available."""
|
||||
token = await get_hf_token()
|
||||
if token:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
return {}
|
||||
"""Get authentication headers if a token is available."""
|
||||
token = await get_hf_token()
|
||||
if token:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
return {}
|
||||
|
||||
|
||||
def get_repo_root(repo_id: str) -> Path:
|
||||
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
||||
sanitized_repo_id = repo_id.replace("/", "--")
|
||||
return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
|
||||
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
||||
sanitized_repo_id = repo_id.replace("/", "--")
|
||||
return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
|
||||
|
||||
|
||||
async def fetch_file_list(session, repo_id, revision, path=""):
|
||||
api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
files = []
|
||||
for item in data:
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
files = []
|
||||
for item in data:
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
|
||||
reraise=True
|
||||
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
|
||||
)
|
||||
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
|
||||
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, file_path)
|
||||
local_path = os.path.join(save_directory, file_path)
|
||||
async def download_file(
|
||||
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
|
||||
):
|
||||
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, file_path)
|
||||
local_path = os.path.join(save_directory, file_path)
|
||||
|
||||
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
# Check if file already exists and get its size
|
||||
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
|
||||
# Check if file already exists and get its size
|
||||
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
|
||||
|
||||
headers = await get_auth_headers()
|
||||
if use_range_request:
|
||||
headers["Range"] = f"bytes={local_file_size}-"
|
||||
headers = await get_auth_headers()
|
||||
if use_range_request:
|
||||
headers["Range"] = f"bytes={local_file_size}-"
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded_size = local_file_size
|
||||
downloaded_this_session = 0
|
||||
mode = 'ab' if use_range_request else 'wb'
|
||||
async with session.get(url, headers=headers) as response:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded_size = local_file_size
|
||||
downloaded_this_session = 0
|
||||
mode = 'ab' if use_range_request else 'wb'
|
||||
if downloaded_size == total_size:
|
||||
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
if response.status == 200:
|
||||
# File doesn't support range requests or we're not using them, start from beginning
|
||||
mode = 'wb'
|
||||
downloaded_size = 0
|
||||
elif response.status == 206:
|
||||
# Partial content, resume download
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable, get the actual file size
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
if downloaded_size == total_size:
|
||||
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
if response.status == 200:
|
||||
# File doesn't support range requests or we're not using them, start from beginning
|
||||
mode = 'wb'
|
||||
downloaded_size = 0
|
||||
elif response.status == 206:
|
||||
# Partial content, resume download
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable, get the actual file size
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
if downloaded_size == total_size:
|
||||
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
else:
|
||||
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
|
||||
|
||||
if downloaded_size == total_size:
|
||||
print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
DOWNLOAD_CHUNK_SIZE = 32768
|
||||
start_time = datetime.now()
|
||||
async with aiofiles.open(local_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
||||
await f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
downloaded_this_session += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_size = total_size - downloaded_size
|
||||
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
|
||||
status = "in_progress" if downloaded_size < total_size else "complete"
|
||||
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
||||
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
||||
|
||||
async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
|
||||
repo_root = get_repo_root(repo_id)
|
||||
refs_dir = repo_root / "refs"
|
||||
snapshots_dir = repo_root / "snapshots"
|
||||
cachedreqs_dir = repo_root / "cachedreqs"
|
||||
|
||||
# Ensure directories exist
|
||||
await aios.makedirs(refs_dir, exist_ok=True)
|
||||
await aios.makedirs(snapshots_dir, exist_ok=True)
|
||||
await aios.makedirs(cachedreqs_dir, exist_ok=True)
|
||||
|
||||
# Check if we have a cached commit hash
|
||||
refs_file = refs_dir / revision
|
||||
if await aios.path.exists(refs_file):
|
||||
async with aiofiles.open(refs_file, 'r') as f:
|
||||
commit_hash = (await f.read()).strip()
|
||||
if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
|
||||
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
else:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Fetch the commit hash for the given revision
|
||||
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(api_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
||||
revision_info = await response.json()
|
||||
commit_hash = revision_info['sha']
|
||||
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
|
||||
|
||||
# Cache the commit hash
|
||||
async with aiofiles.open(refs_file, 'w') as f:
|
||||
await f.write(commit_hash)
|
||||
if downloaded_size == total_size:
|
||||
print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
# Set up the snapshot directory
|
||||
snapshot_dir = snapshots_dir / commit_hash
|
||||
await aios.makedirs(snapshot_dir, exist_ok=True)
|
||||
DOWNLOAD_CHUNK_SIZE = 32768
|
||||
start_time = datetime.now()
|
||||
async with aiofiles.open(local_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
||||
await f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
downloaded_this_session += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_size = total_size - downloaded_size
|
||||
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
|
||||
status = "in_progress" if downloaded_size < total_size else "complete"
|
||||
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
||||
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
||||
|
||||
# Set up the cached file list directory
|
||||
cached_file_list_dir = cachedreqs_dir / commit_hash
|
||||
await aios.makedirs(cached_file_list_dir, exist_ok=True)
|
||||
cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
|
||||
|
||||
async def download_repo_files(
|
||||
repo_id: str,
|
||||
revision: str = "main",
|
||||
progress_callback: Optional[RepoProgressCallback] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_parallel_downloads: int = 4
|
||||
) -> Path:
|
||||
repo_root = get_repo_root(repo_id)
|
||||
refs_dir = repo_root / "refs"
|
||||
snapshots_dir = repo_root / "snapshots"
|
||||
cachedreqs_dir = repo_root / "cachedreqs"
|
||||
|
||||
# Ensure directories exist
|
||||
await aios.makedirs(refs_dir, exist_ok=True)
|
||||
await aios.makedirs(snapshots_dir, exist_ok=True)
|
||||
await aios.makedirs(cachedreqs_dir, exist_ok=True)
|
||||
|
||||
# Check if we have a cached commit hash
|
||||
refs_file = refs_dir / revision
|
||||
if await aios.path.exists(refs_file):
|
||||
async with aiofiles.open(refs_file, 'r') as f:
|
||||
commit_hash = (await f.read()).strip()
|
||||
if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
|
||||
else:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Check if we have a cached file list
|
||||
if await aios.path.exists(cached_file_list_path):
|
||||
async with aiofiles.open(cached_file_list_path, 'r') as f:
|
||||
file_list = json.loads(await f.read())
|
||||
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
|
||||
else:
|
||||
file_list = await fetch_file_list(session, repo_id, revision)
|
||||
# Cache the file list
|
||||
async with aiofiles.open(cached_file_list_path, 'w') as f:
|
||||
await f.write(json.dumps(file_list))
|
||||
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
|
||||
# Fetch the commit hash for the given revision
|
||||
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(api_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
||||
revision_info = await response.json()
|
||||
commit_hash = revision_info['sha']
|
||||
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
||||
total_files = len(filtered_file_list)
|
||||
total_bytes = sum(file["size"] for file in filtered_file_list)
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
|
||||
start_time = datetime.now()
|
||||
# Cache the commit hash
|
||||
async with aiofiles.open(refs_file, 'w') as f:
|
||||
await f.write(commit_hash)
|
||||
|
||||
async def download_with_progress(file_info, progress_state):
|
||||
local_path = snapshot_dir / file_info["path"]
|
||||
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
|
||||
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
|
||||
progress_state['completed_files'] += 1
|
||||
progress_state['downloaded_bytes'] += file_info["size"]
|
||||
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
||||
return
|
||||
# Set up the snapshot directory
|
||||
snapshot_dir = snapshots_dir / commit_hash
|
||||
await aios.makedirs(snapshot_dir, exist_ok=True)
|
||||
|
||||
async def file_progress_callback(event: RepoFileProgressEvent):
|
||||
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
|
||||
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
|
||||
file_progress[event.file_path] = event
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
|
||||
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
||||
# Set up the cached file list directory
|
||||
cached_file_list_dir = cachedreqs_dir / commit_hash
|
||||
await aios.makedirs(cached_file_list_dir, exist_ok=True)
|
||||
cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
|
||||
|
||||
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
||||
progress_state['completed_files'] += 1
|
||||
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Check if we have a cached file list
|
||||
if await aios.path.exists(cached_file_list_path):
|
||||
async with aiofiles.open(cached_file_list_path, 'r') as f:
|
||||
file_list = json.loads(await f.read())
|
||||
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
|
||||
else:
|
||||
file_list = await fetch_file_list(session, repo_id, revision)
|
||||
# Cache the file list
|
||||
async with aiofiles.open(cached_file_list_path, 'w') as f:
|
||||
await f.write(json.dumps(file_list))
|
||||
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
|
||||
|
||||
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
||||
total_files = len(filtered_file_list)
|
||||
total_bytes = sum(file["size"] for file in filtered_file_list)
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {
|
||||
file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
|
||||
for file in filtered_file_list
|
||||
}
|
||||
start_time = datetime.now()
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
async def download_with_semaphore(file_info):
|
||||
async with semaphore:
|
||||
await download_with_progress(file_info, progress_state)
|
||||
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
|
||||
await asyncio.gather(*tasks)
|
||||
async def download_with_progress(file_info, progress_state):
|
||||
local_path = snapshot_dir / file_info["path"]
|
||||
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
|
||||
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
|
||||
progress_state['completed_files'] += 1
|
||||
progress_state['downloaded_bytes'] += file_info["size"]
|
||||
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
async def file_progress_callback(event: RepoFileProgressEvent):
|
||||
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
|
||||
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
|
||||
file_progress[event.file_path] = event
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
|
||||
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
||||
progress_state['completed_files'] += 1
|
||||
file_progress[
|
||||
file_info["path"]
|
||||
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
|
||||
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
|
||||
async def download_with_semaphore(file_info):
|
||||
async with semaphore:
|
||||
await download_with_progress(file_info, progress_state)
|
||||
|
||||
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return snapshot_dir
|
||||
|
||||
return snapshot_dir
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
"""
|
||||
Retrieve the weight map from the model.safetensors.index.json file.
|
||||
|
||||
Args:
|
||||
@@ -302,55 +342,52 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
|
||||
Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
|
||||
"""
|
||||
|
||||
# Download the index file
|
||||
await download_repo_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
allow_patterns="model.safetensors.index.json"
|
||||
)
|
||||
# Download the index file
|
||||
await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
|
||||
|
||||
# Check if the file exists
|
||||
repo_root = get_repo_root(repo_id)
|
||||
snapshot_dir = repo_root / "snapshots"
|
||||
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
|
||||
# Check if the file exists
|
||||
repo_root = get_repo_root(repo_id)
|
||||
snapshot_dir = repo_root / "snapshots"
|
||||
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
|
||||
|
||||
if index_file:
|
||||
index_file_path = snapshot_dir / index_file
|
||||
if await aios.path.exists(index_file_path):
|
||||
async with aiofiles.open(index_file_path, 'r') as f:
|
||||
index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
if index_file:
|
||||
index_file_path = snapshot_dir / index_file
|
||||
if await aios.path.exists(index_file_path):
|
||||
async with aiofiles.open(index_file_path, 'r') as f:
|
||||
index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def extract_layer_num(tensor_name: str) -> Optional[int]:
|
||||
# This is a simple example and might need to be adjusted based on the actual naming convention
|
||||
parts = tensor_name.split('.')
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
return int(part)
|
||||
return None
|
||||
# This is a simple example and might need to be adjusted based on the actual naming convention
|
||||
parts = tensor_name.split('.')
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
return int(part)
|
||||
return None
|
||||
|
||||
|
||||
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
default_patterns = [
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
]
|
||||
shard_specific_patterns = []
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
|
||||
shard_specific_patterns.append(filename)
|
||||
sorted_file_names = sorted(weight_map.values())
|
||||
if shard.is_first_layer():
|
||||
shard_specific_patterns.append(sorted_file_names[0])
|
||||
elif shard.is_last_layer():
|
||||
shard_specific_patterns.append(sorted_file_names[-1])
|
||||
else:
|
||||
shard_specific_patterns = ["*.safetensors"]
|
||||
return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates
|
||||
default_patterns = [
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
]
|
||||
shard_specific_patterns = []
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
|
||||
shard_specific_patterns.append(filename)
|
||||
sorted_file_names = sorted(weight_map.values())
|
||||
if shard.is_first_layer():
|
||||
shard_specific_patterns.append(sorted_file_names[0])
|
||||
elif shard.is_last_layer():
|
||||
shard_specific_patterns.append(sorted_file_names[-1])
|
||||
else:
|
||||
shard_specific_patterns = ["*.safetensors"]
|
||||
return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates
|
||||
|
||||
@@ -8,72 +8,70 @@ from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
|
||||
from exo.helpers import AsyncCallbackSystem, DEBUG
|
||||
|
||||
|
||||
class HFShardDownloader(ShardDownloader):
|
||||
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
|
||||
self.quick_check = quick_check
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
||||
self.completed_downloads: Dict[Shard, Path] = {}
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
|
||||
async def ensure_shard(self, shard: Shard) -> Path:
|
||||
if shard in self.completed_downloads:
|
||||
return self.completed_downloads[shard]
|
||||
if self.quick_check:
|
||||
repo_root = get_repo_root(shard.model_id)
|
||||
snapshots_dir = repo_root / "snapshots"
|
||||
if snapshots_dir.exists():
|
||||
most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
|
||||
return most_recent_dir
|
||||
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
|
||||
self.quick_check = quick_check
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
||||
self.completed_downloads: Dict[Shard, Path] = {}
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
|
||||
# If a download on this shard is already in progress, keep that one
|
||||
for active_shard in self.active_downloads:
|
||||
if active_shard == shard:
|
||||
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
|
||||
return await self.active_downloads[shard]
|
||||
async def ensure_shard(self, shard: Shard) -> Path:
|
||||
if shard in self.completed_downloads:
|
||||
return self.completed_downloads[shard]
|
||||
if self.quick_check:
|
||||
repo_root = get_repo_root(shard.model_id)
|
||||
snapshots_dir = repo_root / "snapshots"
|
||||
if snapshots_dir.exists():
|
||||
most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
|
||||
return most_recent_dir
|
||||
|
||||
# Cancel any downloads for this model_id on a different shard
|
||||
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
|
||||
for active_shard in existing_active_shards:
|
||||
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
|
||||
task = self.active_downloads[active_shard]
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass # This is expected when cancelling a task
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
|
||||
traceback.print_exc()
|
||||
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
|
||||
# If a download on this shard is already in progress, keep that one
|
||||
for active_shard in self.active_downloads:
|
||||
if active_shard == shard:
|
||||
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
|
||||
return await self.active_downloads[shard]
|
||||
|
||||
# Start new download
|
||||
download_task = asyncio.create_task(self._download_shard(shard))
|
||||
self.active_downloads[shard] = download_task
|
||||
try:
|
||||
path = await download_task
|
||||
self.completed_downloads[shard] = path
|
||||
return path
|
||||
finally:
|
||||
# Ensure the task is removed even if an exception occurs
|
||||
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
|
||||
if shard in self.active_downloads:
|
||||
self.active_downloads.pop(shard)
|
||||
# Cancel any downloads for this model_id on a different shard
|
||||
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
|
||||
for active_shard in existing_active_shards:
|
||||
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
|
||||
task = self.active_downloads[active_shard]
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass # This is expected when cancelling a task
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
|
||||
traceback.print_exc()
|
||||
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
|
||||
|
||||
async def _download_shard(self, shard: Shard) -> Path:
|
||||
async def wrapped_progress_callback(event: RepoProgressEvent):
|
||||
self._on_progress.trigger_all(shard, event)
|
||||
# Start new download
|
||||
download_task = asyncio.create_task(self._download_shard(shard))
|
||||
self.active_downloads[shard] = download_task
|
||||
try:
|
||||
path = await download_task
|
||||
self.completed_downloads[shard] = path
|
||||
return path
|
||||
finally:
|
||||
# Ensure the task is removed even if an exception occurs
|
||||
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
|
||||
if shard in self.active_downloads:
|
||||
self.active_downloads.pop(shard)
|
||||
|
||||
weight_map = await get_weight_map(shard.model_id)
|
||||
allow_patterns = get_allow_patterns(weight_map, shard)
|
||||
async def _download_shard(self, shard: Shard) -> Path:
|
||||
|
||||
return await download_repo_files(
|
||||
repo_id=shard.model_id,
|
||||
progress_callback=wrapped_progress_callback,
|
||||
allow_patterns=allow_patterns,
|
||||
max_parallel_downloads=self.max_parallel_downloads
|
||||
)
|
||||
async def wrapped_progress_callback(event: RepoProgressEvent):
|
||||
self._on_progress.trigger_all(shard, event)
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self._on_progress
|
||||
weight_map = await get_weight_map(shard.model_id)
|
||||
allow_patterns = get_allow_patterns(weight_map, shard)
|
||||
|
||||
return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self._on_progress
|
||||
|
||||
@@ -5,10 +5,12 @@ from exo.inference.shard import Shard
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.helpers import AsyncCallbackSystem
|
||||
|
||||
|
||||
class ShardDownloader(ABC):
|
||||
@abstractmethod
|
||||
async def ensure_shard(self, shard: Shard) -> Path:
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def ensure_shard(self, shard: Shard) -> Path:
|
||||
"""
|
||||
Ensures that the shard is downloaded.
|
||||
Does not allow multiple overlapping downloads at once.
|
||||
If you try to download a Shard which overlaps a Shard that is already being downloaded,
|
||||
@@ -17,9 +19,9 @@ class ShardDownloader(ABC):
|
||||
Args:
|
||||
shard (Shard): The shard to download.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
pass
|
||||
@property
|
||||
@abstractmethod
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
pass
|
||||
|
||||
155
exo/helpers.py
155
exo/helpers.py
@@ -20,6 +20,7 @@ exo_text = r"""
|
||||
\___/_/\_\___/
|
||||
"""
|
||||
|
||||
|
||||
def get_system_info():
|
||||
if psutil.MACOS:
|
||||
if platform.machine() == "arm64":
|
||||
@@ -87,7 +88,10 @@ def terminal_link(uri, label=None):
|
||||
|
||||
T = TypeVar("T")
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
class AsyncCallback(Generic[T]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.condition: asyncio.Condition = asyncio.Condition()
|
||||
self.result: Optional[Tuple[T, ...]] = None
|
||||
@@ -95,9 +99,7 @@ class AsyncCallback(Generic[T]):
|
||||
|
||||
async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
|
||||
async with self.condition:
|
||||
await asyncio.wait_for(
|
||||
self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout
|
||||
)
|
||||
await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
|
||||
assert self.result is not None # for type checking
|
||||
return self.result
|
||||
|
||||
@@ -116,6 +118,7 @@ class AsyncCallback(Generic[T]):
|
||||
|
||||
|
||||
class AsyncCallbackSystem(Generic[K, T]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.callbacks: Dict[K, AsyncCallback[T]] = {}
|
||||
|
||||
@@ -139,89 +142,97 @@ class AsyncCallbackSystem(Generic[K, T]):
|
||||
|
||||
K = TypeVar('K', bound=str)
|
||||
V = TypeVar('V')
|
||||
|
||||
|
||||
class PrefixDict(Generic[K, V]):
|
||||
def __init__(self):
|
||||
self.items: Dict[K, V] = {}
|
||||
|
||||
def add(self, key: K, value: V) -> None:
|
||||
self.items[key] = value
|
||||
def __init__(self):
|
||||
self.items: Dict[K, V] = {}
|
||||
|
||||
def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
|
||||
return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
|
||||
def add(self, key: K, value: V) -> None:
|
||||
self.items[key] = value
|
||||
|
||||
def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
|
||||
matches = self.find_prefix(argument)
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
|
||||
return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
|
||||
|
||||
def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
|
||||
matches = self.find_prefix(argument)
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
|
||||
return max(matches, key=lambda x: len(x[0]))
|
||||
|
||||
return max(matches, key=lambda x: len(x[0]))
|
||||
|
||||
def is_valid_uuid(val):
|
||||
try:
|
||||
uuid.UUID(str(val))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
try:
|
||||
uuid.UUID(str(val))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def get_or_create_node_id():
|
||||
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
|
||||
try:
|
||||
if NODE_ID_FILE.is_file():
|
||||
with open(NODE_ID_FILE, "r") as f:
|
||||
stored_id = f.read().strip()
|
||||
if is_valid_uuid(stored_id):
|
||||
if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
|
||||
return stored_id
|
||||
else:
|
||||
if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
|
||||
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
|
||||
try:
|
||||
if NODE_ID_FILE.is_file():
|
||||
with open(NODE_ID_FILE, "r") as f:
|
||||
stored_id = f.read().strip()
|
||||
if is_valid_uuid(stored_id):
|
||||
if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
|
||||
return stored_id
|
||||
else:
|
||||
if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
|
||||
|
||||
new_id = str(uuid.uuid4())
|
||||
with open(NODE_ID_FILE, "w") as f:
|
||||
f.write(new_id)
|
||||
new_id = str(uuid.uuid4())
|
||||
with open(NODE_ID_FILE, "w") as f:
|
||||
f.write(new_id)
|
||||
|
||||
if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
|
||||
return new_id
|
||||
except IOError as e:
|
||||
if DEBUG >= 2: print(f"IO error creating node_id: {e}")
|
||||
return str(uuid.uuid4())
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
|
||||
return str(uuid.uuid4())
|
||||
|
||||
if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
|
||||
return new_id
|
||||
except IOError as e:
|
||||
if DEBUG >= 2: print(f"IO error creating node_id: {e}")
|
||||
return str(uuid.uuid4())
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def pretty_print_bytes(size_in_bytes: int) -> str:
|
||||
if size_in_bytes < 1024:
|
||||
return f"{size_in_bytes} B"
|
||||
elif size_in_bytes < 1024 ** 2:
|
||||
return f"{size_in_bytes / 1024:.2f} KB"
|
||||
elif size_in_bytes < 1024 ** 3:
|
||||
return f"{size_in_bytes / (1024 ** 2):.2f} MB"
|
||||
elif size_in_bytes < 1024 ** 4:
|
||||
return f"{size_in_bytes / (1024 ** 3):.2f} GB"
|
||||
else:
|
||||
return f"{size_in_bytes / (1024 ** 4):.2f} TB"
|
||||
if size_in_bytes < 1024:
|
||||
return f"{size_in_bytes} B"
|
||||
elif size_in_bytes < 1024**2:
|
||||
return f"{size_in_bytes / 1024:.2f} KB"
|
||||
elif size_in_bytes < 1024**3:
|
||||
return f"{size_in_bytes / (1024 ** 2):.2f} MB"
|
||||
elif size_in_bytes < 1024**4:
|
||||
return f"{size_in_bytes / (1024 ** 3):.2f} GB"
|
||||
else:
|
||||
return f"{size_in_bytes / (1024 ** 4):.2f} TB"
|
||||
|
||||
|
||||
def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
|
||||
if bytes_per_second < 1024:
|
||||
return f"{bytes_per_second} B/s"
|
||||
elif bytes_per_second < 1024 ** 2:
|
||||
return f"{bytes_per_second / 1024:.2f} KB/s"
|
||||
elif bytes_per_second < 1024 ** 3:
|
||||
return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
|
||||
elif bytes_per_second < 1024 ** 4:
|
||||
return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
|
||||
else:
|
||||
return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
|
||||
if bytes_per_second < 1024:
|
||||
return f"{bytes_per_second} B/s"
|
||||
elif bytes_per_second < 1024**2:
|
||||
return f"{bytes_per_second / 1024:.2f} KB/s"
|
||||
elif bytes_per_second < 1024**3:
|
||||
return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
|
||||
elif bytes_per_second < 1024**4:
|
||||
return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
|
||||
else:
|
||||
return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
|
||||
|
||||
|
||||
def get_all_ip_addresses():
|
||||
try:
|
||||
ip_addresses = []
|
||||
for interface in netifaces.interfaces():
|
||||
ifaddresses = netifaces.ifaddresses(interface)
|
||||
if netifaces.AF_INET in ifaddresses:
|
||||
for link in ifaddresses[netifaces.AF_INET]:
|
||||
ip = link['addr']
|
||||
ip_addresses.append(ip)
|
||||
return list(set(ip_addresses))
|
||||
except:
|
||||
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
|
||||
return ["localhost"]
|
||||
try:
|
||||
ip_addresses = []
|
||||
for interface in netifaces.interfaces():
|
||||
ifaddresses = netifaces.ifaddresses(interface)
|
||||
if netifaces.AF_INET in ifaddresses:
|
||||
for link in ifaddresses[netifaces.AF_INET]:
|
||||
ip = link['addr']
|
||||
ip_addresses.append(ip)
|
||||
return list(set(ip_addresses))
|
||||
except:
|
||||
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
|
||||
return ["localhost"]
|
||||
|
||||
@@ -52,10 +52,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
assert np.array_equal(next_resp_full, resp4)
|
||||
|
||||
|
||||
asyncio.run(
|
||||
test_inference_engine(
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
"llama3-8b-sfr",
|
||||
)
|
||||
)
|
||||
asyncio.run(test_inference_engine(
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
"llama3-8b-sfr",
|
||||
))
|
||||
|
||||
@@ -5,7 +5,9 @@ from typing import Tuple, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from .shard import Shard
|
||||
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
||||
pass
|
||||
|
||||
@@ -5,5 +5,6 @@ from mlx_lm.models.base import KVCache
|
||||
|
||||
|
||||
class IdentityBlock(nn.Module):
|
||||
|
||||
def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
|
||||
return x
|
||||
|
||||
@@ -7,7 +7,7 @@ import mlx.nn as nn
|
||||
from mlx_lm.models.base import KVCache
|
||||
from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
|
||||
from .base import IdentityBlock
|
||||
from ...shard import Shard
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
@@ -70,6 +71,7 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
@@ -107,10 +109,7 @@ class Model(nn.Module):
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
|
||||
to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
|
||||
shard_state_dict[
|
||||
f"{prefix}.mlp.switch_mlp.{
|
||||
m}.{k}"
|
||||
] = mx.stack(to_join)
|
||||
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ class ModelArgs(ModelArgs):
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
@@ -66,7 +68,9 @@ class LlamaModel(nn.Module):
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
@@ -116,9 +120,7 @@ class Model(nn.Module):
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||
)
|
||||
return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ from exo.download.shard_download import ShardDownloader
|
||||
|
||||
|
||||
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..shard import Shard
|
||||
|
||||
|
||||
class StatefulShardedModel:
|
||||
|
||||
def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
|
||||
self.shard = shard
|
||||
self.model = model
|
||||
@@ -26,6 +27,7 @@ class StatefulShardedModel:
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
@@ -74,16 +76,9 @@ class StatefulShardedModel:
|
||||
return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
|
||||
|
||||
def init_cache(self, request_id: str):
|
||||
kv_heads = (
|
||||
[self.model.n_kv_heads] * len(self.model.layers)
|
||||
if isinstance(self.model.n_kv_heads, int)
|
||||
else self.model.n_kv_heads
|
||||
)
|
||||
kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
|
||||
if self.max_kv_size is not None:
|
||||
cache = [
|
||||
RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4)
|
||||
for n in kv_heads
|
||||
]
|
||||
cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
|
||||
else:
|
||||
cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..shard import Shard
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -139,9 +140,10 @@ def load_model_shard(
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
def class_predicate(p, m):
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
return f"{p}.scales" in weights
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
@@ -156,6 +158,7 @@ def load_model_shard(
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
async def load_shard(
|
||||
model_path: str,
|
||||
shard: Shard,
|
||||
@@ -179,26 +182,27 @@ async def load_shard(
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
async def get_image_from_str(_image_str: str):
|
||||
image_str = _image_str.strip()
|
||||
image_str = _image_str.strip()
|
||||
|
||||
if image_str.startswith("http"):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_str, timeout=10) as response:
|
||||
content = await response.read()
|
||||
return Image.open(BytesIO(content)).convert("RGB")
|
||||
elif image_str.startswith("data:image/"):
|
||||
# Extract the image format and base64 data
|
||||
format_prefix, base64_data = image_str.split(";base64,")
|
||||
image_format = format_prefix.split("/")[1].lower()
|
||||
if DEBUG >= 2: print(f"{image_str=} {image_format=}")
|
||||
imgdata = base64.b64decode(base64_data)
|
||||
img = Image.open(BytesIO(imgdata))
|
||||
if image_str.startswith("http"):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_str, timeout=10) as response:
|
||||
content = await response.read()
|
||||
return Image.open(BytesIO(content)).convert("RGB")
|
||||
elif image_str.startswith("data:image/"):
|
||||
# Extract the image format and base64 data
|
||||
format_prefix, base64_data = image_str.split(";base64,")
|
||||
image_format = format_prefix.split("/")[1].lower()
|
||||
if DEBUG >= 2: print(f"{image_str=} {image_format=}")
|
||||
imgdata = base64.b64decode(base64_data)
|
||||
img = Image.open(BytesIO(imgdata))
|
||||
|
||||
# Convert to RGB if not already
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
# Convert to RGB if not already
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
|
||||
|
||||
@@ -39,8 +39,8 @@ y = full.step("full", input_ids, pixel_values, temp=0)
|
||||
full_generated_tokens = [y.item()]
|
||||
|
||||
for _ in range(13):
|
||||
y = full.step("full", y, temp=0)
|
||||
full_generated_tokens.append(y.item())
|
||||
y = full.step("full", y, temp=0)
|
||||
full_generated_tokens.append(y.item())
|
||||
|
||||
full_response = full_processor.tokenizer.decode(full_generated_tokens)
|
||||
print("full response:", full_response)
|
||||
@@ -54,11 +54,11 @@ y = m2.step("shard", y, temp=0)
|
||||
full_generated_tokens = [y.item()]
|
||||
|
||||
for _ in range(13):
|
||||
y = m1.step("shard", y, temp=0)
|
||||
y = m2.step("shard", y, temp=0)
|
||||
full_generated_tokens.append(y.item())
|
||||
y = m1.step("shard", y, temp=0)
|
||||
y = m2.step("shard", y, temp=0)
|
||||
full_generated_tokens.append(y.item())
|
||||
|
||||
sharded_response = processor2.tokenizer.decode(full_generated_tokens)
|
||||
print("sharded response:", sharded_response)
|
||||
|
||||
assert full_response == sharded_response
|
||||
assert full_response == sharded_response
|
||||
|
||||
@@ -6,6 +6,7 @@ import numpy as np
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self, shard: Optional[Shard] = None):
|
||||
self.shard = shard
|
||||
self.layers = [
|
||||
@@ -21,7 +22,7 @@ class DummyModel(nn.Module):
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
if self.shard:
|
||||
for layer in self.layers[self.shard.start_layer : self.shard.end_layer + 1]:
|
||||
for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
|
||||
x = layer(x)
|
||||
if self.shard.is_last_layer():
|
||||
x = x.reshape((1, 2, 4))
|
||||
|
||||
@@ -34,8 +34,6 @@ class Shard:
|
||||
def overlaps(self, other: 'Shard') -> bool:
|
||||
return shards_overlap(self, other)
|
||||
|
||||
|
||||
def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
|
||||
return (
|
||||
shard1.model_id == shard2.model_id
|
||||
and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
|
||||
)
|
||||
return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
|
||||
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
|
||||
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
|
||||
prompt = "In a single word only, what is the last name of the current president of the USA?"
|
||||
@@ -22,7 +23,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
|
||||
resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
|
||||
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
|
||||
input_data=resp1,
|
||||
inference_state=inference_state_1,
|
||||
)
|
||||
@@ -34,7 +35,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
)
|
||||
resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
|
||||
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
|
||||
input_data=resp3,
|
||||
inference_state=inference_state_3,
|
||||
)
|
||||
@@ -42,21 +43,22 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
assert np.array_equal(resp_full, resp2)
|
||||
assert np.array_equal(next_resp_full, resp4)
|
||||
|
||||
asyncio.run(
|
||||
test_inference_engine(
|
||||
MLXDynamicShardInferenceEngine(HFShardDownloader()),
|
||||
MLXDynamicShardInferenceEngine(HFShardDownloader()),
|
||||
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.run(test_inference_engine(
|
||||
MLXDynamicShardInferenceEngine(HFShardDownloader()),
|
||||
MLXDynamicShardInferenceEngine(HFShardDownloader()),
|
||||
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
))
|
||||
|
||||
if os.getenv("RUN_TINYGRAD", default="0") == "1":
|
||||
import tinygrad
|
||||
import os
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
||||
asyncio.run(test_inference_engine(
|
||||
asyncio.run(
|
||||
test_inference_engine(
|
||||
TinygradDynamicShardInferenceEngine(HFShardDownloader()),
|
||||
TinygradDynamicShardInferenceEngine(HFShardDownloader()),
|
||||
"TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -20,16 +20,11 @@ TOP_P = 0.9
|
||||
ALPHA_F = 0.1
|
||||
ALPHA_P = 0.0
|
||||
MODEL_PARAMS = {
|
||||
"8B": {
|
||||
"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
|
||||
"files": 1
|
||||
},
|
||||
"70B": {
|
||||
"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672},
|
||||
"files": 8
|
||||
}
|
||||
"8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
|
||||
"70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
|
||||
}
|
||||
|
||||
|
||||
def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
|
||||
# build model
|
||||
linear = nn.Linear
|
||||
@@ -48,10 +43,12 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
|
||||
|
||||
with Context(BEAM=0):
|
||||
# replace weights in model
|
||||
load_state_dict(model, weights, strict=False, consume=False) # consume=True
|
||||
load_state_dict(model, weights, strict=False, consume=False) # consume=True
|
||||
return model
|
||||
|
||||
|
||||
class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
@@ -64,7 +61,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
toks = self.tokenizer.encode(prompt)
|
||||
h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
|
||||
|
||||
if h.shape == (1,):
|
||||
if h.shape == (1, ):
|
||||
start_pos += len(toks)
|
||||
start_pos += 1
|
||||
n_captured_toks = 0
|
||||
@@ -80,7 +77,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
|
||||
h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
|
||||
|
||||
if h.shape == (1,):
|
||||
if h.shape == (1, ):
|
||||
start_pos += n_captured_toks
|
||||
start_pos += 1
|
||||
n_captured_toks = 0
|
||||
|
||||
@@ -2,21 +2,24 @@ from typing import Tuple, Union, Optional, Dict, Any
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = 1.0 / (theta**(Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
# TODO: move dtype outside this
|
||||
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
|
||||
|
||||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a,b = A[..., 0:1], A[..., 1:2]
|
||||
ro = a*c - b*d
|
||||
co = a*d + b*c
|
||||
a, b = A[..., 0:1], A[..., 1:2]
|
||||
ro = a * c - b * d
|
||||
co = a * d + b * c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
@@ -26,16 +29,19 @@ def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Te
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
|
||||
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
|
||||
class Attention:
|
||||
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
@@ -45,7 +51,7 @@ class Attention:
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
|
||||
if getenv("WQKV"):
|
||||
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
|
||||
xqkv = x @ self.wqkv.T
|
||||
@@ -69,10 +75,10 @@ class Attention:
|
||||
|
||||
# update the cache
|
||||
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
|
||||
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
|
||||
self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
|
||||
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
|
||||
keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
|
||||
values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
@@ -80,26 +86,31 @@ class Attention:
|
||||
attn = attn.reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn)
|
||||
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||
self.feed_forward = feed_forward(dim, hidden_dim, linear)
|
||||
self.attention_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
|
||||
|
||||
|
||||
# standard openai sampling
|
||||
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
assert logits.ndim == 1, "only works on 1d tensors"
|
||||
@@ -127,8 +138,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
|
||||
for i in range(k):
|
||||
t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
|
||||
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1), ))
|
||||
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1), ))
|
||||
t = (counter == t_argmax).where(0, t)
|
||||
|
||||
# approximate top p
|
||||
@@ -149,10 +160,28 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
|
||||
return output_token
|
||||
|
||||
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
norm_eps: float,
|
||||
vocab_size,
|
||||
shard: Shard = None,
|
||||
linear=nn.Linear,
|
||||
n_kv_heads=None,
|
||||
rope_theta=10000,
|
||||
max_context=1024,
|
||||
jit=True,
|
||||
feed_forward=FeedForward
|
||||
):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
|
||||
self.norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||
@@ -162,10 +191,10 @@ class Transformer:
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
self.shard = shard
|
||||
|
||||
def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
|
||||
def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
|
||||
seqlen = x.shape[1]
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
|
||||
|
||||
if self.shard.is_first_layer():
|
||||
h = self.tok_embeddings(x)
|
||||
@@ -182,24 +211,33 @@ class Transformer:
|
||||
else:
|
||||
return h
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
||||
def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
|
||||
if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
|
||||
return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
||||
|
||||
def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
||||
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
|
||||
for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
|
||||
for x in ["q", "k", "v", "o"]
|
||||
for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
|
||||
for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
|
||||
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
|
||||
for l in range(len(model.layers))},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
@@ -215,9 +253,10 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
|
||||
def fix_bf16(weights:Dict[Any, Tensor]):
|
||||
|
||||
def fix_bf16(weights: Dict[Any, Tensor]):
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
# TODO: check if device supports bf16
|
||||
return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
|
||||
@@ -8,8 +8,10 @@ from exo.helpers import DEBUG
|
||||
from exo.download.hf.hf_helpers import get_allow_patterns
|
||||
from fnmatch import fnmatch
|
||||
|
||||
|
||||
# **** helper functions ****
|
||||
def concat_weights(models, device=None):
|
||||
|
||||
def convert(name) -> Tensor:
|
||||
disk_tensors: List[Tensor] = [model[name] for model in models]
|
||||
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
||||
@@ -17,11 +19,14 @@ def concat_weights(models, device=None):
|
||||
axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
|
||||
lazy_tensors = [data.to(device=device) for data in disk_tensors]
|
||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||
|
||||
return {name: convert(name) for name in {name: None for model in models for name in model}}
|
||||
|
||||
def load(fn:str, shard: Shard):
|
||||
|
||||
def load(fn: str, shard: Shard):
|
||||
if fn.endswith('.index.json'):
|
||||
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
||||
with open(fn) as fp:
|
||||
weight_map = json.load(fp)['weight_map']
|
||||
parts = {}
|
||||
filtered_weight_map = {}
|
||||
allow_patterns = get_allow_patterns(weight_map, shard)
|
||||
|
||||
@@ -2,6 +2,7 @@ import traceback
|
||||
from transformers import AutoTokenizer, AutoProcessor
|
||||
from exo.helpers import DEBUG
|
||||
|
||||
|
||||
async def resolve_tokenizer(model_id: str):
|
||||
try:
|
||||
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
|
||||
|
||||
@@ -2,39 +2,32 @@ from exo.inference.shard import Shard
|
||||
|
||||
model_base_shards = {
|
||||
### llama
|
||||
"llama-3.1-8b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3.1-70b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
"llama-3.1-405b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
|
||||
},
|
||||
"llama-3-8b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3-70b": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
"llama-3.1-8b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3.1-70b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
"llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), },
|
||||
"llama-3-8b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llama-3-70b":
|
||||
{
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
"TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
|
||||
},
|
||||
### mistral
|
||||
"mistral-nemo": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
|
||||
},
|
||||
"mistral-large": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
|
||||
},
|
||||
"mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), },
|
||||
"mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), },
|
||||
### deepseek v2
|
||||
"deepseek-coder-v2-lite": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
|
||||
},
|
||||
"deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), },
|
||||
### llava
|
||||
"llava-1.5-7b-hf": {
|
||||
"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
|
||||
},
|
||||
"llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), },
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from .peer_handle import PeerHandle
|
||||
|
||||
|
||||
class Discovery(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo import DEBUG_DISCOVERY
|
||||
|
||||
|
||||
class ListenProtocol(asyncio.DatagramProtocol):
|
||||
|
||||
def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
|
||||
super().__init__()
|
||||
self.on_message = on_message
|
||||
@@ -24,6 +25,7 @@ class ListenProtocol(asyncio.DatagramProtocol):
|
||||
|
||||
|
||||
class GRPCDiscovery(Discovery):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
@@ -97,14 +99,12 @@ class GRPCDiscovery(Discovery):
|
||||
sock = transport.get_extra_info("socket")
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
|
||||
message = json.dumps(
|
||||
{
|
||||
"type": "discovery",
|
||||
"node_id": self.node_id,
|
||||
"grpc_port": self.node_port,
|
||||
"device_capabilities": self.device_capabilities.to_dict(),
|
||||
}
|
||||
).encode("utf-8")
|
||||
message = json.dumps({
|
||||
"type": "discovery",
|
||||
"node_id": self.node_id,
|
||||
"grpc_port": self.node_port,
|
||||
"device_capabilities": self.device_capabilities.to_dict(),
|
||||
}).encode("utf-8")
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -166,14 +166,14 @@ class GRPCDiscovery(Discovery):
|
||||
try:
|
||||
current_time = time.time()
|
||||
peers_to_remove = [
|
||||
peer_handle.id()
|
||||
for peer_handle, connected_at, last_seen in self.known_peers.values()
|
||||
peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
|
||||
if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
|
||||
]
|
||||
if DEBUG_DISCOVERY >= 2:
|
||||
print(
|
||||
"Peer statuses:",
|
||||
{peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()},
|
||||
{peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
|
||||
for peer_handle, connected_at, last_seen in self.known_peers.values()},
|
||||
)
|
||||
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
|
||||
print(f"Cleaning up peers: {peers_to_remove}")
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.topology.device_capabilities import DeviceCapabilities
|
||||
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
|
||||
self._id = _id
|
||||
self.address = address
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.orchestration import Node
|
||||
|
||||
|
||||
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
|
||||
def __init__(self, node: Node, host: str, port: int):
|
||||
self.node = node
|
||||
self.host = host
|
||||
@@ -81,9 +82,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
node_service_pb2.InferenceResult(
|
||||
tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
|
||||
is_finished=result[1],
|
||||
)
|
||||
if result[0] is not None
|
||||
else node_service_pb2.InferenceResult(is_finished=result[1])
|
||||
) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
|
||||
)
|
||||
|
||||
async def CollectTopology(self, request, context):
|
||||
@@ -91,12 +90,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
visited = set(request.visited)
|
||||
topology = await self.node.collect_topology(visited, max_depth)
|
||||
nodes = {
|
||||
node_id: node_service_pb2.DeviceCapabilities(
|
||||
model=cap.model,
|
||||
chip=cap.chip,
|
||||
memory=cap.memory,
|
||||
flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
|
||||
)
|
||||
node_id:
|
||||
node_service_pb2.DeviceCapabilities(
|
||||
model=cap.model,
|
||||
chip=cap.chip,
|
||||
memory=cap.memory,
|
||||
flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
|
||||
)
|
||||
for node_id, cap in topology.nodes.items()
|
||||
}
|
||||
peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
|
||||
|
||||
@@ -11,10 +11,9 @@ from google.protobuf.internal import builder as _builder
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3'
|
||||
)
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
@@ -25,38 +24,38 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=122
|
||||
_globals['_PROMPTREQUEST']._serialized_end=317
|
||||
_globals['_TENSORREQUEST']._serialized_start=320
|
||||
_globals['_TENSORREQUEST']._serialized_end=499
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=501
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=548
|
||||
_globals['_INFERENCERESULT']._serialized_start=550
|
||||
_globals['_INFERENCERESULT']._serialized_end=642
|
||||
_globals['_TENSOR']._serialized_start=644
|
||||
_globals['_TENSOR']._serialized_end=703
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=705
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=765
|
||||
_globals['_TOPOLOGY']._serialized_start=768
|
||||
_globals['_TOPOLOGY']._serialized_end=1038
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=889
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=967
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=969
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1038
|
||||
_globals['_PEERS']._serialized_start=1040
|
||||
_globals['_PEERS']._serialized_end=1065
|
||||
_globals['_DEVICEFLOPS']._serialized_start=1067
|
||||
_globals['_DEVICEFLOPS']._serialized_end=1122
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=1124
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=1231
|
||||
_globals['_SENDRESULTREQUEST']._serialized_start=1233
|
||||
_globals['_SENDRESULTREQUEST']._serialized_end=1309
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1311
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1372
|
||||
_globals['_EMPTY']._serialized_start=1374
|
||||
_globals['_EMPTY']._serialized_end=1381
|
||||
_globals['_NODESERVICE']._serialized_start=1384
|
||||
_globals['_NODESERVICE']._serialized_end=1862
|
||||
_globals['_SHARD']._serialized_start = 36
|
||||
_globals['_SHARD']._serialized_end = 119
|
||||
_globals['_PROMPTREQUEST']._serialized_start = 122
|
||||
_globals['_PROMPTREQUEST']._serialized_end = 317
|
||||
_globals['_TENSORREQUEST']._serialized_start = 320
|
||||
_globals['_TENSORREQUEST']._serialized_end = 499
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_start = 501
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_end = 548
|
||||
_globals['_INFERENCERESULT']._serialized_start = 550
|
||||
_globals['_INFERENCERESULT']._serialized_end = 642
|
||||
_globals['_TENSOR']._serialized_start = 644
|
||||
_globals['_TENSOR']._serialized_end = 703
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start = 705
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end = 765
|
||||
_globals['_TOPOLOGY']._serialized_start = 768
|
||||
_globals['_TOPOLOGY']._serialized_end = 1038
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start = 889
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end = 967
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start = 969
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end = 1038
|
||||
_globals['_PEERS']._serialized_start = 1040
|
||||
_globals['_PEERS']._serialized_end = 1065
|
||||
_globals['_DEVICEFLOPS']._serialized_start = 1067
|
||||
_globals['_DEVICEFLOPS']._serialized_end = 1122
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start = 1124
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end = 1231
|
||||
_globals['_SENDRESULTREQUEST']._serialized_start = 1233
|
||||
_globals['_SENDRESULTREQUEST']._serialized_end = 1309
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start = 1311
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end = 1372
|
||||
_globals['_EMPTY']._serialized_start = 1374
|
||||
_globals['_EMPTY']._serialized_end = 1381
|
||||
_globals['_NODESERVICE']._serialized_start = 1384
|
||||
_globals['_NODESERVICE']._serialized_end = 1862
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -12,306 +12,264 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
warnings.warn(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in node_service_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
|
||||
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
|
||||
RuntimeWarning
|
||||
)
|
||||
warnings.warn(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
|
||||
f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
|
||||
f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
|
||||
)
|
||||
|
||||
|
||||
class NodeServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendPrompt = channel.unary_unary(
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTensor = channel.unary_unary(
|
||||
'/node_service.NodeService/SendTensor',
|
||||
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.GetInferenceResult = channel.unary_unary(
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.InferenceResult.FromString,
|
||||
_registered_method=True)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Topology.FromString,
|
||||
_registered_method=True)
|
||||
self.SendResult = channel.unary_unary(
|
||||
'/node_service.NodeService/SendResult',
|
||||
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendOpaqueStatus = channel.unary_unary(
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPrompt = channel.unary_unary(
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True
|
||||
)
|
||||
self.SendTensor = channel.unary_unary(
|
||||
'/node_service.NodeService/SendTensor',
|
||||
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True
|
||||
)
|
||||
self.GetInferenceResult = channel.unary_unary(
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.InferenceResult.FromString,
|
||||
_registered_method=True
|
||||
)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Topology.FromString,
|
||||
_registered_method=True
|
||||
)
|
||||
self.SendResult = channel.unary_unary(
|
||||
'/node_service.NodeService/SendResult',
|
||||
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True
|
||||
)
|
||||
self.SendOpaqueStatus = channel.unary_unary(
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True
|
||||
)
|
||||
|
||||
|
||||
class NodeServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def SendPrompt(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPrompt(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
def SendTensor(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTensor(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
def GetInferenceResult(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetInferenceResult(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
def CollectTopology(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CollectTopology(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
def SendResult(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendResult(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendOpaqueStatus(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
def SendOpaqueStatus(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_NodeServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendPrompt': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPrompt,
|
||||
request_deserializer=node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendTensor': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTensor,
|
||||
request_deserializer=node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetInferenceResult,
|
||||
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
|
||||
),
|
||||
'CollectTopology': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=node__service__pb2.Topology.SerializeToString,
|
||||
),
|
||||
'SendResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendResult,
|
||||
request_deserializer=node__service__pb2.SendResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendOpaqueStatus,
|
||||
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'node_service.NodeService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
|
||||
rpc_method_handlers = {
|
||||
'SendPrompt':
|
||||
grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPrompt,
|
||||
request_deserializer=node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendTensor':
|
||||
grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTensor,
|
||||
request_deserializer=node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'GetInferenceResult':
|
||||
grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetInferenceResult,
|
||||
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
|
||||
),
|
||||
'CollectTopology':
|
||||
grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=node__service__pb2.Topology.SerializeToString,
|
||||
),
|
||||
'SendResult':
|
||||
grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendResult,
|
||||
request_deserializer=node__service__pb2.SendResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendOpaqueStatus':
|
||||
grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendOpaqueStatus,
|
||||
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler, ))
|
||||
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class NodeService(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def SendPrompt(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
node__service__pb2.PromptRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@staticmethod
|
||||
def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
node__service__pb2.PromptRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def SendTensor(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendTensor',
|
||||
node__service__pb2.TensorRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@staticmethod
|
||||
def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendTensor',
|
||||
node__service__pb2.TensorRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def GetInferenceResult(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
node__service__pb2.InferenceResult.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@staticmethod
|
||||
def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
node__service__pb2.InferenceResult.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def CollectTopology(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
node__service__pb2.Topology.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@staticmethod
|
||||
def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
node__service__pb2.Topology.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def SendResult(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendResult',
|
||||
node__service__pb2.SendResultRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@staticmethod
|
||||
def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendResult',
|
||||
node__service__pb2.SendResultRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def SendOpaqueStatus(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@staticmethod
|
||||
def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from .grpc_discovery import GRPCDiscovery
|
||||
|
||||
|
||||
class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
|
||||
self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
|
||||
|
||||
@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
|
||||
|
||||
|
||||
class PeerHandle(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def id(self) -> str:
|
||||
pass
|
||||
|
||||
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Server(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
|
||||
|
||||
|
||||
class Node(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
pass
|
||||
|
||||
@@ -18,6 +18,7 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
|
||||
|
||||
|
||||
class StandardNode(Node):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
@@ -359,6 +360,7 @@ class StandardNode(Node):
|
||||
self.on_token.trigger_all(request_id, tokens, is_finished)
|
||||
|
||||
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||
|
||||
async def send_result_to_peer(peer):
|
||||
try:
|
||||
await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
|
||||
@@ -372,6 +374,7 @@ class StandardNode(Node):
|
||||
|
||||
async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
|
||||
if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
|
||||
|
||||
async def send_status_to_peer(peer):
|
||||
try:
|
||||
await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
|
||||
|
||||
@@ -7,6 +7,7 @@ from exo.networking.peer_handle import PeerHandle
|
||||
|
||||
|
||||
class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_inference_engine = AsyncMock()
|
||||
self.mock_server = AsyncMock()
|
||||
|
||||
@@ -12,6 +12,7 @@ async def main() -> None:
|
||||
callback2 = callback_system.register("callback2")
|
||||
|
||||
def on_next_callback(name: str) -> Callable[..., None]:
|
||||
|
||||
def callback(*args: Any) -> None:
|
||||
print(f"{name} received values: {args}")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ class Partition:
|
||||
|
||||
|
||||
class PartitioningStrategy(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def partition(self, topology: Topology) -> List[Partition]:
|
||||
pass
|
||||
|
||||
@@ -5,6 +5,7 @@ from .partitioning_strategy import Partition
|
||||
|
||||
|
||||
class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
|
||||
|
||||
def partition(self, topology: Topology) -> List[Partition]:
|
||||
nodes = list(topology.all_nodes())
|
||||
nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)
|
||||
|
||||
@@ -4,6 +4,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
|
||||
|
||||
|
||||
class TestMacDeviceCapabilities(unittest.TestCase):
|
||||
|
||||
@patch("subprocess.check_output")
|
||||
def test_mac_device_capabilities_pro(self, mock_check_output):
|
||||
# Mock the subprocess output
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
|
||||
|
||||
|
||||
class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
|
||||
|
||||
def test_map_partitions_to_shards(self):
|
||||
partitions = [
|
||||
Partition("node1", 0.0, 0.42857),
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.topology.partitioning_strategy import Partition
|
||||
|
||||
|
||||
class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
|
||||
|
||||
def test_partition(self):
|
||||
# triangle
|
||||
# node1 -> node2 -> node3 -> node1
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Dict, Set, Optional
|
||||
|
||||
|
||||
class Topology:
|
||||
|
||||
def __init__(self):
|
||||
self.nodes: Dict[str, DeviceCapabilities] = {} # Maps node IDs to DeviceCapabilities
|
||||
self.peer_graph: Dict[str, Set[str]] = {} # Adjacency list representing the graph
|
||||
|
||||
@@ -9,58 +9,60 @@ from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
|
||||
|
||||
|
||||
def create_hf_repo_progress_event(
|
||||
completed_files: int = 5,
|
||||
total_files: int = 10,
|
||||
downloaded_bytes: int = 500000000,
|
||||
downloaded_bytes_this_session: int = 250000000,
|
||||
total_bytes: int = 1000000000,
|
||||
overall_speed: int = 5000000,
|
||||
overall_eta: timedelta = timedelta(seconds=100),
|
||||
file_progress: dict = None,
|
||||
status: str = "in_progress"
|
||||
completed_files: int = 5,
|
||||
total_files: int = 10,
|
||||
downloaded_bytes: int = 500000000,
|
||||
downloaded_bytes_this_session: int = 250000000,
|
||||
total_bytes: int = 1000000000,
|
||||
overall_speed: int = 5000000,
|
||||
overall_eta: timedelta = timedelta(seconds=100),
|
||||
file_progress: dict = None,
|
||||
status: str = "in_progress"
|
||||
) -> RepoProgressEvent:
|
||||
if file_progress is None:
|
||||
file_progress = {
|
||||
"file1.bin": RepoFileProgressEvent(
|
||||
repo_id="repo_id",
|
||||
repo_revision="repo_revision",
|
||||
file_path="file1.bin",
|
||||
downloaded=100000000,
|
||||
downloaded_this_session=50000000,
|
||||
total=200000000,
|
||||
speed=1000000,
|
||||
eta=timedelta(seconds=100),
|
||||
status="in_progress"
|
||||
),
|
||||
"file2.bin": RepoFileProgressEvent(
|
||||
repo_id="repo_id",
|
||||
repo_revision="repo_revision",
|
||||
file_path="file2.bin",
|
||||
downloaded=200000000,
|
||||
downloaded_this_session=100000000,
|
||||
total=200000000,
|
||||
speed=2000000,
|
||||
eta=timedelta(seconds=0),
|
||||
status="complete"
|
||||
)
|
||||
}
|
||||
if file_progress is None:
|
||||
file_progress = {
|
||||
"file1.bin":
|
||||
RepoFileProgressEvent(
|
||||
repo_id="repo_id",
|
||||
repo_revision="repo_revision",
|
||||
file_path="file1.bin",
|
||||
downloaded=100000000,
|
||||
downloaded_this_session=50000000,
|
||||
total=200000000,
|
||||
speed=1000000,
|
||||
eta=timedelta(seconds=100),
|
||||
status="in_progress"
|
||||
), "file2.bin":
|
||||
RepoFileProgressEvent(
|
||||
repo_id="repo_id",
|
||||
repo_revision="repo_revision",
|
||||
file_path="file2.bin",
|
||||
downloaded=200000000,
|
||||
downloaded_this_session=100000000,
|
||||
total=200000000,
|
||||
speed=2000000,
|
||||
eta=timedelta(seconds=0),
|
||||
status="complete"
|
||||
)
|
||||
}
|
||||
|
||||
return RepoProgressEvent(
|
||||
repo_id="repo_id",
|
||||
repo_revision="repo_revision",
|
||||
completed_files=completed_files,
|
||||
total_files=total_files,
|
||||
downloaded_bytes=downloaded_bytes,
|
||||
downloaded_bytes_this_session=downloaded_bytes_this_session,
|
||||
total_bytes=total_bytes,
|
||||
overall_speed=overall_speed,
|
||||
overall_eta=overall_eta,
|
||||
file_progress=file_progress,
|
||||
status=status
|
||||
)
|
||||
return RepoProgressEvent(
|
||||
repo_id="repo_id",
|
||||
repo_revision="repo_revision",
|
||||
completed_files=completed_files,
|
||||
total_files=total_files,
|
||||
downloaded_bytes=downloaded_bytes,
|
||||
downloaded_bytes_this_session=downloaded_bytes_this_session,
|
||||
total_bytes=total_bytes,
|
||||
overall_speed=overall_speed,
|
||||
overall_eta=overall_eta,
|
||||
file_progress=file_progress,
|
||||
status=status
|
||||
)
|
||||
|
||||
|
||||
class TestNodeViz(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.topology = Topology()
|
||||
self.topology.update_node(
|
||||
|
||||
@@ -16,7 +16,9 @@ from rich.syntax import Syntax
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
|
||||
|
||||
class TopologyViz:
|
||||
|
||||
def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
|
||||
self.chatgpt_api_endpoints = chatgpt_api_endpoints
|
||||
self.web_chat_urls = web_chat_urls
|
||||
@@ -28,11 +30,7 @@ class TopologyViz:
|
||||
|
||||
self.console = Console()
|
||||
self.layout = Layout()
|
||||
self.layout.split(
|
||||
Layout(name="main"),
|
||||
Layout(name="prompt_output", size=15),
|
||||
Layout(name="download", size=25)
|
||||
)
|
||||
self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25))
|
||||
self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
|
||||
self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
|
||||
self.download_panel = Panel("", title="Download Progress", border_style="cyan")
|
||||
@@ -75,11 +73,11 @@ class TopologyViz:
|
||||
|
||||
# Update and show/hide prompt and output panel
|
||||
if any(r[0] or r[1] for r in self.requests.values()):
|
||||
self.prompt_output_panel = self._generate_prompt_output_layout()
|
||||
self.layout["prompt_output"].update(self.prompt_output_panel)
|
||||
self.layout["prompt_output"].visible = True
|
||||
self.prompt_output_panel = self._generate_prompt_output_layout()
|
||||
self.layout["prompt_output"].update(self.prompt_output_panel)
|
||||
self.layout["prompt_output"].visible = True
|
||||
else:
|
||||
self.layout["prompt_output"].visible = False
|
||||
self.layout["prompt_output"].visible = False
|
||||
|
||||
# Only show download_panel if there are in-progress downloads
|
||||
if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
|
||||
@@ -97,33 +95,33 @@ class TopologyViz:
|
||||
max_lines = 13 # Maximum number of lines for the entire panel content
|
||||
|
||||
for (prompt, output) in reversed(requests):
|
||||
prompt_icon, output_icon = "💬️", "🤖"
|
||||
prompt_icon, output_icon = "💬️", "🤖"
|
||||
|
||||
# Process prompt
|
||||
prompt_lines = prompt.split('\n')
|
||||
if len(prompt_lines) > max_lines // 2:
|
||||
prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
|
||||
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
|
||||
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
|
||||
# Process prompt
|
||||
prompt_lines = prompt.split('\n')
|
||||
if len(prompt_lines) > max_lines // 2:
|
||||
prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
|
||||
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
|
||||
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
|
||||
|
||||
# Process output
|
||||
output_lines = output.split('\n')
|
||||
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
|
||||
if len(output_lines) > remaining_lines:
|
||||
output_lines = output_lines[:remaining_lines - 1] + ['...']
|
||||
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
|
||||
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
|
||||
# Process output
|
||||
output_lines = output.split('\n')
|
||||
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
|
||||
if len(output_lines) > remaining_lines:
|
||||
output_lines = output_lines[:remaining_lines - 1] + ['...']
|
||||
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
|
||||
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
|
||||
|
||||
content.append(prompt_text)
|
||||
content.append(output_text)
|
||||
content.append(Text()) # Empty line between entries
|
||||
content.append(prompt_text)
|
||||
content.append(output_text)
|
||||
content.append(Text()) # Empty line between entries
|
||||
|
||||
return Panel(
|
||||
Group(*content),
|
||||
title="",
|
||||
border_style="cyan",
|
||||
height=15, # Increased height to accommodate multiple lines
|
||||
expand=True # Allow the panel to expand to full width
|
||||
Group(*content),
|
||||
title="",
|
||||
border_style="cyan",
|
||||
height=15, # Increased height to accommodate multiple lines
|
||||
expand=True # Allow the panel to expand to full width
|
||||
)
|
||||
|
||||
def _generate_main_layout(self) -> str:
|
||||
@@ -185,14 +183,14 @@ class TopologyViz:
|
||||
visualization[bar_y][bar_start_x + i] = segment
|
||||
|
||||
# Add labels
|
||||
visualization[bar_y - 1][bar_start_x - 10 : bar_start_x - 3] = "GPU poor"
|
||||
visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2 : bar_start_x + bar_width * 2 + 11] = "GPU rich"
|
||||
visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
|
||||
visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2:bar_start_x + bar_width * 2 + 11] = "GPU rich"
|
||||
|
||||
# Add position indicator and FLOPS value
|
||||
pos_x = bar_start_x + int(bar_pos * bar_width)
|
||||
flops_str = f"{total_flops:.2f} TFLOPS"
|
||||
visualization[bar_y - 1][pos_x] = "▼"
|
||||
visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
|
||||
visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
|
||||
visualization[bar_y + 2][pos_x] = "▲"
|
||||
|
||||
# Add an extra empty line for spacing
|
||||
@@ -270,41 +268,41 @@ class TopologyViz:
|
||||
|
||||
# Current node download progress
|
||||
if self.node_id in self.node_download_progress:
|
||||
download_progress = self.node_download_progress[self.node_id]
|
||||
title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
|
||||
summary.add_row(Text(title, style="bold"))
|
||||
progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
|
||||
summary.add_row(progress_info)
|
||||
download_progress = self.node_download_progress[self.node_id]
|
||||
title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
|
||||
summary.add_row(Text(title, style="bold"))
|
||||
progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
|
||||
summary.add_row(progress_info)
|
||||
|
||||
eta_info = f"{download_progress.overall_eta}"
|
||||
summary.add_row(eta_info)
|
||||
eta_info = f"{download_progress.overall_eta}"
|
||||
summary.add_row(eta_info)
|
||||
|
||||
summary.add_row("") # Empty row for spacing
|
||||
summary.add_row("") # Empty row for spacing
|
||||
|
||||
for file_path, file_progress in download_progress.file_progress.items():
|
||||
if file_progress.status != "complete":
|
||||
progress = int(file_progress.downloaded / file_progress.total * 30)
|
||||
bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
|
||||
percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
|
||||
summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
|
||||
for file_path, file_progress in download_progress.file_progress.items():
|
||||
if file_progress.status != "complete":
|
||||
progress = int(file_progress.downloaded / file_progress.total * 30)
|
||||
bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
|
||||
percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
|
||||
summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
|
||||
|
||||
summary.add_row("") # Empty row for spacing
|
||||
|
||||
# Other nodes download progress summary
|
||||
summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
|
||||
for node_id, progress in self.node_download_progress.items():
|
||||
if node_id != self.node_id:
|
||||
device = self.topology.nodes.get(node_id)
|
||||
partition = next((p for p in self.partitions if p.node_id == node_id), None)
|
||||
partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
|
||||
percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
|
||||
speed = pretty_print_bytes_per_second(progress.overall_speed)
|
||||
device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
|
||||
progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
|
||||
progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
|
||||
percentage_str = f"{percentage:.1f}%"
|
||||
eta_str = f"{progress.overall_eta}"
|
||||
summary.add_row(device_info, progress_info, percentage_str)
|
||||
summary.add_row("", progress_bar, eta_str)
|
||||
if node_id != self.node_id:
|
||||
device = self.topology.nodes.get(node_id)
|
||||
partition = next((p for p in self.partitions if p.node_id == node_id), None)
|
||||
partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
|
||||
percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
|
||||
speed = pretty_print_bytes_per_second(progress.overall_speed)
|
||||
device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
|
||||
progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
|
||||
progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
|
||||
percentage_str = f"{percentage:.1f}%"
|
||||
eta_str = f"{progress.overall_eta}"
|
||||
summary.add_row(device_info, progress_info, percentage_str)
|
||||
summary.add_row("", progress_bar, eta_str)
|
||||
|
||||
return summary
|
||||
return summary
|
||||
|
||||
@@ -3,51 +3,49 @@ import asyncio
|
||||
from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
|
||||
|
||||
DEFAULT_ALLOW_PATTERNS = [
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.safetensors",
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.safetensors",
|
||||
]
|
||||
# Always ignore `.git` and `.cache/huggingface` folders in commits
|
||||
DEFAULT_IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".git/*",
|
||||
"*/.git",
|
||||
"**/.git/**",
|
||||
".cache/huggingface",
|
||||
".cache/huggingface/*",
|
||||
"*/.cache/huggingface",
|
||||
"**/.cache/huggingface/**",
|
||||
".git",
|
||||
".git/*",
|
||||
"*/.git",
|
||||
"**/.git/**",
|
||||
".cache/huggingface",
|
||||
".cache/huggingface/*",
|
||||
"*/.cache/huggingface",
|
||||
"**/.cache/huggingface/**",
|
||||
]
|
||||
|
||||
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
|
||||
async def progress_callback(event: RepoProgressEvent):
|
||||
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
|
||||
print(f"Estimated time remaining: {event.overall_eta}")
|
||||
print("File Progress:")
|
||||
for file_path, progress in event.file_progress.items():
|
||||
status_icon = {
|
||||
'not_started': '⚪',
|
||||
'in_progress': '🔵',
|
||||
'complete': '✅'
|
||||
}[progress.status]
|
||||
eta_str = str(progress.eta)
|
||||
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
|
||||
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
|
||||
print("\n")
|
||||
|
||||
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
|
||||
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
|
||||
|
||||
async def progress_callback(event: RepoProgressEvent):
|
||||
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
|
||||
print(f"Estimated time remaining: {event.overall_eta}")
|
||||
print("File Progress:")
|
||||
for file_path, progress in event.file_progress.items():
|
||||
status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
|
||||
eta_str = str(progress.eta)
|
||||
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
|
||||
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
|
||||
print("\n")
|
||||
|
||||
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
|
||||
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
|
||||
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
|
||||
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
|
||||
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
|
||||
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
|
||||
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
|
||||
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
|
||||
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
|
||||
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
|
||||
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
|
||||
|
||||
207
main.py
207
main.py
@@ -53,131 +53,146 @@ inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
||||
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
||||
|
||||
if args.node_port is None:
|
||||
args.node_port = find_available_port(args.node_host)
|
||||
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
|
||||
args.node_port = find_available_port(args.node_host)
|
||||
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
|
||||
|
||||
args.node_id = args.node_id or get_or_create_node_id()
|
||||
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
|
||||
chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
|
||||
web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
|
||||
chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
|
||||
web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
|
||||
if DEBUG >= 0:
|
||||
print("Chat interface started:")
|
||||
for web_chat_url in web_chat_urls:
|
||||
print(f" - {terminal_link(web_chat_url)}")
|
||||
print("ChatGPT API endpoint served at:")
|
||||
for chatgpt_api_endpoint in chatgpt_api_endpoints:
|
||||
print(f" - {terminal_link(chatgpt_api_endpoint)}")
|
||||
print("Chat interface started:")
|
||||
for web_chat_url in web_chat_urls:
|
||||
print(f" - {terminal_link(web_chat_url)}")
|
||||
print("ChatGPT API endpoint served at:")
|
||||
for chatgpt_api_endpoint in chatgpt_api_endpoints:
|
||||
print(f" - {terminal_link(chatgpt_api_endpoint)}")
|
||||
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
|
||||
node = StandardNode(
|
||||
args.node_id,
|
||||
None,
|
||||
inference_engine,
|
||||
discovery,
|
||||
chatgpt_api_endpoints=chatgpt_api_endpoints,
|
||||
web_chat_urls=web_chat_urls,
|
||||
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
|
||||
disable_tui=args.disable_tui,
|
||||
max_generate_tokens=args.max_generate_tokens,
|
||||
topology_viz=topology_viz
|
||||
args.node_id,
|
||||
None,
|
||||
inference_engine,
|
||||
discovery,
|
||||
chatgpt_api_endpoints=chatgpt_api_endpoints,
|
||||
web_chat_urls=web_chat_urls,
|
||||
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
|
||||
disable_tui=args.disable_tui,
|
||||
max_generate_tokens=args.max_generate_tokens,
|
||||
topology_viz=topology_viz
|
||||
)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None)
|
||||
node.on_token.register("update_topology_viz").on_next(
|
||||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
|
||||
api = ChatGPTAPI(
|
||||
node,
|
||||
inference_engine.__class__.__name__,
|
||||
response_timeout_secs=args.chatgpt_api_response_timeout_secs,
|
||||
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
|
||||
)
|
||||
node.on_token.register("update_topology_viz").on_next(
|
||||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id,
|
||||
inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
|
||||
)
|
||||
|
||||
|
||||
def preemptively_start_download(request_id: str, opaque_status: str):
|
||||
try:
|
||||
status = json.loads(opaque_status)
|
||||
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
|
||||
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
||||
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
||||
asyncio.create_task(shard_downloader.ensure_shard(current_shard))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to preemptively start download: {e}")
|
||||
traceback.print_exc()
|
||||
try:
|
||||
status = json.loads(opaque_status)
|
||||
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
|
||||
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
||||
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
||||
asyncio.create_task(shard_downloader.ensure_shard(current_shard))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to preemptively start download: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
|
||||
if args.prometheus_client_port:
|
||||
from exo.stats.metrics import start_metrics_server
|
||||
start_metrics_server(node, args.prometheus_client_port)
|
||||
from exo.stats.metrics import start_metrics_server
|
||||
start_metrics_server(node, args.prometheus_client_port)
|
||||
|
||||
last_broadcast_time = 0
|
||||
|
||||
|
||||
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
||||
global last_broadcast_time
|
||||
current_time = time.time()
|
||||
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
|
||||
last_broadcast_time = current_time
|
||||
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
||||
global last_broadcast_time
|
||||
current_time = time.time()
|
||||
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
|
||||
last_broadcast_time = current_time
|
||||
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
||||
|
||||
|
||||
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
||||
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
print("Thank you for using exo.")
|
||||
print_yellow_exo()
|
||||
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in server_tasks]
|
||||
print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
||||
await asyncio.gather(*server_tasks, return_exceptions=True)
|
||||
await server.stop()
|
||||
loop.stop()
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
print("Thank you for using exo.")
|
||||
print_yellow_exo()
|
||||
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in server_tasks]
|
||||
print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
||||
await asyncio.gather(*server_tasks, return_exceptions=True)
|
||||
await server.stop()
|
||||
loop.stop()
|
||||
|
||||
|
||||
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
||||
shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(shard.model_id)
|
||||
request_id = str(uuid.uuid4())
|
||||
callback_id = f"cli-wait-response-{request_id}"
|
||||
callback = node.on_token.register(callback_id)
|
||||
if topology_viz:
|
||||
topology_viz.update_prompt(request_id, prompt)
|
||||
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
|
||||
shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(shard.model_id)
|
||||
request_id = str(uuid.uuid4())
|
||||
callback_id = f"cli-wait-response-{request_id}"
|
||||
callback = node.on_token.register(callback_id)
|
||||
if topology_viz:
|
||||
topology_viz.update_prompt(request_id, prompt)
|
||||
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
|
||||
|
||||
try:
|
||||
print(f"Processing prompt: {prompt}")
|
||||
await node.process_prompt(shard, prompt, None, request_id=request_id)
|
||||
try:
|
||||
print(f"Processing prompt: {prompt}")
|
||||
await node.process_prompt(shard, prompt, None, request_id=request_id)
|
||||
|
||||
_, tokens, _ = await callback.wait(
|
||||
lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
|
||||
timeout=300
|
||||
)
|
||||
_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
|
||||
|
||||
print("\nGenerated response:")
|
||||
print(tokenizer.decode(tokens))
|
||||
except Exception as e:
|
||||
print(f"Error processing prompt: {str(e)}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
node.on_token.deregister(callback_id)
|
||||
|
||||
print("\nGenerated response:")
|
||||
print(tokenizer.decode(tokens))
|
||||
except Exception as e:
|
||||
print(f"Error processing prompt: {str(e)}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
node.on_token.deregister(callback_id)
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Use a more direct approach to handle signals
|
||||
def handle_exit():
|
||||
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
|
||||
# Use a more direct approach to handle signals
|
||||
def handle_exit():
|
||||
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
|
||||
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
|
||||
await node.start(wait_for_peers=args.wait_for_peers)
|
||||
await node.start(wait_for_peers=args.wait_for_peers)
|
||||
|
||||
if args.run_model:
|
||||
await run_model_cli(node, inference_engine, args.run_model, args.prompt)
|
||||
else:
|
||||
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
||||
await asyncio.Event().wait()
|
||||
|
||||
if args.run_model:
|
||||
await run_model_cli(node, inference_engine, args.run_model, args.prompt)
|
||||
else:
|
||||
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
||||
await asyncio.Event().wait()
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
print("Received keyboard interrupt. Shutting down...")
|
||||
finally:
|
||||
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
|
||||
loop.close()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
print("Received keyboard interrupt. Shutting down...")
|
||||
finally:
|
||||
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
|
||||
loop.close()
|
||||
|
||||
Reference in New Issue
Block a user