Merge branch 'main' into runners2

This commit is contained in:
Alex Cheema
2025-01-20 16:12:55 +00:00
39 changed files with 3028 additions and 291 deletions

2
.gitignore vendored
View File

@@ -171,3 +171,5 @@ cython_debug/
**/*.xcodeproj/*
.aider*
exo/tinychat/images/*.png

View File

@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
[![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
---
@@ -38,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
### Wide Model Support
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
### Dynamic Model Partitioning
@@ -100,13 +102,13 @@ source install.sh
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
1. Upgrade to the latest version of MacOS 15.
1. Upgrade to the latest version of macOS Sequoia.
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
## Documentation
### Example Usage on Multiple MacOS Devices
### Example Usage on Multiple macOS Devices
#### Device 1:
@@ -177,9 +179,9 @@ curl http://localhost:52415/v1/chat/completions \
}'
```
### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
#### Device 1 (MacOS):
#### Device 1 (macOS):
```sh
exo
@@ -244,7 +246,7 @@ python3 format.py ./exo
## Known Issues
- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
```sh
/Applications/Python 3.x/Install Certificates.command

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# Get the total memory in MB
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))

View File

@@ -0,0 +1,111 @@
import json
import re
import requests
def get_current_weather(location: str, unit: str = "celsius"):
"""Mock weather data function"""
# Hardcoded response for demo purposes
return {
"location": location,
"temperature": 22 if unit == "celsius" else 72,
"unit": unit,
"forecast": "Sunny with light clouds"
}
def try_parse_tool_calls(content: str):
"""Try parse the tool calls."""
tool_calls = []
offset = 0
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
if i == 0:
offset = m.start()
try:
func = json.loads(m.group(1))
tool_calls.append({"type": "function", "function": func})
if isinstance(func["arguments"], str):
func["arguments"] = json.loads(func["arguments"])
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
pass
if tool_calls:
if offset > 0 and content[:offset].strip():
c = content[:offset]
else:
c = ""
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
def chat_completion(messages):
"""Send chat completion request to local server"""
response = requests.post(
"http://localhost:52415/v1/chat/completions",
json={
"model": "qwen-2.5-1.5b",
"messages": messages,
"tools": [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}],
"tool_choice": "auto"
}
)
return response.json()
def main():
# Initial conversation
messages = [{
"role": "user",
"content": "Hi there, what's the weather in Boston?"
}]
# Get initial response
response = chat_completion(messages)
print(f"First response: {response}")
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
messages.append(assistant_message)
# If there are tool calls, execute them and continue conversation
if "tool_calls" in assistant_message:
for tool_call in assistant_message["tool_calls"]:
if tool_call["function"]["name"] == "get_current_weather":
args = tool_call["function"]["arguments"]
weather_data = get_current_weather(**args)
# Add tool response to messages
messages.append({
"role": "tool",
"content": json.dumps(weather_data),
"name": tool_call["function"]["name"]
})
# Get final response with weather data
response = chat_completion(messages)
print(f"Final response: {response}")
messages.append({
"role": "assistant",
"content": response["choices"][0]["message"]["content"]
})
# Print full conversation
for msg in messages:
print(f"\n{msg['role'].upper()}: {msg['content']}")
if __name__ == "__main__":
main()

View File

@@ -5,18 +5,24 @@ import json
import os
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
from typing import List, Literal, Union, Dict, Optional
from aiohttp import web
import aiohttp_cors
import traceback
import signal
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable, Optional
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import mlx.core as mx
import tempfile
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
@@ -24,23 +30,28 @@ from exo.apputil import create_animation_mp4
from collections import defaultdict
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
self.role = role
self.content = content
self.tools = tools
def to_dict(self):
return {"role": self.role, "content": self.content}
data = {"role": self.role, "content": self.content}
if self.tools:
data["tools"] = self.tools
return data
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
self.model = model
self.messages = messages
self.temperature = temperature
self.tools = tools
def to_dict(self):
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
def generate_completion(
@@ -120,20 +131,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
return remapped_messages
def build_prompt(tokenizer, _messages: List[Message]):
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
for message in messages:
if not isinstance(message.content, list):
continue
chat_template_args = {
"conversation": [m.to_dict() for m in messages],
"tokenize": False,
"add_generation_prompt": True
}
if tools: chat_template_args["tools"] = tools
prompt = tokenizer.apply_chat_template(**chat_template_args)
print(f"!!! Prompt: {prompt}")
return prompt
def parse_message(data: dict):
if "role" not in data or "content" not in data:
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
return Message(data["role"], data["content"])
return Message(data["role"], data["content"], data.get("tools"))
def parse_chat_request(data: dict, default_model: str):
@@ -141,6 +156,7 @@ def parse_chat_request(data: dict, default_model: str):
data.get("model", default_model),
[parse_message(msg) for msg in data["messages"]],
data.get("temperature", 0.0),
data.get("tools", None),
)
@@ -151,7 +167,7 @@ class PromptSession:
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
@@ -166,6 +182,7 @@ class ChatGPTAPI:
# Get the callback system and register our handler
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
self.system_prompt = system_prompt
cors = aiohttp_cors.setup(self.app)
cors_options = aiohttp_cors.ResourceOptions(
@@ -180,6 +197,7 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
@@ -191,10 +209,12 @@ class ChatGPTAPI:
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
@@ -241,7 +261,7 @@ class ChatGPTAPI:
)
await response.prepare(request)
for model_name, pretty in pretty_name.items():
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
@@ -269,6 +289,12 @@ class ChatGPTAPI:
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
# Process all models in parallel
await asyncio.gather(*[
process_model(model_name, pretty)
for model_name, pretty in pretty_name.items()
])
await response.write(b"data: [DONE]\n\n")
return response
@@ -281,7 +307,8 @@ class ChatGPTAPI:
)
async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
return web.json_response({"object": "list", "data": models_list})
async def handle_post_chat_token_encode(self, request):
data = await request.json()
@@ -294,7 +321,7 @@ class ChatGPTAPI:
shard = build_base_shard(model, self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
prompt = build_prompt(tokenizer, messages)
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
tokens = tokenizer.encode(prompt)
return web.json_response({
"length": len(prompt),
@@ -314,13 +341,13 @@ class ChatGPTAPI:
async def handle_post_chat_completions(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
chat_request = parse_chat_request(data, self.default_model)
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
chat_request.model = self.default_model
if not chat_request.model or chat_request.model not in model_cards:
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
chat_request.model = self.default_model
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
if not shard:
@@ -331,34 +358,26 @@ class ChatGPTAPI:
)
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
prompt = build_prompt(tokenizer, chat_request.messages)
# Add system prompt if set
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
chat_request.messages.insert(0, Message("system", self.system_prompt))
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
request_id = str(uuid.uuid4())
if self.on_chat_completion_request:
try:
self.on_chat_completion_request(request_id, chat_request, prompt)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
# request_id = None
# match = self.prompts.find_longest_prefix(prompt)
# if match and len(prompt) > len(match[1].prompt):
# if DEBUG >= 2:
# print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
# request_id = match[1].request_id
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
# # remove the matching prefix from the prompt
# prompt = prompt[len(match[1].prompt):]
# else:
# request_id = str(uuid.uuid4())
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
try:
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
if stream:
response = web.StreamResponse(
@@ -374,10 +393,12 @@ class ChatGPTAPI:
try:
# Stream tokens while waiting for inference to complete
while True:
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
token, is_finished = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
@@ -408,10 +429,13 @@ class ChatGPTAPI:
return response
except asyncio.TimeoutError:
if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
return web.json_response({"detail": "Response generation timed out"}, status=408)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
if DEBUG >= 2:
print(f"[ChatGPTAPI] Error processing prompt: {e}")
traceback.print_exc()
return web.json_response(
{"detail": f"Error processing prompt: {str(e)}"},
status=500
@@ -420,6 +444,7 @@ class ChatGPTAPI:
finally:
# Clean up the queue for this request
if request_id in self.token_queues:
if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
del self.token_queues[request_id]
else:
tokens = []
@@ -441,6 +466,85 @@ class ChatGPTAPI:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_post_image_generations(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
model = data.get("model", "")
prompt = data.get("prompt", "")
image_url = data.get("image_url", "")
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
shard = build_base_shard(model, self.inference_engine_classname)
if DEBUG >= 2: print(f"shard: {shard}")
if not shard:
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
request_id = str(uuid.uuid4())
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)
try:
if image_url != "" and image_url != None:
img = self.base64_decode(image_url)
else:
img = None
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
await response.prepare(request)
def get_progress_bar(current_step, total_steps, bar_length=50):
# Calculate the percentage of completion
percent = float(current_step) / total_steps
# Calculate the number of hashes to display
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
spaces = ' ' * (bar_length - len(arrow))
# Create the progress bar string
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
return progress_bar
async def stream_image(_request_id: str, result, is_finished: bool):
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
elif isinstance(result, np.ndarray):
im = Image.fromarray(np.array(result))
images_folder = get_exo_images_dir()
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = images_folder / image_filename
im.save(image_path)
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
# Construct the full URL correctly
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
stream_task = None
def on_result(_request_id: str, result, is_finished: bool):
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
await callback.wait(on_result, timeout=self.response_timeout*10)
if stream_task:
# Wait for the stream task to complete before returning
await stream_task
return response
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_delete_model(self, request):
try:
model_name = request.match_info.get('model_name')
@@ -553,7 +657,7 @@ class ChatGPTAPI:
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
return web.json_response({
"status": "success",
@@ -585,3 +689,19 @@ class ChatGPTAPI:
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
def base64_decode(self, base64_string):
#decode and reshape image
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
img = Image.open(BytesIO(image_data))
W, H = (dim - dim % 64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
img = img[None]
return img

View File

@@ -303,6 +303,10 @@ async def download_repo_files(
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
if model_index_exists:
allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)

View File

@@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader):
print(f"No snapshot directory found for {self.current_repo_id}")
return None
if not await aios.path.exists(snapshot_dir/"model_index.json"):
# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
else:
patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
# Check download status for all relevant files
status = {}

View File

@@ -350,4 +350,21 @@ async def get_mac_system_info() -> Tuple[str, str, int]:
return model_id, chip_id, memory
except Exception as e:
if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
return "Unknown Model", "Unknown Chip", 0
return "Unknown Model", "Unknown Chip", 0
def get_exo_home() -> Path:
if os.name == "nt": # Check if the OS is Windows
docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else:
docs_folder = Path.home() / "Documents"
exo_folder = docs_folder / "Exo"
if not exo_folder.exists():
exo_folder.mkdir()
return exo_folder
def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
if not images_dir.exists():
images_dir.mkdir()
return images_dir

View File

@@ -39,11 +39,15 @@ class InferenceEngine(ABC):
async def clear_session(self):
self.session.empty()
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
tokens = await self.encode(shard, prompt)
x = tokens.reshape(1, -1)
output_data = await self.infer_tensor(request_id, shard, x)
return output_data
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
else:
x = tokens
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
return output_data, inference_state
inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",

View File

@@ -0,0 +1,307 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
import time
from typing import Optional, Tuple
import inspect
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from tqdm import tqdm
from .sd_models.vae import ModelArgs as VAEArgs
from .sd_models.vae import Autoencoder
from .sd_models.tokenizer import load_tokenizer
from .sd_models.clip import CLIPTextModel
from .sd_models.clip import ModelArgs as CLIPArgs
from .sd_models.unet import UNetConfig, UNetModel
from dataclasses import dataclass, field
from exo.inference.shard import Shard
@dataclass
class DiffusionConfig:
beta_schedule: str = "scaled_linear"
beta_start: float = 0.00085
beta_end: float = 0.012
num_train_steps: int = 1000
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
#Sampler
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)
return (b - a) * x + a
def _interp(y, x_new):
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
x_low = x_new.astype(mx.int32)
x_high = mx.minimum(x_low + 1, len(y) - 1)
y_low = y[x_low]
y_high = y[x_high]
delta_x = x_new - x_low
y_new = y_low * (1 - delta_x) + delta_x * y_high
return y_new
class SimpleEulerSampler:
"""A simple Euler integrator that can be used to sample from our diffusion models.
The method ``step()`` performs one Euler step from x_t to x_t_prev.
"""
def __init__(self, config: DiffusionConfig):
# Compute the noise schedule
if config.beta_schedule == "linear":
betas = _linspace(
config.beta_start, config.beta_end, config.num_train_steps
)
elif config.beta_schedule == "scaled_linear":
betas = _linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
).square()
else:
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
alphas = 1 - betas
alphas_cumprod = mx.cumprod(alphas)
self._sigmas = mx.concatenate(
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
@property
def max_time(self):
return len(self._sigmas) - 1
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def add_noise(self, x, t, key=None):
noise = mx.random.normal(x.shape, key=key)
s = self.sigmas(t)
return (x + noise * s) * (s.square() + 1).rsqrt()
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
start_time = start_time or (len(self._sigmas) - 1)
assert 0 < start_time <= (len(self._sigmas) - 1)
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def current_timestep(self, step, total_steps, start_time=None):
if step < total_steps:
steps = self.timesteps(total_steps, start_time)
return steps[step]
else:
return mx.array(0),mx.array(0)
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
dt = sigma_prev - sigma
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev
@dataclass
class ShardConfig:
model_id:str
start_layer:int
end_layer:int
n_layers:int
@dataclass
class StableDiffusionConfig:
model_type:str
vae:VAEArgs
text_encoder:CLIPArgs
scheduler:DiffusionConfig
unet:UNetConfig
shard:ShardConfig
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(StableDiffusionConfig):
shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.model_type = config.model_type
self.config = config
self.model_path = config.vae['path'].split('/vae')[0]
self.shard = config.shard
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
if self.shard_clip.start_layer != -1:
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
else:
self.text_encoder = nn.Identity()
self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
self.sampler = SimpleEulerSampler(self.diffusion_config)
if self.shard_unet.start_layer!=-1:
self.config_unet = UNetConfig.from_dict(config.unet['config'])
self.unet = UNetModel(self.config_unet, self.shard_unet)
else:
self.unet = nn.Identity()
self.config_vae=VAEArgs.from_dict(config.vae['config'])
if self.shard_encoder.start_layer != -1:
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
else:
self.encoder = nn.Identity()
if self.shard_decoder.start_layer != -1:
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
else:
self.decoder = nn.Identity()
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
is_finished = False
is_step_finished = False
if t.item()==1000:
if self.shard_clip.start_layer == 0:
conditioning = x
if self.shard_clip.start_layer != -1:
conditioning, mask= self.text_encoder(conditioning,mask)
seed = int(time.time())
mx.random.seed(seed)
if image is None:
if self.shard_encoder.is_last_layer():
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
x_t_prev=x
start_step = self.sampler.max_time
else:
if self.shard_encoder.start_layer != -1:
image= self.encoder.encode(image)
if self.shard_encoder.is_last_layer():
start_step = self.sampler.max_time*strength
total_steps = int(total_steps*strength)
image = mx.broadcast_to(image, (1,) + image.shape[1:])
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
image = None
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
# Perform the denoising loop
if self.shard_unet.start_layer != -1:
with tqdm(total=total_steps,initial=step+1) as pbar:
if step<total_steps:
x = x_t_prev
if self.shard_unet.is_first_layer():
x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
else:
x_t_unet = x
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
if self.shard_unet.is_last_layer():
if cfg_weight > 1:
eps_text, eps_neg = x.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
x_t_prev=x
mx.eval(x)
if self.shard_decoder.is_last_layer():
is_step_finished=True
if self.shard_decoder.start_layer != -1:
x=self.decoder.decode(x)
if self.shard_decoder.is_last_layer():
x = mx.clip(x / 2 + 0.5, 0, 1)
B, H, W, C = x.shape
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(1 * H, B // 1 * W, C)
x = (x * 255).astype(mx.uint8)
if t_prev.item() ==0:
is_finished=True
mx.eval(x)
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
def load(self):
if self.shard_encoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.encoder.sanitize(vae_weights)
self.encoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_decoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.decoder.sanitize(vae_weights)
self.decoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_clip.start_layer != -1:
clip_weights = mx.load(self.config_clip.weight_files[0])
clip_weights = self.text_encoder.sanitize(clip_weights)
self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
if self.shard_unet.start_layer !=-1:
unet_weights = mx.load(self.config_unet.weight_files[0])
unet_weights = self.unet.sanitize(unet_weights)
self.unet.load_weights(list(unet_weights.items()), strict=True)
def model_shards(shard:ShardConfig):
def create_shard(shard, model_ranges):
start_layer = shard.start_layer
end_layer = shard.end_layer
shards = {}
for model_name, (range_start, range_end) in model_ranges.items():
if start_layer < range_end and end_layer >= range_start:
# Calculate the overlap with the model range
overlap_start = max(start_layer, range_start)
overlap_end = min(end_layer, range_end - 1)
# Adjust the layers relative to the model's range
relative_start = overlap_start - range_start
relative_end = overlap_end - range_start
shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
else:
# If no overlap, create a zero-layer shard
shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
return shards
# Define the ranges for different models
model_ranges = {
'clip': (0, 12),
'vae_encoder':(12,17),
'unet':(17,26),
'vae_decoder': (26, 31) # Example range for unet
}
# Call the function and get the shards for all models
shards = create_shard(shard, model_ranges)
# Access individual shards
shard_clip = shards['clip']
shard_encoder = shards['vae_encoder']
shard_unet = shards['unet']
shard_decoder = shards['vae_decoder']
return shard_clip, shard_encoder, shard_unet, shard_decoder

View File

@@ -0,0 +1,117 @@
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__()
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
if self.args.shard.is_first_layer():
h = self.embed_tokens(inputs)
else:
h = inputs
mask = None
if h.shape[1] > 1:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
if self.args.shard.is_last_layer():
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Phi3Model(args)
if self.args.shard.is_last_layer():
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
out = self.lm_head(out)
return out
def sanitize(self, weights):
shard_state_dict = {}
for key, value in weights.items():
if "self_attn.rope.inv_freq" in key:
continue
if key.startswith('model.layers.'):
layer_num = int(key.split('.')[2])
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
shard_state_dict[key] = value
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
shard_state_dict[key] = value
return shard_state_dict
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__() # Ensure parent initializations are respected
super().__post_init__()
if isinstance(self.shard, Shard):
return
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
self.shard = Shard(**self.shard)
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

View File

@@ -0,0 +1,191 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
import math
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import field, dataclass
from exo.inference.shard import Shard
from exo.inference.mlx.models.base import IdentityBlock
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
projection_dim: Optional[int] = None
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return ModelArgs(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
hidden_act=config.get("hidden_act", "quick_gelu"),
weight_files=config.get("weight_files", [])
)
@dataclass
class ModelArgs(CLIPTextModelConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
weight_files: List[str] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
@dataclass
class CLIPOutput:
pooled_output: Optional[mx.array] = None
last_hidden_state: Optional[mx.array] = None
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
self.attention.query_proj.bias = mx.zeros(model_dims)
self.attention.key_proj.bias = mx.zeros(model_dims)
self.attention.value_proj.bias = mx.zeros(model_dims)
self.attention.out_proj.bias = mx.zeros(model_dims)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
super().__init__()
self.shard = shard
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
if self.shard.is_first_layer():
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = []
for i in range(math.ceil(config.num_layers/2)):
if 2*i in self.layers_range:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
else:
self.layers.append(IdentityBlock())
if self.shard.is_last_layer():
self.final_layer_norm = nn.LayerNorm(config.model_dims)
if config.projection_dim is not None:
self.text_projection = nn.Linear(
config.model_dims, config.projection_dim, bias=False
)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def __call__(self, x, mask=None):
# Extract some shapes
if self.shard.is_first_layer():
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
# Apply the final layernorm and return
if self.shard.is_last_layer():
x = self.final_layer_norm(x)
return x, mask
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
if "position_ids" in key:
continue
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
if key.startswith("layers."):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if not self.shard.is_first_layer() and "embedding" in key:
continue
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
continue
if not self.shard.is_last_layer() and key.startswith("text_projection"):
continue
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,131 @@
# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
import regex
import json
import glob
class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
return tokens
def encode(self, prompt):
tokens = [self.tokenize(prompt)]
negative_text = ""
if negative_text is not None:
tokens += [self.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
return tokens
def load_tokenizer(
model_path: str,
vocab_key: str = "tokenizer_vocab",
merges_key: str = "tokenizer_merges",
):
vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return Tokenizer(bpe_ranks, vocab)

View File

@@ -0,0 +1,629 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import dataclass, field
from typing import Tuple, Optional, List
from exo.inference.shard import Shard
@dataclass
class UNetConfig:
in_channels: int = 4
out_channels: int = 4
conv_in_kernel: int = 3
conv_out_kernel: int = 3
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: Tuple[int] = (2, 2, 2, 2)
mid_block_layers: int = 2
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
)
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
projection_class_embeddings_input_dim: Optional[int] = None
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls,config):
n_blocks = len(config['block_out_channels'])
return UNetConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
transformer_layers_per_block=config.get(
"transformer_layers_per_block", (1,) * 4
),
num_attention_heads=(
[config["attention_head_dim"]] * n_blocks
if isinstance(config["attention_head_dim"], int)
else config["attention_head_dim"]
),
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
down_block_types=config["down_block_types"],
up_block_types=config["up_block_types"][::-1],
addition_embed_type=config.get("addition_embed_type", None),
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
projection_class_embeddings_input_dim=config.get(
"projection_class_embeddings_input_dim", None
),
weight_files=config.get("weight_files", [])
)
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def __call__(self, x):
x = self.linear_1(x)
x = nn.silu(x)
x = self.linear_2(x)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
model_dims: int,
num_heads: int,
hidden_dims: Optional[int] = None,
memory_dims: Optional[int] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(model_dims)
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
self.attn1.out_proj.bias = mx.zeros(model_dims)
memory_dims = memory_dims or model_dims
self.norm2 = nn.LayerNorm(model_dims)
self.attn2 = nn.MultiHeadAttention(
model_dims, num_heads, key_input_dims=memory_dims
)
self.attn2.out_proj.bias = mx.zeros(model_dims)
hidden_dims = hidden_dims or 4 * model_dims
self.norm3 = nn.LayerNorm(model_dims)
self.linear1 = nn.Linear(model_dims, hidden_dims)
self.linear2 = nn.Linear(model_dims, hidden_dims)
self.linear3 = nn.Linear(hidden_dims, model_dims)
def __call__(self, x, memory, attn_mask, memory_mask):
# Self attention
y = self.norm1(x)
y = self.attn1(y, y, y, attn_mask)
x = x + y
# Cross attention
y = self.norm2(x)
y = self.attn2(y, memory, memory, memory_mask)
x = x + y
# FFN
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu(y_b)
y = self.linear3(y)
x = x + y
return x
class Transformer2D(nn.Module):
"""A transformer model for inputs with 2 spatial dimensions."""
def __init__(
self,
in_channels: int,
model_dims: int,
encoder_dims: int,
num_heads: int,
num_layers: int = 1,
norm_num_groups: int = 32,
):
super().__init__()
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
self.proj_in = nn.Linear(in_channels, model_dims)
self.transformer_blocks = [
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
for i in range(num_layers)
]
self.proj_out = nn.Linear(model_dims, in_channels)
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
dtype = x.dtype
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
for block in self.transformer_blocks:
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
# Apply the output projection and reshape
x = self.proj_out(x)
x = x.reshape(B, H, W, C)
return x + input_x
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
groups: int = 32,
temb_channels: Optional[int] = None,
):
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if in_channels != out_channels:
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
dtype = x.dtype
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv2(y)
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
return x
class UNetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
prev_out_channels: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 8,
cross_attention_dim=1280,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
add_cross_attention=True,
):
super().__init__()
# Prepare the in channels list for the resnets
if prev_out_channels is None:
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
else:
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
in_channels_list = [
a + b for a, b in zip(in_channels_list, res_channels_list)
]
# Add resnet blocks that also process the time embedding
self.resnets = [
ResnetBlock2D(
in_channels=ic,
out_channels=out_channels,
temb_channels=temb_channels,
groups=resnet_groups,
)
for ic in in_channels_list
]
# Add optional cross attention layers
if add_cross_attention:
self.attentions = [
Transformer2D(
in_channels=out_channels,
model_dims=out_channels,
num_heads=num_attention_heads,
num_layers=transformer_layers_per_block,
encoder_dims=cross_attention_dim,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(
self,
x,
encoder_x=None,
temb=None,
attn_mask=None,
encoder_attn_mask=None,
residual_hidden_states=None,
):
output_states = []
for i in range(len(self.resnets)):
if residual_hidden_states is not None:
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
x = self.resnets[i](x, temb)
if "attentions" in self:
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
output_states.append(x)
if "downsample" in self:
x = self.downsample(x)
output_states.append(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
output_states.append(x)
return x, output_states
class UNetModel(nn.Module):
"""The conditional 2D UNet model that actually performs the denoising."""
def __init__(self, config: UNetConfig, shard: Shard):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
if shard.is_first_layer():
self.conv_in = nn.Conv2d(
config.in_channels,
config.block_out_channels[0],
config.conv_in_kernel,
padding=(config.conv_in_kernel - 1) // 2,
)
self.timesteps = nn.SinusoidalPositionalEncoding(
config.block_out_channels[0],
max_freq=1,
min_freq=math.exp(
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.time_embedding = TimestepEmbedding(
config.block_out_channels[0],
config.block_out_channels[0] * 4,
)
if config.addition_embed_type == "text_time":
self.add_time_proj = nn.SinusoidalPositionalEncoding(
config.addition_time_embed_dim,
max_freq=1,
min_freq=math.exp(
-math.log(10000)
+ 2 * math.log(10000) / config.addition_time_embed_dim
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.add_embedding = TimestepEmbedding(
config.projection_class_embeddings_input_dim,
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
)
self.down_blocks = []
for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):
if i in self.layers_range:
self.down_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
num_layers=config.layers_per_block[i],
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention="CrossAttn" in config.down_block_types[i],
)
)
else:
self.down_blocks.append(nn.Identity())
# Make the middle block
if 4 in self.layers_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
Transformer2D(
in_channels=config.block_out_channels[-1],
model_dims=config.block_out_channels[-1],
num_heads=config.num_attention_heads[-1],
num_layers=config.transformer_layers_per_block[-1],
encoder_dims=config.cross_attention_dim[-1],
),
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
]
# Make the upsampling blocks
block_channels = (
[config.block_out_channels[0]]
+ list(config.block_out_channels)
+ [config.block_out_channels[-1]]
)
total_items = len(block_channels) - 3
reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
self.up_blocks = []
for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):
i = total_items - rev_i
if rev_i+5 in self.layers_range:
self.up_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
prev_out_channels=prev_out_channels,
num_layers=config.layers_per_block[i] + 1,
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention="CrossAttn" in config.up_block_types[i],
)
)
else:
self.up_blocks.append(nn.Identity())
if shard.is_last_layer():
self.conv_norm_out = nn.GroupNorm(
config.norm_num_groups,
config.block_out_channels[0],
pytorch_compatible=True,
)
self.conv_out = nn.Conv2d(
config.block_out_channels[0],
config.out_channels,
config.conv_out_kernel,
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(
self,
x,
timestep,
encoder_x,
attn_mask=None,
encoder_attn_mask=None,
text_time=None,
residuals=None,
):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Add the extra text_time conditioning
if text_time is not None:
text_emb, time_ids = text_time
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
emb = mx.concatenate([text_emb, emb], axis=-1)
emb = self.add_embedding(emb)
temb = temb + emb
if self.shard.is_first_layer():
# Preprocess the input
x = self.conv_in(x)
residuals = [x]
# Run the downsampling part of the unet
for i in range(len(self.down_blocks)):
if i in self.layers_range:
x, res = self.down_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
)
residuals.extend(res)
else:
x= self.down_blocks[i](x)
if 4 in self.layers_range:
# Run the middle part of the unet
x = self.mid_blocks[0](x, temb)
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
x = self.mid_blocks[2](x, temb)
# Run the upsampling part of the unet
for i in range(len(self.up_blocks)):
if i+5 in self.layers_range:
x, _ = self.up_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
residual_hidden_states=residuals,
)
else:
x= self.up_blocks[i](x)
# Postprocess the output
if self.shard.is_last_layer():
dtype = x.dtype
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = nn.silu(x)
x = self.conv_out(x)
return x, residuals
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
k1=""
k2=""
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map transformer ffn
if "ff.net.2" in key:
key = key.replace("ff.net.2", "linear3")
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = mx.split(value, 2)
if "conv_shortcut.weight" in key:
value = value.squeeze()
# Transform the weights from 1x1 convs to linear
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if key.startswith("conv_in") :
if 0 not in self.layers_range:
continue
if key.startswith("down_blocks"):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if key.startswith("mid_block"):
if 4 not in self.layers_range:
continue
if key.startswith("up_blocks"):
layer_num = int(key.split(".")[1])
if (layer_num+5) not in self.layers_range:
continue
if key.startswith("conv_out") or key.startswith("conv_norm_out"):
if 8 not in self.layers_range:
continue
if len(k1)>0:
sanitized_weights[k1] = v1
sanitized_weights[k2] = v2
else:
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,429 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .unet import ResnetBlock2D, upsample_nearest
from dataclasses import dataclass, field
from exo.inference.shard import Shard
from typing import Tuple
import inspect
from ..base import IdentityBlock
@dataclass
class AutoencoderConfig:
in_channels: int = 3
out_channels: int = 3
latent_channels_out: int = 8
latent_channels_in: int = 4
block_out_channels: Tuple[int] = (128, 256, 512, 512)
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.18215
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(AutoencoderConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
class Attention(nn.Module):
"""A single head unmasked attention for use with the VAE."""
def __init__(self, dims: int, norm_groups: int = 32):
super().__init__()
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
self.query_proj = nn.Linear(dims, dims)
self.key_proj = nn.Linear(dims, dims)
self.value_proj = nn.Linear(dims, dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x):
B, H, W, C = x.shape
y = self.group_norm(x)
queries = self.query_proj(y).reshape(B, H * W, C)
keys = self.key_proj(y).reshape(B, H * W, C)
values = self.value_proj(y).reshape(B, H * W, C)
scale = 1 / math.sqrt(queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 2, 1)
attn = mx.softmax(scores, axis=-1)
y = (attn @ values).reshape(B, H, W, C)
y = self.out_proj(y)
x = x + y
return x
class EncoderDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
):
super().__init__()
# Add the resnet blocks
self.resnets = [
ResnetBlock2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
groups=resnet_groups,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=0
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x):
for resnet in self.resnets:
x = resnet(x)
if "downsample" in self:
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.downsample(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
return x
class Encoder(nn.Module):
"""Implements the encoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
latent_channels_out: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
layers_range: List[int] = [],
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
):
super().__init__()
self.layers_range = layers_range
self.shard = shard
if self.shard.is_first_layer():
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
)
channels = [block_out_channels[0]] + list(block_out_channels)
self.down_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in self.layers_range:
self.down_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=i < len(block_out_channels) - 1,
add_upsample=False,
)
)
else:
self.down_blocks.append(IdentityBlock())
current_layer += 1
if self.shard.is_last_layer():
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[-1], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
def __call__(self, x):
if self.shard.is_first_layer():
x = self.conv_in(x)
for l in self.down_blocks:
x = l(x)
if self.shard.is_last_layer():
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
"""Implements the decoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
shard: Shard,
layer_range: List[int],
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.out_channels = out_channels
self.layers_range = layer_range
if 0 in layer_range:
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
)
if 0 in layer_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
channels = list(reversed(block_out_channels))
channels = [channels[0]] + channels
self.up_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in layer_range:
self.up_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=False,
add_upsample=i < len(block_out_channels) - 1,
)
)
else:
self.up_blocks.append(IdentityBlock())
current_layer += 1
if 4 in layer_range:
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[0], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
def __call__(self, x):
if 0 in self.layers_range:
x = self.conv_in(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
for l in self.up_blocks:
x = l(x)
if 4 in self.layers_range:
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Autoencoder(nn.Module):
"""The autoencoder that allows us to perform diffusion in the latent space."""
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
self.latent_channels = config.latent_channels_in
self.scaling_factor = config.scaling_factor
self.model_shard = model_shard
if self.model_shard == "vae_encoder":
self.encoder = Encoder(
config.in_channels,
config.latent_channels_out,
config.block_out_channels,
config.layers_per_block,
resnet_groups=config.norm_num_groups,
layers_range=self.layers_range,
shard=shard
)
if self.shard.is_last_layer():
self.quant_proj = nn.Linear(
config.latent_channels_out, config.latent_channels_out
)
if self.model_shard == "vae_decoder":
self.decoder = Decoder(
config.latent_channels_in,
config.out_channels,
shard,
self.layers_range,
config.block_out_channels,
config.layers_per_block + 1,
resnet_groups=config.norm_num_groups,
)
if self.shard.is_first_layer():
self.post_quant_proj = nn.Linear(
config.latent_channels_in, config.latent_channels_in
)
def decode(self, z):
if self.shard.is_first_layer():
z = z / self.scaling_factor
z=self.post_quant_proj(z)
return self.decoder(z)
def encode(self, x):
x = self.encoder(x)
if self.shard.is_last_layer():
x = self.quant_proj(x)
mean, logvar = x.split(2, axis=-1)
mean = mean * self.scaling_factor
logvar = logvar + 2 * math.log(self.scaling_factor)
x = mean
return x
def __call__(self, x, key=None):
mean, logvar = self.encode(x)
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
def sanitize(self, weights):
shard = self.shard
layers = self.layers_range
sanitized_weights = {}
for key, value in weights.items():
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map attention layers
if "key" in key:
key = key.replace("key", "key_proj")
if "proj_attn" in key:
key = key.replace("proj_attn", "out_proj")
if "query" in key:
key = key.replace("query", "query_proj")
if "value" in key:
key = key.replace("value", "value_proj")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map the quant/post_quant layers
if "quant_conv" in key:
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
# Map the conv_shortcut to linear
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if "post_quant_conv" in key :
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
if 'decoder' in key and self.model_shard == "vae_decoder":
if key.startswith("decoder.mid_blocks."):
if 0 in layers:
sanitized_weights[key] = value
if "conv_in" in key and 0 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.up_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_norm_out") and 4 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_out") and 4 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_decoder":
if key.startswith("post_quant_proj") and 0 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_encoder":
if key.startswith("encoder."):
if "conv_in" in key and shard.is_first_layer():
sanitized_weights[key] = value
if key.startswith("encoder.down_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_norm_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if key.startswith("quant_proj") and shard.is_last_layer():
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -12,6 +12,7 @@ from exo.download.shard_download import ShardDownloader
import asyncio
from collections import OrderedDict
from mlx_lm.models.cache import make_prompt_cache
from concurrent.futures import ThreadPoolExecutor
class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
@@ -20,6 +21,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
self.caches = OrderedDict()
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
async def _eval_mlx(self, *args):
loop = asyncio.get_running_loop()
await loop.run_in_executor(self._mlx_thread, mx.eval, *args)
async def poll_state(self, request_id: str, max_caches=2):
if request_id in self.caches:
@@ -38,16 +45,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
logits = mx.array(x)
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return np.asarray(self.sampler(logprobs), dtype=int)
result = self.sampler(logprobs)
await self._eval_mlx(result)
return np.asarray(result, dtype=int)
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
await self.ensure_shard(shard)
tokens = self.tokenizer.encode(prompt)
return np.asarray(tokens)
loop = asyncio.get_running_loop()
return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))
async def decode(self, shard: Shard, tokens) -> str:
await self.ensure_shard(shard)
return self.tokenizer.decode(tokens)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)
async def save_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
@@ -56,13 +66,18 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
self.model.load_weights(path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
await self.ensure_shard(shard)
state = await self.poll_state(request_id)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
output_data = np.array(self.model(x, **state), copy=False)
return output_data
if self.model.model_type != 'StableDiffusionPipeline':
output_data = self.model(x, **state, **inference_state)
else:
output_data, inference_state = self.model(x, **state, **inference_state)
output_data = np.array(output_data, copy=False)
return output_data, inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)
@@ -87,26 +102,25 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
return True
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
loop = asyncio.get_running_loop()
nothin = await self.ensure_train(shard, loss, opt, lr)
await self.ensure_train(shard, loss, opt, lr)
def train_step(inp, tar, lng):
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
gradlayers = grad['model']['layers']
self.session['opt'].update(self.model, grad)
mx.eval(self.model.parameters(), self.session['opt'].state, lval)
return lval, gradlayers
return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)
score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
#print(f"{score=}")
score, gradients, eval_args = train_step(x, y, l)
await self._eval_mlx(*eval_args)
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
#print(layers[0])
return score, np.array(layers[0]['input_layernorm'], copy=False)
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
await self._eval_mlx(first_layer)
return score, first_layer
async def ensure_shard(self, shard: Shard):
if self.shard == shard:
@@ -121,3 +135,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
self.caches = OrderedDict()
self.session = {}
async def cleanup(self):
self._mlx_thread.shutdown(wait=True)

View File

@@ -62,8 +62,16 @@ def _get_classes(config: dict):
def load_config(model_path: Path) -> dict:
try:
with open(model_path/"config.json", "r") as f:
config = json.load(f)
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
return config
model_index_path = model_path / "model_index.json"
if model_index_path.exists():
config = load_model_index(model_path, model_index_path)
return config
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
@@ -110,6 +118,24 @@ def load_model_shard(
# Try weight for back-compat
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if config.get("model_index", False):
model.load()
return model
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -129,19 +155,7 @@ def load_model_shard(
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
@@ -186,6 +200,9 @@ async def load_shard(
processor.eos_token_id = processor.tokenizer.eos_token_id
processor.encode = processor.tokenizer.encode
return model, processor
elif hasattr(model, "tokenizer"):
tokenizer = model.tokenizer
return model, tokenizer
else:
tokenizer = await resolve_tokenizer(model_path)
return model, tokenizer
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
return img
else:
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
# loading a combined config for all models in the index
def load_model_index(model_path: Path, model_index_path: Path):
models_config = {}
with open(model_index_path, "r") as f:
model_index = json.load(f)
models_config["model_index"] = True
models_config["model_type"] = model_index["_class_name"]
models_config["models"] = {}
for model in model_index.keys():
model_config_path = glob.glob(str(model_path / model / "*config.json"))
if len(model_config_path)>0:
with open(model_config_path[0], "r") as f:
model_config = { }
model_config["model_type"] = model
model_config["config"] = json.load(f)
model_config["path"] = model_path / model
if model_config["path"]/"*model.safetensors":
model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
model_config["path"] = str(model_path / model)
m = {}
m[model] = model_config
models_config.update(m)
return models_config

View File

@@ -0,0 +1,81 @@
import asyncio
import time
import numpy as np
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.inference.shard import Shard
from exo.models import build_base_shard
from collections import deque
from statistics import mean, median
async def test_non_blocking():
# Setup
shard_downloader = HFShardDownloader()
engine = MLXDynamicShardInferenceEngine(shard_downloader)
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
await engine.ensure_shard(shard)
queue = asyncio.Queue()
measurements = deque(maxlen=1000000)
running = True
async def mlx_worker():
try:
start_time = time.time()
count = 0
while running and (time.time() - start_time) < 5: # Hard time limit
start = time.perf_counter_ns()
await engine.infer_prompt("req1", shard, "test prompt")
duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
count += 1
print(f"MLX operation {count} took: {duration:.3f}ms")
except asyncio.CancelledError:
pass
finally:
print(f"\nTotal MLX operations completed: {count}")
print(f"Average rate: {count/5:.1f} ops/second")
async def latency_producer():
try:
start_time = time.perf_counter_ns()
count = 0
while running:
await queue.put(time.perf_counter_ns())
count += 1
await asyncio.sleep(0) # Yield to event loop without delay
duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds
print(f"\nProducer iterations: {count}")
print(f"Producer rate: {count/duration:.1f} iterations/second")
except asyncio.CancelledError:
pass
async def latency_consumer():
try:
while running:
timestamp = await queue.get()
latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms
measurements.append(latency)
queue.task_done()
except asyncio.CancelledError:
pass
tasks = [
asyncio.create_task(mlx_worker()),
asyncio.create_task(latency_producer()),
asyncio.create_task(latency_consumer())
]
try:
await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
except asyncio.TimeoutError:
print("\nTest timed out")
finally:
running = False
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
print(f"\nFinal measurement count: {len(measurements)}")
if __name__ == "__main__":
asyncio.run(test_non_blocking())

View File

@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
from .losses import length_masked_ce_loss
from collections import OrderedDict
import asyncio
from typing import Optional
Tensor.no_grad = True
# default settings
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
safe_save(state_dict, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
await self.ensure_shard(shard)
def wrap_infer():
x = Tensor(input_data)
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
self.states[request_id].start += x.shape[1]
return out.realize()
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
return output_data.numpy()
return output_data.numpy(), inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
def step(x, y, l):

View File

@@ -14,7 +14,7 @@ class DummyTokenizer:
self.eos_token_id = 69
self.vocab_size = 1000
def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
return "dummy_tokenized_prompt"
def encode(self, text):

View File

@@ -103,6 +103,7 @@ parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailsca
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
@@ -182,11 +183,12 @@ api = ChatGPTAPI(
inference_engine.__class__.__name__,
response_timeout=args.chatgpt_api_response_timeout,
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
default_model=args.default_model
default_model=args.default_model,
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
)
# node.on_token.register("update_topology_viz").on_next(
# lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
# )
def preemptively_start_download(request_id: str, opaque_status: str):
try:

View File

@@ -92,14 +92,17 @@ model_cards = {
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
### qwen
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
### nemotron
@@ -108,6 +111,11 @@ model_cards = {
# gemma
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
# stable diffusion
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
# phi
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
# dummy
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
}
@@ -133,18 +141,24 @@ pretty_name = {
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
"qwen-2.5-3b": "Qwen 2.5 3B",
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-7b": "Qwen 2.5 7B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
"qwen-2.5-14b": "Qwen 2.5 14B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-32b": "Qwen 2.5 32B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-72b": "Qwen 2.5 72B",
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
"phi-3.5-mini": "Phi-3.5 Mini",
"phi-4": "Phi-4",
"llama-3-8b": "Llama 3 8B",
"llama-3-70b": "Llama 3 70B",
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
}
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

View File

@@ -11,7 +11,8 @@ from exo.inference.shard import Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.helpers import DEBUG
import json
import mlx.core as mx
class GRPCPeerHandle(PeerHandle):
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
@@ -90,7 +91,7 @@ class GRPCPeerHandle(PeerHandle):
traceback.print_exc()
return False
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None:
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
@@ -100,10 +101,11 @@ class GRPCPeerHandle(PeerHandle):
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
)
await self.stub.SendPrompt(request)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None:
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
@@ -113,8 +115,14 @@ class GRPCPeerHandle(PeerHandle):
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
)
await self.stub.SendTensor(request)
response =await self.stub.SendTensor(request)
if not response.tensor_data or not response.shape or not response.dtype:
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(
@@ -173,10 +181,44 @@ class GRPCPeerHandle(PeerHandle):
topology.add_edge(node_id, conn.to_id, conn.description)
return topology
async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished)
await self.stub.SendNewToken(request)
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
tensor = None
if isinstance(result, np.ndarray):
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
result = []
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
await self.stub.SendResult(request)
async def send_opaque_status(self, request_id: str, status: str) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
await self.stub.SendOpaqueStatus(request)
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state

View File

@@ -8,6 +8,8 @@ from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
import json
import mlx.core as mx
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -58,9 +60,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
)
prompt = request.prompt
request_id = request.request_id
await self.node.process_prompt(shard, prompt, request_id)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}")
return node_service_pb2.Empty()
inference_state = self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def SendTensor(self, request, context):
shard = Shard(
@@ -71,9 +75,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
)
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id
await self.node.process_tensor(shard, tensor, request_id)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}")
return node_service_pb2.Empty()
inference_state = self.deserialize_inference_state(request.inference_state)
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def SendExample(self, request, context):
shard = Shard(
@@ -127,8 +135,12 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
request_id = request.request_id
token = request.token
is_finished = request.is_finished
if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=}")
self.node.on_token.trigger_all(request_id, token, is_finished)
img = request.tensor
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
result = list(result)
if len(img.tensor_data) > 0:
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()
async def SendOpaqueStatus(self, request, context):
@@ -140,3 +152,22 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)
def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
inference_state = {}
for k, tensor_data in inference_state_proto.tensor_data.items():
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
for k, tensor_list in inference_state_proto.tensor_list_data.items():
inference_state[k] = [
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
for tensor in tensor_list.tensors
]
if inference_state_proto.other_data_json:
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
return inference_state

View File

@@ -23,12 +23,14 @@ message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message ExampleRequest {
@@ -51,6 +53,16 @@ message Tensor {
string dtype = 3;
}
message TensorList {
repeated Tensor tensors = 1;
}
message InferenceState {
map<string, Tensor> tensor_data = 1;
map<string, TensorList> tensor_list_data = 2;
string other_data_json = 3;
}
message CollectTopologyRequest {
repeated string visited = 1;
int32 max_depth = 2;
@@ -85,8 +97,9 @@ message DeviceCapabilities {
message SendNewTokenRequest {
string request_id = 1;
int32 token = 2;
bool is_finished = 3;
repeated int32 result = 2;
optional Tensor tensor = 3;
bool is_finished = 4;
}
message SendOpaqueStatusRequest {

View File

@@ -24,55 +24,67 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"M\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\r\n\x05token\x18\x02 \x01(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x84\x01\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_options = b'8\001'
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._loaded_options = None
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
_globals['_SHARD']._serialized_start=36
_globals['_SHARD']._serialized_end=119
_globals['_PROMPTREQUEST']._serialized_start=121
_globals['_PROMPTREQUEST']._serialized_end=228
_globals['_TENSORREQUEST']._serialized_start=231
_globals['_TENSORREQUEST']._serialized_end=360
_globals['_EXAMPLEREQUEST']._serialized_start=363
_globals['_EXAMPLEREQUEST']._serialized_end=585
_globals['_LOSS']._serialized_start=587
_globals['_LOSS']._serialized_end=659
_globals['_TENSOR']._serialized_start=661
_globals['_TENSOR']._serialized_end=720
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=722
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=782
_globals['_TOPOLOGY']._serialized_start=785
_globals['_TOPOLOGY']._serialized_end=1065
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=906
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=984
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=986
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1065
_globals['_PEERCONNECTION']._serialized_start=1067
_globals['_PEERCONNECTION']._serialized_end=1140
_globals['_PEERCONNECTIONS']._serialized_start=1142
_globals['_PEERCONNECTIONS']._serialized_end=1210
_globals['_DEVICEFLOPS']._serialized_start=1212
_globals['_DEVICEFLOPS']._serialized_end=1267
_globals['_DEVICECAPABILITIES']._serialized_start=1269
_globals['_DEVICECAPABILITIES']._serialized_end=1376
_globals['_SENDNEWTOKENREQUEST']._serialized_start=1378
_globals['_SENDNEWTOKENREQUEST']._serialized_end=1455
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1457
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1518
_globals['_HEALTHCHECKREQUEST']._serialized_start=1520
_globals['_HEALTHCHECKREQUEST']._serialized_end=1540
_globals['_HEALTHCHECKRESPONSE']._serialized_start=1542
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1583
_globals['_EMPTY']._serialized_start=1585
_globals['_EMPTY']._serialized_end=1592
_globals['_NODESERVICE']._serialized_start=1595
_globals['_NODESERVICE']._serialized_end=2132
_globals['_PROMPTREQUEST']._serialized_start=122
_globals['_PROMPTREQUEST']._serialized_end=309
_globals['_TENSORREQUEST']._serialized_start=312
_globals['_TENSORREQUEST']._serialized_end=521
_globals['_EXAMPLEREQUEST']._serialized_start=524
_globals['_EXAMPLEREQUEST']._serialized_end=746
_globals['_LOSS']._serialized_start=748
_globals['_LOSS']._serialized_end=820
_globals['_TENSOR']._serialized_start=822
_globals['_TENSOR']._serialized_end=881
_globals['_TENSORLIST']._serialized_start=883
_globals['_TENSORLIST']._serialized_end=934
_globals['_INFERENCESTATE']._serialized_start=937
_globals['_INFERENCESTATE']._serialized_end=1275
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337
_globals['_TOPOLOGY']._serialized_start=1340
_globals['_TOPOLOGY']._serialized_end=1620
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620
_globals['_PEERCONNECTION']._serialized_start=1622
_globals['_PEERCONNECTION']._serialized_end=1695
_globals['_PEERCONNECTIONS']._serialized_start=1697
_globals['_PEERCONNECTIONS']._serialized_end=1765
_globals['_DEVICEFLOPS']._serialized_start=1767
_globals['_DEVICEFLOPS']._serialized_end=1822
_globals['_DEVICECAPABILITIES']._serialized_start=1824
_globals['_DEVICECAPABILITIES']._serialized_end=1931
_globals['_SENDNEWTOKENREQUEST']._serialized_start=1934
_globals['_SENDNEWTOKENREQUEST']._serialized_end=2066
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2068
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2129
_globals['_HEALTHCHECKREQUEST']._serialized_start=2131
_globals['_HEALTHCHECKREQUEST']._serialized_end=2151
_globals['_HEALTHCHECKRESPONSE']._serialized_start=2153
_globals['_HEALTHCHECKRESPONSE']._serialized_end=2194
_globals['_EMPTY']._serialized_start=2196
_globals['_EMPTY']._serialized_end=2203
_globals['_NODESERVICE']._serialized_start=2206
_globals['_NODESERVICE']._serialized_end=2743
# @@protoc_insertion_point(module_scope)

View File

@@ -1,7 +1,9 @@
import os
import asyncio
from exo.networking.discovery import Discovery
from typing import Dict, List, Callable
from typing import Dict, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor
from exo.networking.discovery import Discovery
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
):
self.topology = NetworkTopology.from_path(network_config_path)
self.network_config_path = network_config_path
self.node_id = node_id
self.create_peer_handle = create_peer_handle
if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)
self.listen_task = None
self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.peers_in_network.pop(node_id)
self._cached_peers: Dict[str, PeerConfig] = {}
self._last_modified_time: Optional[float] = None
self._file_executor = ThreadPoolExecutor(max_workers=1)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
if self.listen_task: self.listen_task.cancel()
self._file_executor.shutdown(wait=True)
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
peers_from_config = await self._get_peers()
new_known_peers = {}
for peer_id, peer_config in peers_from_config.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
@@ -57,15 +58,43 @@ class ManualDiscovery(Discovery):
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass
new_known_peers[peer_id] = peer
elif DEBUG_DISCOVERY >= 2:
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
await asyncio.sleep(5.0)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
async def _get_peers(self):
try:
loop = asyncio.get_running_loop()
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
return self._cached_peers
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
if self.node_id not in topology.peers:
raise ValueError(
f"Node ID {self.node_id} not found in network config file "
f"{self.network_config_path}. Please run with `node_id` set to "
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
)
peers_in_network = topology.peers
peers_in_network.pop(self.node_id)
self._cached_peers = peers_in_network
self._last_modified_time = current_mtime
return peers_in_network
except Exception as e:
if DEBUG_DISCOVERY >= 2:
print(f"Error when loading network config file from {self.network_config_path}. "
f"Please update the config file in order to successfully discover peers. "
f"Exception: {e}")
return self._cached_peers

View File

@@ -29,4 +29,4 @@
}
}
}
}
}

View File

@@ -1,3 +1,4 @@
import json
import asyncio
import unittest
from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
_ = self.discovery1.start()
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
await self.discovery1.start()
async def asyncTearDown(self):
await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
)
await self.discovery1.start()
await self.discovery2.start()
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
await self.discovery1.start()
await self.discovery2.start()
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
async def test_dynamic_config_update(self):
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(initial_peers), 1)
# Save original config for cleanup
with open(root_path, "r") as f:
original_config = json.load(f)
try:
updated_config = {
"peers": {
**original_config["peers"],
"node3": {
"address": "localhost",
"port": 50053,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
},
},
}
}
with open(root_path, "w") as f:
json.dump(updated_config, f, indent=2)
node3 = mock.AsyncMock(spec=Node)
server3 = GRPCServer(node3, "localhost", 50053)
await server3.start()
try:
# Wait for the config to be reloaded
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
self.assertEqual(len(updated_peers), 2)
for peer in updated_peers:
await peer.connect()
self.assertTrue(await peer.is_connected())
finally:
await server3.stop()
finally:
# Restore the original config file
with open(root_path, "w") as f:
json.dump(original_config, f, indent=2)
# Wait for the config to be reloaded again
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(updated_peers), 1)
if __name__ == "__main__":
asyncio.run(unittest.main())

View File

@@ -118,44 +118,50 @@ class Node:
shard,
result: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
):
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
self.token_count += 1
if self.token_count == 1:
self.first_token_time = time.perf_counter_ns()
if self.token_count % 20 == 0:
print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}")
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
forward = token.reshape(1, -1)
self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished))
if shard.model_id != 'stable-diffusion-2-1-base':
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
forward = token.reshape(1, -1)
intermediate_result = self.buffered_token_output[request_id][0]
else:
forward = result
else:
await self.inference_engine.ensure_shard(shard)
is_finished = inference_state.get("is_finished", False)
intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
forward = result
if shard.is_last_layer():
self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
if is_finished:
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
if shard.model_id != 'stable-diffusion-2-1-base':
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
self.outstanding_requests.pop(request_id)
else:
self.outstanding_requests[request_id] = "waiting"
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
return np.array(self.buffered_token_output[request_id][0])
async def process_prompt(
self,
base_shard: Shard,
prompt: str,
request_id: Optional[str] = None,
) -> None:
inference_state: Optional[dict] = {},
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
start_time = time.perf_counter_ns()
asyncio.create_task(
@@ -172,7 +178,8 @@ class Node:
}),
)
)
await self._process_prompt(base_shard, prompt, request_id)
start_time = time.perf_counter_ns()
resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
@@ -192,7 +199,7 @@ class Node:
)
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
@@ -201,12 +208,13 @@ class Node:
if not shard.is_first_layer():
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
self.outstanding_requests[request_id] = "waiting"
await self.forward_prompt(shard, prompt, request_id, 0)
resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
return None
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
await self.process_inference_result(shard, result, request_id)
else:
self.outstanding_requests[request_id] = "processing"
result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return result
async def enqueue_example(
self,
@@ -350,10 +358,11 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
) -> None:
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
start_time = time.perf_counter_ns()
await self._process_tensor(shard, tensor, request_id)
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
@@ -363,15 +372,17 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
) -> None:
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
try:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
await self.process_inference_result(shard, result, request_id)
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return ret
except Exception as e:
self.outstanding_requests.pop(request_id)
print(f"Error processing tensor for shard {shard}: {e}")
@@ -404,19 +415,20 @@ class Node:
prompt: str,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
if target_id == self.id:
await self.process_prompt(next_shard, prompt, request_id)
await self.process_prompt(next_shard, prompt, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
async def forward_tensor(
self,
@@ -424,19 +436,20 @@ class Node:
tensor: np.ndarray,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
if target_id == self.id:
await self.process_tensor(next_shard, tensor, request_id)
await self.process_tensor(next_shard, tensor, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
def get_partition_index(self, offset: int = 0):
if not self.partitioning_strategy:
@@ -604,3 +617,12 @@ class Node:
@property
def current_topology(self) -> Topology:
return self.topology
def handle_stable_diffusion(self, inference_state, result):
if inference_state['is_step_finished']:
inference_state['step']+=1
progress = [inference_state['step'],inference_state['total_steps']]
intermediate_result = result
if progress[0] == progress[1]:
intermediate_result = result
return intermediate_result, inference_state

View File

@@ -0,0 +1,166 @@
from dataclasses import dataclass
from typing import Dict, Optional, Any
from opentelemetry import trace, context
from opentelemetry.trace import Status, StatusCode, SpanContext
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from contextlib import contextmanager
import time
from threading import Lock
@dataclass
class TraceContext:
request_id: str
sequence_number: int
current_span: Optional[trace.Span] = None
trace_parent: Optional[str] = None
token_group_span: Optional[trace.Span] = None
token_count: int = 0
token_group_size: int = 10 # Default group size
request_span: Optional[trace.Span] = None # Track the main request span
class Tracer:
def __init__(self):
self.tracer = trace.get_tracer("exo")
self.contexts: Dict[str, TraceContext] = {}
self._lock = Lock()
self.propagator = TraceContextTextMapPropagator()
def get_context(self, request_id: str) -> Optional[TraceContext]:
with self._lock:
return self.contexts.get(request_id)
def set_context(self, request_id: str, context: TraceContext):
with self._lock:
self.contexts[request_id] = context
def inject_context(self, span: trace.Span) -> str:
"""Inject current span context into carrier for propagation"""
carrier = {}
ctx = trace.set_span_in_context(span)
self.propagator.inject(carrier, context=ctx)
return carrier.get("traceparent", "")
def extract_context(self, trace_parent: str) -> Optional[context.Context]:
"""Extract span context from carrier"""
if not trace_parent:
return None
carrier = {"traceparent": trace_parent}
return self.propagator.extract(carrier)
def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
"""Create a new context with the given trace parent"""
parent_ctx = self.extract_context(trace_parent)
if parent_ctx:
# Create a new request span that links to the parent context
request_span = self.tracer.start_span(
"request",
context=parent_ctx,
attributes={
"request_id": request_id,
"sequence_number": sequence_number
}
)
return TraceContext(
request_id=request_id,
sequence_number=sequence_number,
request_span=request_span,
current_span=request_span,
trace_parent=trace_parent
)
return TraceContext(request_id=request_id, sequence_number=sequence_number)
def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
"""Handle token generation and manage token group spans"""
context.token_count += 1
# Start a new token group span if needed
if not context.token_group_span and context.request_span:
group_number = (context.token_count - 1) // context.token_group_size + 1
# Create token group span as child of request span
parent_ctx = trace.set_span_in_context(context.request_span)
context.token_group_span = self.tracer.start_span(
f"token_group_{group_number}",
context=parent_ctx,
attributes={
"request_id": context.request_id,
"group.number": group_number,
"group.start_token": context.token_count,
"group.max_tokens": context.token_group_size
}
)
# Add token to current group span
if context.token_group_span:
relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
context.token_group_span.set_attribute(f"token.{relative_pos}", token)
context.token_group_span.set_attribute("token.count", relative_pos)
# End current group span if we've reached the group size or if generation is finished
if context.token_count % context.token_group_size == 0 or is_finished:
context.token_group_span.set_attribute("token.final_count", relative_pos)
context.token_group_span.end()
context.token_group_span = None
@contextmanager
def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
"""Start a new span with proper parent context"""
attributes = {
"request_id": context.request_id,
"sequence_number": context.sequence_number
}
if extra_attributes:
attributes.update(extra_attributes)
# Use request span as parent if available
parent_ctx = None
if context.request_span:
parent_ctx = trace.set_span_in_context(context.request_span)
elif context.trace_parent:
parent_ctx = self.extract_context(context.trace_parent)
if parent_ctx and not context.request_span:
# Create a new request span that links to the parent context
context.request_span = self.tracer.start_span(
"request",
context=parent_ctx,
attributes={
"request_id": context.request_id,
"sequence_number": context.sequence_number
}
)
parent_ctx = trace.set_span_in_context(context.request_span)
elif context.current_span:
parent_ctx = trace.set_span_in_context(context.current_span)
# Create span with parent context if it exists
if parent_ctx:
span = self.tracer.start_span(
name,
context=parent_ctx,
attributes=attributes
)
else:
span = self.tracer.start_span(
name,
attributes=attributes
)
# Update context with current span
prev_span = context.current_span
context.current_span = span
try:
start_time = time.perf_counter()
yield span
duration = time.perf_counter() - start_time
span.set_attribute("duration_s", duration)
span.set_status(Status(StatusCode.OK))
except Exception as e:
span.set_status(Status(StatusCode.ERROR, str(e)))
raise
finally:
span.end()
context.current_span = prev_span
# Global tracer instance
tracer = Tracer()

View File

@@ -197,7 +197,25 @@
const div = document.createElement('div');
div.className = `message message-role-${role}`;
try {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
if (content.includes('![Generated Image]')) {
const imageUrl = content.match(/\((.*?)\)/)[1];
const img = document.createElement('img');
img.src = imageUrl;
img.alt = 'Generated Image';
img.onclick = async () => {
try {
const response = await fetch(img.src);
const blob = await response.blob();
const file = new File([blob], 'image.png', { type: 'image/png' });
handleImageUpload({ target: { files: [file] } });
} catch (error) {
console.error('Error fetching image:', error);
}
};
div.appendChild(img);
} else {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
}
} catch (e) {
console.log(content);
console.error(e);
@@ -281,7 +299,7 @@
</span>
</div>
<div class="input">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
<i class="fas fa-image"></i>
</button>
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>

View File

@@ -231,53 +231,110 @@ document.addEventListener("alpine:init", () => {
};
}
});
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
// Send a request to the image generation endpoint
console.log(apiMessages[apiMessages.length - 1].content)
console.log(this.cstate.selectedModel)
console.log(this.endpoint)
const response = await fetch(`${this.endpoint}/image/generations`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
"model": 'stable-diffusion-2-1-base',
"prompt": apiMessages[apiMessages.length - 1].content,
"image_url": this.imageUrl
}),
});
}
// start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
if (!response.ok) {
throw new Error("Failed to fetch");
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
const reader = response.body.getReader();
let done = false;
let gottenFirstChunk = false;
while (!done) {
const { value, done: readerDone } = await reader.read();
done = readerDone;
const decoder = new TextDecoder();
if (value) {
// Assume non-binary data (text) comes first
const chunk = decoder.decode(value, { stream: true });
const parsed = JSON.parse(chunk);
console.log(parsed)
if (parsed.progress) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
}
else if (parsed.images) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
const imageUrl = parsed.images[0].url;
console.log(imageUrl)
this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
}
}
}
}
else{
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
});
}
console.log(apiMessages)
//start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
}
}
}
}
// Clean the cstate before adding it to histories
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
cleanedCstate.messages = cleanedCstate.messages.map(msg => {

View File

@@ -91,25 +91,70 @@ class TopologyViz:
content = []
requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
max_width = self.console.width - 6 # Full width minus padding and icon
max_lines = 13 # Maximum number of lines for the entire panel content
# Calculate available height for content
panel_height = 15 # Fixed panel height
available_lines = panel_height - 2 # Subtract 2 for panel borders
lines_per_entry = available_lines // len(requests) if requests else 0
for (prompt, output) in reversed(requests):
prompt_icon, output_icon = "💬️", "🤖"
# Process prompt
prompt_lines = prompt.split('\n')
if len(prompt_lines) > max_lines // 2:
prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
# Calculate max lines for prompt and output
max_prompt_lines = lines_per_entry // 3 # Allocate 1/3 for prompt
max_output_lines = lines_per_entry - max_prompt_lines - 1 # Remaining space minus spacing
# Process prompt
prompt_lines = []
for line in prompt.split('\n'):
words = line.split()
current_line = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= max_width:
current_line.append(word)
current_length += len(word) + 1
else:
if current_line:
prompt_lines.append(' '.join(current_line))
current_line = [word]
current_length = len(word)
if current_line:
prompt_lines.append(' '.join(current_line))
if len(prompt_lines) > max_prompt_lines:
prompt_lines = prompt_lines[:max_prompt_lines - 1] + ['...']
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(prompt_lines), style="white")
# Process output - same word-aware wrapping
output_lines = []
for line in output.split('\n'):
words = line.split()
current_line = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= max_width:
current_line.append(word)
current_length += len(word) + 1
else:
if current_line:
output_lines.append(' '.join(current_line))
current_line = [word]
current_length = len(word)
if current_line:
output_lines.append(' '.join(current_line))
if len(output_lines) > max_output_lines:
output_lines = output_lines[:max_output_lines - 1] + ['...']
# Process output
output_lines = output.split('\n')
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
if len(output_lines) > remaining_lines:
output_lines = output_lines[:remaining_lines - 1] + ['...']
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
output_text.append('\n'.join(output_lines), style="white")
content.append(prompt_text)
content.append(output_text)
@@ -119,8 +164,8 @@ class TopologyViz:
Group(*content),
title="",
border_style="cyan",
height=15, # Increased height to accommodate multiple lines
expand=True # Allow the panel to expand to full width
height=panel_height,
expand=True
)
def _generate_main_layout(self) -> str:

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
if command -v python3.12 &>/dev/null; then
echo "Python 3.12 is installed, proceeding with python3.12..."

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
source ./install.sh
pushd exo/networking/grpc
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
echo "Starting node 1"
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &

View File

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
models = []
for model_id in model_cards: