diff --git a/.gitignore b/.gitignore index 93227e3c..bc6a38d7 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,5 @@ cython_debug/ **/*.xcodeproj/* .aider* + +exo/tinychat/images/*.png diff --git a/README.md b/README.md index 9881097a..2b2496aa 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l [![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main) [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) +exo-explore%2Fexo | Trendshift + --- @@ -38,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in ### Wide Model Support -exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek. +exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek. ### Dynamic Model Partitioning @@ -100,13 +102,13 @@ source install.sh - There are a number of things users have empirically found to improve performance on Apple Silicon Macs: -1. Upgrade to the latest version of MacOS 15. +1. Upgrade to the latest version of macOS Sequoia. 2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs. ## Documentation -### Example Usage on Multiple MacOS Devices +### Example Usage on Multiple macOS Devices #### Device 1: @@ -177,9 +179,9 @@ curl http://localhost:52415/v1/chat/completions \ }' ``` -### Example Usage on Multiple Heterogenous Devices (MacOS + Linux) +### Example Usage on Multiple Heterogenous Devices (macOS + Linux) -#### Device 1 (MacOS): +#### Device 1 (macOS): ```sh exo @@ -244,7 +246,7 @@ python3 format.py ./exo ## Known Issues -- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually: +- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows: ```sh /Applications/Python 3.x/Install Certificates.command diff --git a/configure_mlx.sh b/configure_mlx.sh index 77796055..f1cfe6e6 100755 --- a/configure_mlx.sh +++ b/configure_mlx.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # Get the total memory in MB TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024)) diff --git a/examples/function_calling.py b/examples/function_calling.py new file mode 100644 index 00000000..dcab2a8d --- /dev/null +++ b/examples/function_calling.py @@ -0,0 +1,111 @@ +import json +import re +import requests + +def get_current_weather(location: str, unit: str = "celsius"): + """Mock weather data function""" + # Hardcoded response for demo purposes + return { + "location": location, + "temperature": 22 if unit == "celsius" else 72, + "unit": unit, + "forecast": "Sunny with light clouds" + } + +def try_parse_tool_calls(content: str): + """Try parse the tool calls.""" + tool_calls = [] + offset = 0 + for i, m in enumerate(re.finditer(r"\n(.+)?\n", content)): + if i == 0: + offset = m.start() + try: + func = json.loads(m.group(1)) + tool_calls.append({"type": "function", "function": func}) + if isinstance(func["arguments"], str): + func["arguments"] = json.loads(func["arguments"]) + except json.JSONDecodeError as e: + print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}") + pass + if tool_calls: + if offset > 0 and content[:offset].strip(): + c = content[:offset] + else: + c = "" + return {"role": "assistant", "content": c, "tool_calls": tool_calls} + return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)} + +def chat_completion(messages): + """Send chat completion request to local server""" + response = requests.post( + "http://localhost:52415/v1/chat/completions", + json={ + "model": "qwen-2.5-1.5b", + "messages": messages, + "tools": [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + }], + "tool_choice": "auto" + } + ) + return response.json() + +def main(): + # Initial conversation + messages = [{ + "role": "user", + "content": "Hi there, what's the weather in Boston?" + }] + + # Get initial response + response = chat_completion(messages) + print(f"First response: {response}") + assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"]) + messages.append(assistant_message) + + # If there are tool calls, execute them and continue conversation + if "tool_calls" in assistant_message: + for tool_call in assistant_message["tool_calls"]: + if tool_call["function"]["name"] == "get_current_weather": + args = tool_call["function"]["arguments"] + weather_data = get_current_weather(**args) + + # Add tool response to messages + messages.append({ + "role": "tool", + "content": json.dumps(weather_data), + "name": tool_call["function"]["name"] + }) + + # Get final response with weather data + response = chat_completion(messages) + print(f"Final response: {response}") + messages.append({ + "role": "assistant", + "content": response["choices"][0]["message"]["content"] + }) + + # Print full conversation + for msg in messages: + print(f"\n{msg['role'].upper()}: {msg['content']}") + +if __name__ == "__main__": + main() diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 5bc9fb96..e05ccde0 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -5,18 +5,24 @@ import json import os from pathlib import Path from transformers import AutoTokenizer -from typing import List, Literal, Union, Dict +from typing import List, Literal, Union, Dict, Optional from aiohttp import web import aiohttp_cors import traceback import signal from exo import DEBUG, VERSION from exo.download.download_progress import RepoProgressEvent -from exo.helpers import PrefixDict, shutdown +from exo.helpers import PrefixDict, shutdown, get_exo_images_dir from exo.inference.tokenizers import resolve_tokenizer from exo.orchestration import Node from exo.models import build_base_shard, model_cards, get_repo, pretty_name from typing import Callable, Optional +from PIL import Image +import numpy as np +import base64 +from io import BytesIO +import mlx.core as mx +import tempfile from exo.download.hf.hf_shard_download import HFShardDownloader import shutil from exo.download.hf.hf_helpers import get_hf_home, get_repo_root @@ -24,23 +30,28 @@ from exo.apputil import create_animation_mp4 from collections import defaultdict class Message: - def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]): + def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None): self.role = role self.content = content + self.tools = tools def to_dict(self): - return {"role": self.role, "content": self.content} + data = {"role": self.role, "content": self.content} + if self.tools: + data["tools"] = self.tools + return data class ChatCompletionRequest: - def __init__(self, model: str, messages: List[Message], temperature: float): + def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None): self.model = model self.messages = messages self.temperature = temperature + self.tools = tools def to_dict(self): - return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature} + return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools} def generate_completion( @@ -120,20 +131,24 @@ def remap_messages(messages: List[Message]) -> List[Message]: return remapped_messages -def build_prompt(tokenizer, _messages: List[Message]): +def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None): messages = remap_messages(_messages) - prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True) - for message in messages: - if not isinstance(message.content, list): - continue + chat_template_args = { + "conversation": [m.to_dict() for m in messages], + "tokenize": False, + "add_generation_prompt": True + } + if tools: chat_template_args["tools"] = tools + prompt = tokenizer.apply_chat_template(**chat_template_args) + print(f"!!! Prompt: {prompt}") return prompt def parse_message(data: dict): if "role" not in data or "content" not in data: raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'") - return Message(data["role"], data["content"]) + return Message(data["role"], data["content"], data.get("tools")) def parse_chat_request(data: dict, default_model: str): @@ -141,6 +156,7 @@ def parse_chat_request(data: dict, default_model: str): data.get("model", default_model), [parse_message(msg) for msg in data["messages"]], data.get("temperature", 0.0), + data.get("tools", None), ) @@ -151,7 +167,7 @@ class PromptSession: self.prompt = prompt class ChatGPTAPI: - def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None): + def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None): self.node = node self.inference_engine_classname = inference_engine_classname self.response_timeout = response_timeout @@ -166,6 +182,7 @@ class ChatGPTAPI: # Get the callback system and register our handler self.token_callback = node.on_token.register("chatgpt-api-token-handler") self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished))) + self.system_prompt = system_prompt cors = aiohttp_cors.setup(self.app) cors_options = aiohttp_cors.ResourceOptions( @@ -180,6 +197,7 @@ class ChatGPTAPI: cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}) cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options}) cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}) + cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options}) cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options}) cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options}) cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options}) @@ -191,10 +209,12 @@ class ChatGPTAPI: cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options}) cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options}) + if "__compiled__" not in globals(): self.static_dir = Path(__file__).parent.parent/"tinychat" self.app.router.add_get("/", self.handle_root) self.app.router.add_static("/", self.static_dir, name="static") + self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images') self.app.middlewares.append(self.timeout_middleware) self.app.middlewares.append(self.log_request) @@ -241,7 +261,7 @@ class ChatGPTAPI: ) await response.prepare(request) - for model_name, pretty in pretty_name.items(): + async def process_model(model_name, pretty): if model_name in model_cards: model_info = model_cards[model_name] @@ -269,6 +289,12 @@ class ChatGPTAPI: await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) + # Process all models in parallel + await asyncio.gather(*[ + process_model(model_name, pretty) + for model_name, pretty in pretty_name.items() + ]) + await response.write(b"data: [DONE]\n\n") return response @@ -281,7 +307,8 @@ class ChatGPTAPI: ) async def handle_get_models(self, request): - return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]) + models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()] + return web.json_response({"object": "list", "data": models_list}) async def handle_post_chat_token_encode(self, request): data = await request.json() @@ -294,7 +321,7 @@ class ChatGPTAPI: shard = build_base_shard(model, self.inference_engine_classname) messages = [parse_message(msg) for msg in data.get("messages", [])] tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname)) - prompt = build_prompt(tokenizer, messages) + prompt = build_prompt(tokenizer, messages, data.get("tools", None)) tokens = tokenizer.encode(prompt) return web.json_response({ "length": len(prompt), @@ -314,13 +341,13 @@ class ChatGPTAPI: async def handle_post_chat_completions(self, request): data = await request.json() - if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}") + if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}") stream = data.get("stream", False) chat_request = parse_chat_request(data, self.default_model) if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model chat_request.model = self.default_model if not chat_request.model or chat_request.model not in model_cards: - if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}") + if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}") chat_request.model = self.default_model shard = build_base_shard(chat_request.model, self.inference_engine_classname) if not shard: @@ -331,34 +358,26 @@ class ChatGPTAPI: ) tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname)) - if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}") + if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}") - prompt = build_prompt(tokenizer, chat_request.messages) + # Add system prompt if set + if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages): + chat_request.messages.insert(0, Message("system", self.system_prompt)) + + prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools) request_id = str(uuid.uuid4()) if self.on_chat_completion_request: try: self.on_chat_completion_request(request_id, chat_request, prompt) except Exception as e: if DEBUG >= 2: traceback.print_exc() - # request_id = None - # match = self.prompts.find_longest_prefix(prompt) - # if match and len(prompt) > len(match[1].prompt): - # if DEBUG >= 2: - # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}") - # request_id = match[1].request_id - # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) - # # remove the matching prefix from the prompt - # prompt = prompt[len(match[1].prompt):] - # else: - # request_id = str(uuid.uuid4()) - # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) - if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}") + if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}") try: await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout) - if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s") + if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s") if stream: response = web.StreamResponse( @@ -374,10 +393,12 @@ class ChatGPTAPI: try: # Stream tokens while waiting for inference to complete while True: + if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}") token, is_finished = await asyncio.wait_for( self.token_queues[request_id].get(), timeout=self.response_timeout ) + if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}") finish_reason = None eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None) @@ -408,10 +429,13 @@ class ChatGPTAPI: return response except asyncio.TimeoutError: + if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}") return web.json_response({"detail": "Response generation timed out"}, status=408) except Exception as e: - if DEBUG >= 2: traceback.print_exc() + if DEBUG >= 2: + print(f"[ChatGPTAPI] Error processing prompt: {e}") + traceback.print_exc() return web.json_response( {"detail": f"Error processing prompt: {str(e)}"}, status=500 @@ -420,6 +444,7 @@ class ChatGPTAPI: finally: # Clean up the queue for this request if request_id in self.token_queues: + if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}") del self.token_queues[request_id] else: tokens = [] @@ -441,6 +466,85 @@ class ChatGPTAPI: if DEBUG >= 2: traceback.print_exc() return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) + + async def handle_post_image_generations(self, request): + data = await request.json() + + if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}") + stream = data.get("stream", False) + model = data.get("model", "") + prompt = data.get("prompt", "") + image_url = data.get("image_url", "") + if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}") + shard = build_base_shard(model, self.inference_engine_classname) + if DEBUG >= 2: print(f"shard: {shard}") + if not shard: + return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400) + + request_id = str(uuid.uuid4()) + callback_id = f"chatgpt-api-wait-response-{request_id}" + callback = self.node.on_token.register(callback_id) + try: + if image_url != "" and image_url != None: + img = self.base64_decode(image_url) + else: + img = None + await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout) + + + response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",}) + await response.prepare(request) + + def get_progress_bar(current_step, total_steps, bar_length=50): + # Calculate the percentage of completion + percent = float(current_step) / total_steps + # Calculate the number of hashes to display + arrow = '-' * int(round(percent * bar_length) - 1) + '>' + spaces = ' ' * (bar_length - len(arrow)) + + # Create the progress bar string + progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})' + return progress_bar + + async def stream_image(_request_id: str, result, is_finished: bool): + if isinstance(result, list): + await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n') + + elif isinstance(result, np.ndarray): + im = Image.fromarray(np.array(result)) + images_folder = get_exo_images_dir() + # Save the image to a file + image_filename = f"{_request_id}.png" + image_path = images_folder / image_filename + im.save(image_path) + image_url = request.app.router['static_images'].url_for(filename=image_filename) + base_url = f"{request.scheme}://{request.host}" + # Construct the full URL correctly + full_image_url = base_url + str(image_url) + + await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') + if is_finished: + await response.write_eof() + + + stream_task = None + def on_result(_request_id: str, result, is_finished: bool): + nonlocal stream_task + stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished)) + return _request_id == request_id and is_finished + + await callback.wait(on_result, timeout=self.response_timeout*10) + + if stream_task: + # Wait for the stream task to complete before returning + await stream_task + + return response + + except Exception as e: + if DEBUG >= 2: traceback.print_exc() + return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) + async def handle_delete_model(self, request): try: model_name = request.match_info.get('model_name') @@ -553,7 +657,7 @@ class ChatGPTAPI: if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400) shard = build_base_shard(model_name, self.inference_engine_classname) if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400) - asyncio.create_task(self.node.inference_engine.ensure_shard(shard)) + asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname)) return web.json_response({ "status": "success", @@ -585,3 +689,19 @@ class ChatGPTAPI: await runner.setup() site = web.TCPSite(runner, host, port) await site.start() + + def base64_decode(self, base64_string): + #decode and reshape image + if base64_string.startswith('data:image'): + base64_string = base64_string.split(',')[1] + image_data = base64.b64decode(base64_string) + img = Image.open(BytesIO(image_data)) + W, H = (dim - dim % 64 for dim in (img.width, img.height)) + if W != img.width or H != img.height: + if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}") + img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter + img = mx.array(np.array(img)) + img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1 + img = img[None] + return img + diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index d248dd37..119d321c 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -303,6 +303,10 @@ async def download_repo_files( await f.write(json.dumps(file_list)) if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}") + model_index_exists = any(file["path"] == "model_index.json" for file in file_list) + if model_index_exists: + allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"] + filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"])) total_files = len(filtered_file_list) total_bytes = sum(file["size"] for file in filtered_file_list) diff --git a/exo/download/hf/hf_shard_download.py b/exo/download/hf/hf_shard_download.py index 7428909f..b1981728 100644 --- a/exo/download/hf/hf_shard_download.py +++ b/exo/download/hf/hf_shard_download.py @@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader): print(f"No snapshot directory found for {self.current_repo_id}") return None + if not await aios.path.exists(snapshot_dir/"model_index.json"): # Get the weight map to know what files we need - weight_map = await get_weight_map(self.current_repo_id, self.revision) - if not weight_map: - if DEBUG >= 2: - print(f"No weight map found for {self.current_repo_id}") - return None + weight_map = await get_weight_map(self.current_repo_id, self.revision) + if not weight_map: + if DEBUG >= 2: + print(f"No weight map found for {self.current_repo_id}") + return None + + # Get all files needed for this shard + patterns = get_allow_patterns(weight_map, self.current_shard) + else: + patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"] - # Get all files needed for this shard - patterns = get_allow_patterns(weight_map, self.current_shard) # Check download status for all relevant files status = {} diff --git a/exo/helpers.py b/exo/helpers.py index 1b2f7bea..da286bbe 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -350,4 +350,21 @@ async def get_mac_system_info() -> Tuple[str, str, int]: return model_id, chip_id, memory except Exception as e: if DEBUG >= 2: print(f"Error getting Mac system info: {e}") - return "Unknown Model", "Unknown Chip", 0 \ No newline at end of file + return "Unknown Model", "Unknown Chip", 0 + +def get_exo_home() -> Path: + if os.name == "nt": # Check if the OS is Windows + docs_folder = Path(os.environ["USERPROFILE"]) / "Documents" + else: + docs_folder = Path.home() / "Documents" + exo_folder = docs_folder / "Exo" + if not exo_folder.exists(): + exo_folder.mkdir() + return exo_folder + +def get_exo_images_dir() -> Path: + exo_home = get_exo_home() + images_dir = exo_home / "Images" + if not images_dir.exists(): + images_dir.mkdir() + return images_dir diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index 85f1e14c..97cd6aa5 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -39,11 +39,15 @@ class InferenceEngine(ABC): async def clear_session(self): self.session.empty() - async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray: + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray: tokens = await self.encode(shard, prompt) - x = tokens.reshape(1, -1) - output_data = await self.infer_tensor(request_id, shard, x) - return output_data + if shard.model_id != 'stable-diffusion-2-1-base': + x = tokens.reshape(1, -1) + else: + x = tokens + output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state) + + return output_data, inference_state inference_engine_classes = { "mlx": "MLXDynamicShardInferenceEngine", diff --git a/exo/inference/mlx/models/StableDiffusionPipeline.py b/exo/inference/mlx/models/StableDiffusionPipeline.py new file mode 100644 index 00000000..f8e7c054 --- /dev/null +++ b/exo/inference/mlx/models/StableDiffusionPipeline.py @@ -0,0 +1,307 @@ +# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py + +import time +from typing import Optional, Tuple +import inspect + +import mlx.core as mx +import mlx.nn as nn +from pathlib import Path + +from tqdm import tqdm + +from .sd_models.vae import ModelArgs as VAEArgs +from .sd_models.vae import Autoencoder +from .sd_models.tokenizer import load_tokenizer +from .sd_models.clip import CLIPTextModel +from .sd_models.clip import ModelArgs as CLIPArgs +from .sd_models.unet import UNetConfig, UNetModel + +from dataclasses import dataclass, field +from exo.inference.shard import Shard + +@dataclass +class DiffusionConfig: + beta_schedule: str = "scaled_linear" + beta_start: float = 0.00085 + beta_end: float = 0.012 + num_train_steps: int = 1000 + + @classmethod + def from_dict(cls, params): + return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) + + +#Sampler +def _linspace(a, b, num): + x = mx.arange(0, num) / (num - 1) + return (b - a) * x + a + + +def _interp(y, x_new): + """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new.""" + x_low = x_new.astype(mx.int32) + x_high = mx.minimum(x_low + 1, len(y) - 1) + + y_low = y[x_low] + y_high = y[x_high] + delta_x = x_new - x_low + y_new = y_low * (1 - delta_x) + delta_x * y_high + + return y_new + +class SimpleEulerSampler: + """A simple Euler integrator that can be used to sample from our diffusion models. + + The method ``step()`` performs one Euler step from x_t to x_t_prev. + """ + + def __init__(self, config: DiffusionConfig): + # Compute the noise schedule + if config.beta_schedule == "linear": + betas = _linspace( + config.beta_start, config.beta_end, config.num_train_steps + ) + elif config.beta_schedule == "scaled_linear": + betas = _linspace( + config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps + ).square() + else: + raise NotImplementedError(f"{config.beta_schedule} is not implemented.") + + alphas = 1 - betas + alphas_cumprod = mx.cumprod(alphas) + + self._sigmas = mx.concatenate( + [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()] + ) + + @property + def max_time(self): + return len(self._sigmas) - 1 + + def sample_prior(self, shape, dtype=mx.float32, key=None): + noise = mx.random.normal(shape, key=key) + return ( + noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt() + ).astype(dtype) + + def add_noise(self, x, t, key=None): + noise = mx.random.normal(x.shape, key=key) + s = self.sigmas(t) + return (x + noise * s) * (s.square() + 1).rsqrt() + + def sigmas(self, t): + return _interp(self._sigmas, t) + + def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32): + start_time = start_time or (len(self._sigmas) - 1) + assert 0 < start_time <= (len(self._sigmas) - 1) + steps = _linspace(start_time, 0, num_steps + 1).astype(dtype) + return list(zip(steps, steps[1:])) + + def current_timestep(self, step, total_steps, start_time=None): + if step < total_steps: + steps = self.timesteps(total_steps, start_time) + return steps[step] + else: + return mx.array(0),mx.array(0) + + def step(self, eps_pred, x_t, t, t_prev): + sigma = self.sigmas(t).astype(eps_pred.dtype) + sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype) + + dt = sigma_prev - sigma + x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt + + x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt() + + return x_t_prev + +@dataclass +class ShardConfig: + model_id:str + start_layer:int + end_layer:int + n_layers:int + +@dataclass +class StableDiffusionConfig: + model_type:str + vae:VAEArgs + text_encoder:CLIPArgs + scheduler:DiffusionConfig + unet:UNetConfig + shard:ShardConfig + + @classmethod + def from_dict(cls, params): + return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) + +@dataclass +class ModelArgs(StableDiffusionConfig): + shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + if isinstance(self.shard, dict): + self.shard = Shard(**self.shard) + + if not isinstance(self.shard, Shard): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + +class Model(nn.Module): + def __init__(self, config): + super().__init__() + self.model_type = config.model_type + self.config = config + self.model_path = config.vae['path'].split('/vae')[0] + self.shard = config.shard + self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard) + self.config_clip=CLIPArgs.from_dict(config.text_encoder['config']) + if self.shard_clip.start_layer != -1: + self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip) + else: + self.text_encoder = nn.Identity() + self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt") + self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config']) + self.sampler = SimpleEulerSampler(self.diffusion_config) + if self.shard_unet.start_layer!=-1: + self.config_unet = UNetConfig.from_dict(config.unet['config']) + self.unet = UNetModel(self.config_unet, self.shard_unet) + else: + self.unet = nn.Identity() + self.config_vae=VAEArgs.from_dict(config.vae['config']) + if self.shard_encoder.start_layer != -1: + self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder") + else: + self.encoder = nn.Identity() + if self.shard_decoder.start_layer != -1: + self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder") + else: + self.decoder = nn.Identity() + + def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None): + t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step) + is_finished = False + is_step_finished = False + if t.item()==1000: + if self.shard_clip.start_layer == 0: + conditioning = x + if self.shard_clip.start_layer != -1: + conditioning, mask= self.text_encoder(conditioning,mask) + seed = int(time.time()) + mx.random.seed(seed) + if image is None: + if self.shard_encoder.is_last_layer(): + x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32) + x_t_prev=x + start_step = self.sampler.max_time + else: + if self.shard_encoder.start_layer != -1: + image= self.encoder.encode(image) + if self.shard_encoder.is_last_layer(): + start_step = self.sampler.max_time*strength + total_steps = int(total_steps*strength) + image = mx.broadcast_to(image, (1,) + image.shape[1:]) + x_t_prev=self.sampler.add_noise(image, mx.array(start_step)) + image = None + t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step) + # Perform the denoising loop + if self.shard_unet.start_layer != -1: + with tqdm(total=total_steps,initial=step+1) as pbar: + if step 1 else x + else: + x_t_unet = x + t_unet = mx.broadcast_to(t, [len(x_t_unet)]) + x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual) + if self.shard_unet.is_last_layer(): + if cfg_weight > 1: + eps_text, eps_neg = x.split(2) + eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg) + x = self.sampler.step(eps_pred, x_t_prev, t, t_prev) + x_t_prev=x + mx.eval(x) + + if self.shard_decoder.is_last_layer(): + is_step_finished=True + if self.shard_decoder.start_layer != -1: + x=self.decoder.decode(x) + if self.shard_decoder.is_last_layer(): + x = mx.clip(x / 2 + 0.5, 0, 1) + B, H, W, C = x.shape + x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(1 * H, B // 1 * W, C) + x = (x * 255).astype(mx.uint8) + if t_prev.item() ==0: + is_finished=True + mx.eval(x) + + return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image} + + + def load(self): + if self.shard_encoder.start_layer != -1: + vae_weights = mx.load(self.config_vae.weight_files[0]) + vae_weights = self.encoder.sanitize(vae_weights) + self.encoder.load_weights(list(vae_weights.items()), strict=True) + if self.shard_decoder.start_layer != -1: + vae_weights = mx.load(self.config_vae.weight_files[0]) + vae_weights = self.decoder.sanitize(vae_weights) + self.decoder.load_weights(list(vae_weights.items()), strict=True) + if self.shard_clip.start_layer != -1: + clip_weights = mx.load(self.config_clip.weight_files[0]) + clip_weights = self.text_encoder.sanitize(clip_weights) + self.text_encoder.load_weights(list(clip_weights.items()), strict=True) + if self.shard_unet.start_layer !=-1: + unet_weights = mx.load(self.config_unet.weight_files[0]) + unet_weights = self.unet.sanitize(unet_weights) + self.unet.load_weights(list(unet_weights.items()), strict=True) + +def model_shards(shard:ShardConfig): + def create_shard(shard, model_ranges): + start_layer = shard.start_layer + end_layer = shard.end_layer + + shards = {} + + for model_name, (range_start, range_end) in model_ranges.items(): + if start_layer < range_end and end_layer >= range_start: + # Calculate the overlap with the model range + overlap_start = max(start_layer, range_start) + overlap_end = min(end_layer, range_end - 1) + + # Adjust the layers relative to the model's range + relative_start = overlap_start - range_start + relative_end = overlap_end - range_start + shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start) + else: + # If no overlap, create a zero-layer shard + shards[model_name] = Shard(model_name, -1, -1, range_end - range_start) + + return shards + + # Define the ranges for different models + model_ranges = { + 'clip': (0, 12), + 'vae_encoder':(12,17), + 'unet':(17,26), + 'vae_decoder': (26, 31) # Example range for unet + } + + # Call the function and get the shards for all models + shards = create_shard(shard, model_ranges) + + # Access individual shards + shard_clip = shards['clip'] + shard_encoder = shards['vae_encoder'] + shard_unet = shards['unet'] + shard_decoder = shards['vae_decoder'] + + return shard_clip, shard_encoder, shard_unet, shard_decoder + + + diff --git a/exo/inference/mlx/models/phi3.py b/exo/inference/mlx/models/phi3.py new file mode 100644 index 00000000..acd4114a --- /dev/null +++ b/exo/inference/mlx/models/phi3.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass, field + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.phi3 import TransformerBlock, ModelArgs + +from ...shard import Shard +from .base import IdentityBlock + +@dataclass +class ModelArgs(ModelArgs): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + super().__post_init__() + + if isinstance(self.shard, Shard): + return + if not isinstance(self.shard, dict): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + self.shard = Shard(**self.shard) + +class Phi3Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + + if self.args.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + + self.layers = [] + for i in range(self.num_hidden_layers): + if self.args.shard.start_layer <= i <= self.args.shard.end_layer: + self.layers.append(TransformerBlock(args=args)) + else: + self.layers.append(IdentityBlock()) + + if self.args.shard.is_last_layer(): + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + if self.args.shard.is_first_layer(): + h = self.embed_tokens(inputs) + else: + h = inputs + + mask = None + if h.shape[1] > 1: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + if self.args.shard.is_last_layer(): + h = self.norm(h) + return h + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Phi3Model(args) + if self.args.shard.is_last_layer(): + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.shard.is_last_layer(): + out = self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + + for key, value in weights.items(): + if "self_attn.rope.inv_freq" in key: + continue + if key.startswith('model.layers.'): + layer_num = int(key.split('.')[2]) + if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: + shard_state_dict[key] = value + elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')): + shard_state_dict[key] = value + + return shard_state_dict + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/exo/inference/mlx/models/qwen2.py b/exo/inference/mlx/models/qwen2.py index 46b5ce78..f2ffc7fb 100644 --- a/exo/inference/mlx/models/qwen2.py +++ b/exo/inference/mlx/models/qwen2.py @@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs from ...shard import Shard from .base import IdentityBlock - @dataclass class ModelArgs(ModelArgs): shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) def __post_init__(self): - super().__post_init__() # Ensure parent initializations are respected + super().__post_init__() if isinstance(self.shard, Shard): return @@ -24,7 +23,6 @@ class ModelArgs(ModelArgs): self.shard = Shard(**self.shard) - class Qwen2Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -32,14 +30,17 @@ class Qwen2Model(nn.Module): self.vocab_size = args.vocab_size self.num_hidden_layers = args.num_hidden_layers assert self.vocab_size > 0 + if self.args.shard.is_first_layer(): self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [] for i in range(self.num_hidden_layers): if self.args.shard.start_layer <= i <= self.args.shard.end_layer: self.layers.append(TransformerBlock(args=args)) else: self.layers.append(IdentityBlock()) + if self.args.shard.is_last_layer(): self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) diff --git a/exo/inference/mlx/models/sd_models/clip.py b/exo/inference/mlx/models/sd_models/clip.py new file mode 100644 index 00000000..849460f4 --- /dev/null +++ b/exo/inference/mlx/models/sd_models/clip.py @@ -0,0 +1,191 @@ +# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py + +import math +from dataclasses import dataclass +from typing import List, Optional + +import mlx.core as mx +import mlx.nn as nn +from dataclasses import field, dataclass +from exo.inference.shard import Shard +from exo.inference.mlx.models.base import IdentityBlock + +_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu} + + + +@dataclass +class CLIPTextModelConfig: + num_layers: int = 23 + model_dims: int = 1024 + num_heads: int = 16 + max_length: int = 77 + vocab_size: int = 49408 + projection_dim: Optional[int] = None + hidden_act: str = "quick_gelu" + + @classmethod + def from_dict(cls, config): + return ModelArgs( + num_layers=config["num_hidden_layers"], + model_dims=config["hidden_size"], + num_heads=config["num_attention_heads"], + max_length=config["max_position_embeddings"], + vocab_size=config["vocab_size"], + projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None, + hidden_act=config.get("hidden_act", "quick_gelu"), + weight_files=config.get("weight_files", []) + ) + +@dataclass +class ModelArgs(CLIPTextModelConfig): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + weight_files: List[str] = field(default_factory=lambda: []) + def __post_init__(self): + if isinstance(self.shard, dict): + self.shard = Shard(**self.shard) + + if not isinstance(self.shard, Shard): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + if not self.shard.is_first_layer(): + self.vision_config = None + + +@dataclass +class CLIPOutput: + pooled_output: Optional[mx.array] = None + last_hidden_state: Optional[mx.array] = None + hidden_states: Optional[List[mx.array]] = None + + +class CLIPEncoderLayer(nn.Module): + """The transformer encoder layer from CLIP.""" + + def __init__(self, model_dims: int, num_heads: int, activation: str): + super().__init__() + + self.layer_norm1 = nn.LayerNorm(model_dims) + self.layer_norm2 = nn.LayerNorm(model_dims) + + self.attention = nn.MultiHeadAttention(model_dims, num_heads) + self.attention.query_proj.bias = mx.zeros(model_dims) + self.attention.key_proj.bias = mx.zeros(model_dims) + self.attention.value_proj.bias = mx.zeros(model_dims) + self.attention.out_proj.bias = mx.zeros(model_dims) + + self.linear1 = nn.Linear(model_dims, 4 * model_dims) + self.linear2 = nn.Linear(4 * model_dims, model_dims) + + self.act = _ACTIVATIONS[activation] + + def __call__(self, x, attn_mask=None): + + y = self.layer_norm1(x) + y = self.attention(y, y, y, attn_mask) + x = y + x + + y = self.layer_norm2(x) + y = self.linear1(y) + y = self.act(y) + y = self.linear2(y) + x = y + x + return x + + +class CLIPTextModel(nn.Module): + """Implements the text encoder transformer from CLIP.""" + + def __init__(self, config: CLIPTextModelConfig, shard: Shard): + super().__init__() + + self.shard = shard + self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2) + if self.shard.is_first_layer(): + self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims) + self.position_embedding = nn.Embedding(config.max_length, config.model_dims) + self.layers = [] + for i in range(math.ceil(config.num_layers/2)): + if 2*i in self.layers_range: + self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)) + if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers: + self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)) + else: + self.layers.append(IdentityBlock()) + if self.shard.is_last_layer(): + self.final_layer_norm = nn.LayerNorm(config.model_dims) + + if config.projection_dim is not None: + self.text_projection = nn.Linear( + config.model_dims, config.projection_dim, bias=False + ) + + def _get_mask(self, N, dtype): + indices = mx.arange(N) + mask = indices[:, None] < indices[None] + mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9) + return mask + + def __call__(self, x, mask=None): + # Extract some shapes + if self.shard.is_first_layer(): + B, N = x.shape + eos_tokens = x.argmax(-1) + + # Compute the embeddings + x = self.token_embedding(x) + + x = x + self.position_embedding.weight[:N] + # Compute the features from the transformer + mask = self._get_mask(N, x.dtype) + + for l in self.layers: + x = l(x, mask) + # Apply the final layernorm and return + + if self.shard.is_last_layer(): + x = self.final_layer_norm(x) + + + + return x, mask + def sanitize(self, weights): + sanitized_weights = {} + for key, value in weights.items(): + if "position_ids" in key: + continue + if key.startswith("text_model."): + key = key[11:] + if key.startswith("embeddings."): + key = key[11:] + if key.startswith("encoder."): + key = key[8:] + + # Map attention layers + if "self_attn." in key: + key = key.replace("self_attn.", "attention.") + if "q_proj." in key: + key = key.replace("q_proj.", "query_proj.") + if "k_proj." in key: + key = key.replace("k_proj.", "key_proj.") + if "v_proj." in key: + key = key.replace("v_proj.", "value_proj.") + + # Map ffn layers + if "mlp.fc1" in key: + key = key.replace("mlp.fc1", "linear1") + if "mlp.fc2" in key: + key = key.replace("mlp.fc2", "linear2") + + if key.startswith("layers."): + layer_num = int(key.split(".")[1]) + if layer_num not in self.layers_range: + continue + if not self.shard.is_first_layer() and "embedding" in key: + continue + if not self.shard.is_last_layer() and key.startswith("final_layer_norm"): + continue + if not self.shard.is_last_layer() and key.startswith("text_projection"): + continue + sanitized_weights[key] = value + return sanitized_weights diff --git a/exo/inference/mlx/models/sd_models/tokenizer.py b/exo/inference/mlx/models/sd_models/tokenizer.py new file mode 100644 index 00000000..4987bb90 --- /dev/null +++ b/exo/inference/mlx/models/sd_models/tokenizer.py @@ -0,0 +1,131 @@ +# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py + +import regex +import json +import glob + + +class Tokenizer: + """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" + + def __init__(self, bpe_ranks, vocab): + self.bpe_ranks = bpe_ranks + self.vocab = vocab + self.pat = regex.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + regex.IGNORECASE, + ) + + self._cache = {self.bos: self.bos, self.eos: self.eos} + + @property + def bos(self): + return "<|startoftext|>" + + @property + def bos_token(self): + return self.vocab[self.bos] + + @property + def eos(self): + return "<|endoftext|>" + + @property + def eos_token(self): + return self.vocab[self.eos] + + def bpe(self, text): + if text in self._cache: + return self._cache[text] + + unigrams = list(text[:-1]) + [text[-1] + ""] + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + if not unique_bigrams: + return unigrams + + # In every iteration try to merge the two most likely bigrams. If none + # was merged we are done. + # + # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py + while unique_bigrams: + bigram = min( + unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) + ) + if bigram not in self.bpe_ranks: + break + + new_unigrams = [] + skip = False + for a, b in zip(unigrams, unigrams[1:]): + if skip: + skip = False + continue + + if (a, b) == bigram: + new_unigrams.append(a + b) + skip = True + + else: + new_unigrams.append(a) + + if not skip: + new_unigrams.append(b) + + unigrams = new_unigrams + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + self._cache[text] = unigrams + + return unigrams + + def tokenize(self, text, prepend_bos=True, append_eos=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos) for t in text] + + # Lower case cleanup and split according to self.pat. Hugging Face does + # a much more thorough job here but this should suffice for 95% of + # cases. + clean_text = regex.sub(r"\s+", " ", text.lower()) + tokens = regex.findall(self.pat, clean_text) + + # Split the tokens according to the byte-pair merge file + bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] + + # Map to token ids and return + tokens = [self.vocab[t] for t in bpe_tokens] + if prepend_bos: + tokens = [self.bos_token] + tokens + if append_eos: + tokens.append(self.eos_token) + + return tokens + + def encode(self, prompt): + tokens = [self.tokenize(prompt)] + negative_text = "" + if negative_text is not None: + tokens += [self.tokenize(negative_text)] + lengths = [len(t) for t in tokens] + N = max(lengths) + tokens = [t + [0] * (N - len(t)) for t in tokens] + return tokens + +def load_tokenizer( + model_path: str, + vocab_key: str = "tokenizer_vocab", + merges_key: str = "tokenizer_merges", +): + + vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0] + with open(vocab_file, encoding="utf-8") as f: + vocab = json.load(f) + + merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0] + with open(merges_file, encoding="utf-8") as f: + bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + bpe_merges = [tuple(m.split()) for m in bpe_merges] + bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) + + return Tokenizer(bpe_ranks, vocab) + diff --git a/exo/inference/mlx/models/sd_models/unet.py b/exo/inference/mlx/models/sd_models/unet.py new file mode 100644 index 00000000..3fe44b86 --- /dev/null +++ b/exo/inference/mlx/models/sd_models/unet.py @@ -0,0 +1,629 @@ +# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py + +import math +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from dataclasses import dataclass, field +from typing import Tuple, Optional, List +from exo.inference.shard import Shard + +@dataclass +class UNetConfig: + in_channels: int = 4 + out_channels: int = 4 + conv_in_kernel: int = 3 + conv_out_kernel: int = 3 + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + layers_per_block: Tuple[int] = (2, 2, 2, 2) + mid_block_layers: int = 2 + transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1) + num_attention_heads: Tuple[int] = (5, 10, 20, 20) + cross_attention_dim: Tuple[int] = (1024,) * 4 + norm_num_groups: int = 32 + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + up_block_types: Tuple[str] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ) + addition_embed_type: Optional[str] = None + addition_time_embed_dim: Optional[int] = None + projection_class_embeddings_input_dim: Optional[int] = None + weight_files: List[str] = field(default_factory=lambda: []) + + + + @classmethod + def from_dict(cls,config): + n_blocks = len(config['block_out_channels']) + return UNetConfig( + in_channels=config["in_channels"], + out_channels=config["out_channels"], + block_out_channels=config["block_out_channels"], + layers_per_block=[config["layers_per_block"]] * n_blocks, + transformer_layers_per_block=config.get( + "transformer_layers_per_block", (1,) * 4 + ), + num_attention_heads=( + [config["attention_head_dim"]] * n_blocks + if isinstance(config["attention_head_dim"], int) + else config["attention_head_dim"] + ), + cross_attention_dim=[config["cross_attention_dim"]] * n_blocks, + norm_num_groups=config["norm_num_groups"], + down_block_types=config["down_block_types"], + up_block_types=config["up_block_types"][::-1], + addition_embed_type=config.get("addition_embed_type", None), + addition_time_embed_dim=config.get("addition_time_embed_dim", None), + projection_class_embeddings_input_dim=config.get( + "projection_class_embeddings_input_dim", None + ), + weight_files=config.get("weight_files", []) + + ) + + +def upsample_nearest(x, scale: int = 2): + B, H, W, C = x.shape + x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C)) + x = x.reshape(B, H * scale, W * scale, C) + + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def __call__(self, x): + x = self.linear_1(x) + x = nn.silu(x) + x = self.linear_2(x) + + return x + + +class TransformerBlock(nn.Module): + def __init__( + self, + model_dims: int, + num_heads: int, + hidden_dims: Optional[int] = None, + memory_dims: Optional[int] = None, + ): + super().__init__() + + self.norm1 = nn.LayerNorm(model_dims) + self.attn1 = nn.MultiHeadAttention(model_dims, num_heads) + self.attn1.out_proj.bias = mx.zeros(model_dims) + + memory_dims = memory_dims or model_dims + self.norm2 = nn.LayerNorm(model_dims) + self.attn2 = nn.MultiHeadAttention( + model_dims, num_heads, key_input_dims=memory_dims + ) + self.attn2.out_proj.bias = mx.zeros(model_dims) + + hidden_dims = hidden_dims or 4 * model_dims + self.norm3 = nn.LayerNorm(model_dims) + self.linear1 = nn.Linear(model_dims, hidden_dims) + self.linear2 = nn.Linear(model_dims, hidden_dims) + self.linear3 = nn.Linear(hidden_dims, model_dims) + + def __call__(self, x, memory, attn_mask, memory_mask): + # Self attention + y = self.norm1(x) + y = self.attn1(y, y, y, attn_mask) + x = x + y + + # Cross attention + y = self.norm2(x) + y = self.attn2(y, memory, memory, memory_mask) + x = x + y + + # FFN + y = self.norm3(x) + y_a = self.linear1(y) + y_b = self.linear2(y) + y = y_a * nn.gelu(y_b) + y = self.linear3(y) + x = x + y + + return x + + +class Transformer2D(nn.Module): + """A transformer model for inputs with 2 spatial dimensions.""" + + def __init__( + self, + in_channels: int, + model_dims: int, + encoder_dims: int, + num_heads: int, + num_layers: int = 1, + norm_num_groups: int = 32, + ): + super().__init__() + + self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True) + self.proj_in = nn.Linear(in_channels, model_dims) + self.transformer_blocks = [ + TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims) + for i in range(num_layers) + ] + self.proj_out = nn.Linear(model_dims, in_channels) + + def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask): + # Save the input to add to the output + input_x = x + dtype = x.dtype + + # Perform the input norm and projection + B, H, W, C = x.shape + x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C) + x = self.proj_in(x) + + # Apply the transformer + for block in self.transformer_blocks: + x = block(x, encoder_x, attn_mask, encoder_attn_mask) + + # Apply the output projection and reshape + x = self.proj_out(x) + x = x.reshape(B, H, W, C) + + return x + input_x + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + groups: int = 32, + temb_channels: Optional[int] = None, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels is not None: + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if in_channels != out_channels: + self.conv_shortcut = nn.Linear(in_channels, out_channels) + + def __call__(self, x, temb=None): + dtype = x.dtype + + if temb is not None: + temb = self.time_emb_proj(nn.silu(temb)) + y = self.norm1(x.astype(mx.float32)).astype(dtype) + + y = nn.silu(y) + + y = self.conv1(y) + + + if temb is not None: + y = y + temb[:, None, None, :] + y = self.norm2(y.astype(mx.float32)).astype(dtype) + y = nn.silu(y) + y = self.conv2(y) + + x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x)) + return x + + +class UNetBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + prev_out_channels: Optional[int] = None, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + num_attention_heads: int = 8, + cross_attention_dim=1280, + resnet_groups: int = 32, + add_downsample=True, + add_upsample=True, + add_cross_attention=True, + ): + super().__init__() + + # Prepare the in channels list for the resnets + if prev_out_channels is None: + in_channels_list = [in_channels] + [out_channels] * (num_layers - 1) + else: + in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1) + res_channels_list = [out_channels] * (num_layers - 1) + [in_channels] + in_channels_list = [ + a + b for a, b in zip(in_channels_list, res_channels_list) + ] + + # Add resnet blocks that also process the time embedding + self.resnets = [ + ResnetBlock2D( + in_channels=ic, + out_channels=out_channels, + temb_channels=temb_channels, + groups=resnet_groups, + ) + for ic in in_channels_list + ] + + # Add optional cross attention layers + if add_cross_attention: + self.attentions = [ + Transformer2D( + in_channels=out_channels, + model_dims=out_channels, + num_heads=num_attention_heads, + num_layers=transformer_layers_per_block, + encoder_dims=cross_attention_dim, + ) + for i in range(num_layers) + ] + + # Add an optional downsampling layer + if add_downsample: + self.downsample = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + + # or upsampling layer + if add_upsample: + self.upsample = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__( + self, + x, + encoder_x=None, + temb=None, + attn_mask=None, + encoder_attn_mask=None, + residual_hidden_states=None, + ): + output_states = [] + + for i in range(len(self.resnets)): + if residual_hidden_states is not None: + x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1) + + x = self.resnets[i](x, temb) + + if "attentions" in self: + x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask) + + output_states.append(x) + + if "downsample" in self: + x = self.downsample(x) + output_states.append(x) + + if "upsample" in self: + x = self.upsample(upsample_nearest(x)) + output_states.append(x) + + return x, output_states + + +class UNetModel(nn.Module): + """The conditional 2D UNet model that actually performs the denoising.""" + + def __init__(self, config: UNetConfig, shard: Shard): + super().__init__() + self.shard = shard + self.start_layer = shard.start_layer + self.end_layer = shard.end_layer + self.layers_range = list(range(self.start_layer, self.end_layer+1)) + if shard.is_first_layer(): + self.conv_in = nn.Conv2d( + config.in_channels, + config.block_out_channels[0], + config.conv_in_kernel, + padding=(config.conv_in_kernel - 1) // 2, + ) + + self.timesteps = nn.SinusoidalPositionalEncoding( + config.block_out_channels[0], + max_freq=1, + min_freq=math.exp( + -math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0] + ), + scale=1.0, + cos_first=True, + full_turns=False, + ) + self.time_embedding = TimestepEmbedding( + config.block_out_channels[0], + config.block_out_channels[0] * 4, + ) + + if config.addition_embed_type == "text_time": + self.add_time_proj = nn.SinusoidalPositionalEncoding( + config.addition_time_embed_dim, + max_freq=1, + min_freq=math.exp( + -math.log(10000) + + 2 * math.log(10000) / config.addition_time_embed_dim + ), + scale=1.0, + cos_first=True, + full_turns=False, + ) + self.add_embedding = TimestepEmbedding( + config.projection_class_embeddings_input_dim, + config.block_out_channels[0] * 4, + ) + + # Make the downsampling blocks + block_channels = [config.block_out_channels[0]] + list( + config.block_out_channels + ) + self.down_blocks = [] + + for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])): + if i in self.layers_range: + self.down_blocks.append( + UNetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=config.block_out_channels[0] * 4, + num_layers=config.layers_per_block[i], + transformer_layers_per_block=config.transformer_layers_per_block[i], + num_attention_heads=config.num_attention_heads[i], + cross_attention_dim=config.cross_attention_dim[i], + resnet_groups=config.norm_num_groups, + add_downsample=(i < len(config.block_out_channels) - 1), + add_upsample=False, + add_cross_attention="CrossAttn" in config.down_block_types[i], + ) + ) + else: + self.down_blocks.append(nn.Identity()) + + + # Make the middle block + if 4 in self.layers_range: + self.mid_blocks = [ + ResnetBlock2D( + in_channels=config.block_out_channels[-1], + out_channels=config.block_out_channels[-1], + temb_channels=config.block_out_channels[0] * 4, + groups=config.norm_num_groups, + ), + Transformer2D( + in_channels=config.block_out_channels[-1], + model_dims=config.block_out_channels[-1], + num_heads=config.num_attention_heads[-1], + num_layers=config.transformer_layers_per_block[-1], + encoder_dims=config.cross_attention_dim[-1], + ), + ResnetBlock2D( + in_channels=config.block_out_channels[-1], + out_channels=config.block_out_channels[-1], + temb_channels=config.block_out_channels[0] * 4, + groups=config.norm_num_groups, + ), + ] + + # Make the upsampling blocks + block_channels = ( + [config.block_out_channels[0]] + + list(config.block_out_channels) + + [config.block_out_channels[-1]] + ) + + total_items = len(block_channels) - 3 + reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:])))) + + self.up_blocks = [] + for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels): + i = total_items - rev_i + if rev_i+5 in self.layers_range: + self.up_blocks.append( + UNetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=config.block_out_channels[0] * 4, + prev_out_channels=prev_out_channels, + num_layers=config.layers_per_block[i] + 1, + transformer_layers_per_block=config.transformer_layers_per_block[i], + num_attention_heads=config.num_attention_heads[i], + cross_attention_dim=config.cross_attention_dim[i], + resnet_groups=config.norm_num_groups, + add_downsample=False, + add_upsample=(i > 0), + add_cross_attention="CrossAttn" in config.up_block_types[i], + ) + ) + else: + self.up_blocks.append(nn.Identity()) + + + if shard.is_last_layer(): + self.conv_norm_out = nn.GroupNorm( + config.norm_num_groups, + config.block_out_channels[0], + pytorch_compatible=True, + ) + self.conv_out = nn.Conv2d( + config.block_out_channels[0], + config.out_channels, + config.conv_out_kernel, + padding=(config.conv_out_kernel - 1) // 2, + ) + + def __call__( + self, + x, + timestep, + encoder_x, + attn_mask=None, + encoder_attn_mask=None, + text_time=None, + residuals=None, + ): + # Compute the time embeddings + + temb = self.timesteps(timestep).astype(x.dtype) + temb = self.time_embedding(temb) + + # Add the extra text_time conditioning + if text_time is not None: + text_emb, time_ids = text_time + emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype) + emb = mx.concatenate([text_emb, emb], axis=-1) + emb = self.add_embedding(emb) + temb = temb + emb + + if self.shard.is_first_layer(): + # Preprocess the input + x = self.conv_in(x) + residuals = [x] + # Run the downsampling part of the unet + + for i in range(len(self.down_blocks)): + if i in self.layers_range: + x, res = self.down_blocks[i]( + x, + encoder_x=encoder_x, + temb=temb, + attn_mask=attn_mask, + encoder_attn_mask=encoder_attn_mask, + ) + residuals.extend(res) + else: + x= self.down_blocks[i](x) + + if 4 in self.layers_range: + # Run the middle part of the unet + x = self.mid_blocks[0](x, temb) + x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask) + x = self.mid_blocks[2](x, temb) + + # Run the upsampling part of the unet + for i in range(len(self.up_blocks)): + if i+5 in self.layers_range: + x, _ = self.up_blocks[i]( + x, + encoder_x=encoder_x, + temb=temb, + attn_mask=attn_mask, + encoder_attn_mask=encoder_attn_mask, + residual_hidden_states=residuals, + ) + else: + x= self.up_blocks[i](x) + + # Postprocess the output + if self.shard.is_last_layer(): + dtype = x.dtype + x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype) + x = nn.silu(x) + x = self.conv_out(x) + + return x, residuals + def sanitize(self, weights): + sanitized_weights = {} + for key, value in weights.items(): + k1="" + k2="" + if "downsamplers" in key: + key = key.replace("downsamplers.0.conv", "downsample") + if "upsamplers" in key: + key = key.replace("upsamplers.0.conv", "upsample") + + # Map the mid block + if "mid_block.resnets.0" in key: + key = key.replace("mid_block.resnets.0", "mid_blocks.0") + if "mid_block.attentions.0" in key: + key = key.replace("mid_block.attentions.0", "mid_blocks.1") + if "mid_block.resnets.1" in key: + key = key.replace("mid_block.resnets.1", "mid_blocks.2") + + # Map attention layers + if "to_k" in key: + key = key.replace("to_k", "key_proj") + if "to_out.0" in key: + key = key.replace("to_out.0", "out_proj") + if "to_q" in key: + key = key.replace("to_q", "query_proj") + if "to_v" in key: + key = key.replace("to_v", "value_proj") + + # Map transformer ffn + if "ff.net.2" in key: + key = key.replace("ff.net.2", "linear3") + if "ff.net.0" in key: + k1 = key.replace("ff.net.0.proj", "linear1") + k2 = key.replace("ff.net.0.proj", "linear2") + v1, v2 = mx.split(value, 2) + + + if "conv_shortcut.weight" in key: + value = value.squeeze() + + # Transform the weights from 1x1 convs to linear + if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key): + value = value.squeeze() + + if len(value.shape) == 4: + value = value.transpose(0, 2, 3, 1) + value = value.reshape(-1).reshape(value.shape) + + if key.startswith("conv_in") : + if 0 not in self.layers_range: + continue + + if key.startswith("down_blocks"): + layer_num = int(key.split(".")[1]) + if layer_num not in self.layers_range: + continue + + if key.startswith("mid_block"): + if 4 not in self.layers_range: + continue + + if key.startswith("up_blocks"): + layer_num = int(key.split(".")[1]) + if (layer_num+5) not in self.layers_range: + continue + + if key.startswith("conv_out") or key.startswith("conv_norm_out"): + if 8 not in self.layers_range: + continue + + if len(k1)>0: + sanitized_weights[k1] = v1 + sanitized_weights[k2] = v2 + else: + sanitized_weights[key] = value + + + return sanitized_weights diff --git a/exo/inference/mlx/models/sd_models/vae.py b/exo/inference/mlx/models/sd_models/vae.py new file mode 100644 index 00000000..0f148517 --- /dev/null +++ b/exo/inference/mlx/models/sd_models/vae.py @@ -0,0 +1,429 @@ +# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py + +import math +from typing import List + +import mlx.core as mx +import mlx.nn as nn + +from .unet import ResnetBlock2D, upsample_nearest +from dataclasses import dataclass, field +from exo.inference.shard import Shard +from typing import Tuple +import inspect +from ..base import IdentityBlock + +@dataclass +class AutoencoderConfig: + in_channels: int = 3 + out_channels: int = 3 + latent_channels_out: int = 8 + latent_channels_in: int = 4 + block_out_channels: Tuple[int] = (128, 256, 512, 512) + layers_per_block: int = 2 + norm_num_groups: int = 32 + scaling_factor: float = 0.18215 + weight_files: List[str] = field(default_factory=lambda: []) + @classmethod + def from_dict(cls, params): + return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) + + +@dataclass +class ModelArgs(AutoencoderConfig): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + if isinstance(self.shard, dict): + self.shard = Shard(**self.shard) + + if not isinstance(self.shard, Shard): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + if not self.shard.is_first_layer(): + self.vision_config = None + + +class Attention(nn.Module): + """A single head unmasked attention for use with the VAE.""" + + def __init__(self, dims: int, norm_groups: int = 32): + super().__init__() + + self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True) + self.query_proj = nn.Linear(dims, dims) + self.key_proj = nn.Linear(dims, dims) + self.value_proj = nn.Linear(dims, dims) + self.out_proj = nn.Linear(dims, dims) + + def __call__(self, x): + B, H, W, C = x.shape + + y = self.group_norm(x) + + queries = self.query_proj(y).reshape(B, H * W, C) + keys = self.key_proj(y).reshape(B, H * W, C) + values = self.value_proj(y).reshape(B, H * W, C) + + scale = 1 / math.sqrt(queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 2, 1) + attn = mx.softmax(scores, axis=-1) + y = (attn @ values).reshape(B, H, W, C) + + y = self.out_proj(y) + x = x + y + + return x + + +class EncoderDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + resnet_groups: int = 32, + add_downsample=True, + add_upsample=True, + ): + super().__init__() + + # Add the resnet blocks + self.resnets = [ + ResnetBlock2D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + groups=resnet_groups, + ) + for i in range(num_layers) + ] + + # Add an optional downsampling layer + if add_downsample: + self.downsample = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=2, padding=0 + ) + + # or upsampling layer + if add_upsample: + self.upsample = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x): + for resnet in self.resnets: + x = resnet(x) + if "downsample" in self: + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) + x = self.downsample(x) + + if "upsample" in self: + x = self.upsample(upsample_nearest(x)) + return x + + +class Encoder(nn.Module): + """Implements the encoder side of the Autoencoder.""" + + def __init__( + self, + in_channels: int, + latent_channels_out: int, + block_out_channels: List[int] = [64], + layers_per_block: int = 2, + resnet_groups: int = 32, + layers_range: List[int] = [], + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + ): + super().__init__() + self.layers_range = layers_range + self.shard = shard + if self.shard.is_first_layer(): + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1 + ) + + channels = [block_out_channels[0]] + list(block_out_channels) + self.down_blocks = [] + current_layer = 1 + for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])): + if current_layer in self.layers_range: + self.down_blocks.append( + EncoderDecoderBlock2D( + in_channels, + out_channels, + num_layers=layers_per_block, + resnet_groups=resnet_groups, + add_downsample=i < len(block_out_channels) - 1, + add_upsample=False, + ) + ) + else: + self.down_blocks.append(IdentityBlock()) + current_layer += 1 + + if self.shard.is_last_layer(): + self.mid_blocks = [ + ResnetBlock2D( + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + groups=resnet_groups, + ), + Attention(block_out_channels[-1], resnet_groups), + ResnetBlock2D( + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + groups=resnet_groups, + ), + ] + + self.conv_norm_out = nn.GroupNorm( + resnet_groups, block_out_channels[-1], pytorch_compatible=True + ) + self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1) + + def __call__(self, x): + if self.shard.is_first_layer(): + x = self.conv_in(x) + + for l in self.down_blocks: + x = l(x) + + if self.shard.is_last_layer(): + x = self.mid_blocks[0](x) + x = self.mid_blocks[1](x) + x = self.mid_blocks[2](x) + + x = self.conv_norm_out(x) + x = nn.silu(x) + x = self.conv_out(x) + + return x + + +class Decoder(nn.Module): + """Implements the decoder side of the Autoencoder.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + shard: Shard, + layer_range: List[int], + block_out_channels: List[int] = [64], + layers_per_block: int = 2, + resnet_groups: int = 32, + ): + super().__init__() + self.out_channels = out_channels + self.layers_range = layer_range + if 0 in layer_range: + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 + ) + + if 0 in layer_range: + self.mid_blocks = [ + ResnetBlock2D( + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + groups=resnet_groups, + ), + Attention(block_out_channels[-1], resnet_groups), + ResnetBlock2D( + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + groups=resnet_groups, + ), + ] + + channels = list(reversed(block_out_channels)) + channels = [channels[0]] + channels + + self.up_blocks = [] + current_layer = 1 + + for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])): + if current_layer in layer_range: + self.up_blocks.append( + EncoderDecoderBlock2D( + in_channels, + out_channels, + num_layers=layers_per_block, + resnet_groups=resnet_groups, + add_downsample=False, + add_upsample=i < len(block_out_channels) - 1, + ) + ) + else: + self.up_blocks.append(IdentityBlock()) + current_layer += 1 + if 4 in layer_range: + self.conv_norm_out = nn.GroupNorm( + resnet_groups, block_out_channels[0], pytorch_compatible=True + ) + self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1) + + + def __call__(self, x): + if 0 in self.layers_range: + x = self.conv_in(x) + x = self.mid_blocks[0](x) + x = self.mid_blocks[1](x) + x = self.mid_blocks[2](x) + + for l in self.up_blocks: + x = l(x) + if 4 in self.layers_range: + x = self.conv_norm_out(x) + x = nn.silu(x) + x = self.conv_out(x) + return x + + +class Autoencoder(nn.Module): + """The autoencoder that allows us to perform diffusion in the latent space.""" + + def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str): + super().__init__() + self.shard = shard + self.start_layer = shard.start_layer + self.end_layer = shard.end_layer + self.layers_range = list(range(self.start_layer, self.end_layer+1)) + self.latent_channels = config.latent_channels_in + self.scaling_factor = config.scaling_factor + self.model_shard = model_shard + if self.model_shard == "vae_encoder": + self.encoder = Encoder( + config.in_channels, + config.latent_channels_out, + config.block_out_channels, + config.layers_per_block, + resnet_groups=config.norm_num_groups, + layers_range=self.layers_range, + shard=shard + ) + if self.shard.is_last_layer(): + self.quant_proj = nn.Linear( + config.latent_channels_out, config.latent_channels_out + ) + if self.model_shard == "vae_decoder": + self.decoder = Decoder( + config.latent_channels_in, + config.out_channels, + shard, + self.layers_range, + config.block_out_channels, + config.layers_per_block + 1, + resnet_groups=config.norm_num_groups, + ) + if self.shard.is_first_layer(): + self.post_quant_proj = nn.Linear( + config.latent_channels_in, config.latent_channels_in + ) + + def decode(self, z): + if self.shard.is_first_layer(): + z = z / self.scaling_factor + z=self.post_quant_proj(z) + return self.decoder(z) + + def encode(self, x): + x = self.encoder(x) + if self.shard.is_last_layer(): + x = self.quant_proj(x) + mean, logvar = x.split(2, axis=-1) + mean = mean * self.scaling_factor + logvar = logvar + 2 * math.log(self.scaling_factor) + x = mean + return x + + def __call__(self, x, key=None): + mean, logvar = self.encode(x) + z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean + x_hat = self.decode(z) + + return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar) + + def sanitize(self, weights): + shard = self.shard + layers = self.layers_range + sanitized_weights = {} + for key, value in weights.items(): + + if "downsamplers" in key: + key = key.replace("downsamplers.0.conv", "downsample") + if "upsamplers" in key: + key = key.replace("upsamplers.0.conv", "upsample") + + # Map attention layers + if "key" in key: + key = key.replace("key", "key_proj") + if "proj_attn" in key: + key = key.replace("proj_attn", "out_proj") + if "query" in key: + key = key.replace("query", "query_proj") + if "value" in key: + key = key.replace("value", "value_proj") + + # Map the mid block + if "mid_block.resnets.0" in key: + key = key.replace("mid_block.resnets.0", "mid_blocks.0") + if "mid_block.attentions.0" in key: + key = key.replace("mid_block.attentions.0", "mid_blocks.1") + if "mid_block.resnets.1" in key: + key = key.replace("mid_block.resnets.1", "mid_blocks.2") + + # Map the quant/post_quant layers + if "quant_conv" in key: + key = key.replace("quant_conv", "quant_proj") + value = value.squeeze() + + # Map the conv_shortcut to linear + if "conv_shortcut.weight" in key: + value = value.squeeze() + + if len(value.shape) == 4: + value = value.transpose(0, 2, 3, 1) + value = value.reshape(-1).reshape(value.shape) + + + if "post_quant_conv" in key : + key = key.replace("quant_conv", "quant_proj") + value = value.squeeze() + + if 'decoder' in key and self.model_shard == "vae_decoder": + if key.startswith("decoder.mid_blocks."): + if 0 in layers: + sanitized_weights[key] = value + if "conv_in" in key and 0 in layers: + sanitized_weights[key] = value + if key.startswith("decoder.up_blocks."): + layer_num = int(key.split(".")[2])+1 + if layer_num in layers: + sanitized_weights[key] = value + if key.startswith("decoder.conv_norm_out") and 4 in layers: + sanitized_weights[key] = value + if key.startswith("decoder.conv_out") and 4 in layers: + sanitized_weights[key] = value + if self.model_shard == "vae_decoder": + if key.startswith("post_quant_proj") and 0 in layers: + sanitized_weights[key] = value + if self.model_shard == "vae_encoder": + if key.startswith("encoder."): + if "conv_in" in key and shard.is_first_layer(): + sanitized_weights[key] = value + if key.startswith("encoder.down_blocks."): + layer_num = int(key.split(".")[2])+1 + if layer_num in layers: + sanitized_weights[key] = value + if key.startswith("encoder.mid_blocks.") and shard.is_last_layer(): + sanitized_weights[key] = value + if "conv_norm_out" in key and shard.is_last_layer(): + sanitized_weights[key] = value + if "conv_out" in key and shard.is_last_layer(): + sanitized_weights[key] = value + if key.startswith("quant_proj") and shard.is_last_layer(): + sanitized_weights[key] = value + return sanitized_weights + diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index bbe4d435..51bde44a 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -12,6 +12,7 @@ from exo.download.shard_download import ShardDownloader import asyncio from collections import OrderedDict from mlx_lm.models.cache import make_prompt_cache +from concurrent.futures import ThreadPoolExecutor class MLXDynamicShardInferenceEngine(InferenceEngine): def __init__(self, shard_downloader: ShardDownloader): @@ -20,6 +21,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine): self.caches = OrderedDict() self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1) self.sampler = make_sampler(*self.sampler_params) + self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx") + self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer") + + async def _eval_mlx(self, *args): + loop = asyncio.get_running_loop() + await loop.run_in_executor(self._mlx_thread, mx.eval, *args) async def poll_state(self, request_id: str, max_caches=2): if request_id in self.caches: @@ -38,16 +45,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine): logits = mx.array(x) logits = logits[:, -1, :] logprobs = logits - mx.logsumexp(logits, keepdims=True) - return np.asarray(self.sampler(logprobs), dtype=int) + result = self.sampler(logprobs) + await self._eval_mlx(result) + return np.asarray(result, dtype=int) async def encode(self, shard: Shard, prompt: str) -> np.ndarray: await self.ensure_shard(shard) - tokens = self.tokenizer.encode(prompt) - return np.asarray(tokens) + loop = asyncio.get_running_loop() + return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt)) async def decode(self, shard: Shard, tokens) -> str: await self.ensure_shard(shard) - return self.tokenizer.decode(tokens) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens) async def save_checkpoint(self, shard: Shard, path: str): await self.ensure_shard(shard) @@ -56,13 +66,18 @@ class MLXDynamicShardInferenceEngine(InferenceEngine): async def load_checkpoint(self, shard: Shard, path: str): await self.ensure_shard(shard) self.model.load_weights(path) - - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray: + + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray: await self.ensure_shard(shard) - state = await self.poll_state(request_id) + loop = asyncio.get_running_loop() + state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {} x = mx.array(input_data) - output_data = np.array(self.model(x, **state), copy=False) - return output_data + if self.model.model_type != 'StableDiffusionPipeline': + output_data = self.model(x, **state, **inference_state) + else: + output_data, inference_state = self.model(x, **state, **inference_state) + output_data = np.array(output_data, copy=False) + return output_data, inference_state async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"): await self.ensure_shard(shard) @@ -87,26 +102,25 @@ class MLXDynamicShardInferenceEngine(InferenceEngine): return True async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5): - loop = asyncio.get_running_loop() - nothin = await self.ensure_train(shard, loss, opt, lr) + await self.ensure_train(shard, loss, opt, lr) + def train_step(inp, tar, lng): lval, grad = self.session['LVaG'](self.model, inp, tar, lng) gradlayers = grad['model']['layers'] self.session['opt'].update(self.model, grad) - mx.eval(self.model.parameters(), self.session['opt'].state, lval) - return lval, gradlayers + return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval) x = mx.array(inputs) y = mx.array(targets) l = mx.array(lengths) - score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l) - #print(f"{score=}") + score, gradients, eval_args = train_step(x, y, l) + await self._eval_mlx(*eval_args) layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l] - #print(layers[0]) - - return score, np.array(layers[0]['input_layernorm'], copy=False) + first_layer = np.array(layers[0]['input_layernorm'], copy=False) + await self._eval_mlx(first_layer) + return score, first_layer async def ensure_shard(self, shard: Shard): if self.shard == shard: @@ -121,3 +135,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine): self.caches = OrderedDict() self.session = {} + async def cleanup(self): + self._mlx_thread.shutdown(wait=True) + diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index fca15a1f..34f29604 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -62,8 +62,16 @@ def _get_classes(config: dict): def load_config(model_path: Path) -> dict: try: - with open(model_path/"config.json", "r") as f: - config = json.load(f) + config_path = model_path / "config.json" + if config_path.exists(): + with open(config_path, "r") as f: + config = json.load(f) + return config + + model_index_path = model_path / "model_index.json" + if model_index_path.exists(): + config = load_model_index(model_path, model_index_path) + return config except FileNotFoundError: logging.error(f"Config file not found in {model_path}") raise @@ -110,6 +118,24 @@ def load_model_shard( # Try weight for back-compat weight_files = glob.glob(str(model_path/"weight*.safetensors")) + model_class, model_args_class = _get_classes(config=config) + + class ShardedModel(model_class): + def __init__(self, args): + super().__init__(args) + self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers) + + def __call__(self, x, *args, **kwargs): + y = super().__call__(x, *args, **kwargs) + return y + + model_args = model_args_class.from_dict(config) + model = ShardedModel(model_args) + + if config.get("model_index", False): + model.load() + return model + if not weight_files: logging.error(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}") @@ -129,19 +155,7 @@ def load_model_shard( weights.update(mx.load(wf)) - model_class, model_args_class = _get_classes(config=config) - - class ShardedModel(model_class): - def __init__(self, args): - super().__init__(args) - self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers) - - def __call__(self, x, *args, **kwargs): - y = super().__call__(x, *args, **kwargs) - return y - - model_args = model_args_class.from_dict(config) - model = ShardedModel(model_args) + if hasattr(model, "sanitize"): weights = model.sanitize(weights) @@ -186,6 +200,9 @@ async def load_shard( processor.eos_token_id = processor.tokenizer.eos_token_id processor.encode = processor.tokenizer.encode return model, processor + elif hasattr(model, "tokenizer"): + tokenizer = model.tokenizer + return model, tokenizer else: tokenizer = await resolve_tokenizer(model_path) return model, tokenizer @@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str): return img else: raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.") + +# loading a combined config for all models in the index +def load_model_index(model_path: Path, model_index_path: Path): + models_config = {} + with open(model_index_path, "r") as f: + model_index = json.load(f) + models_config["model_index"] = True + models_config["model_type"] = model_index["_class_name"] + models_config["models"] = {} + for model in model_index.keys(): + model_config_path = glob.glob(str(model_path / model / "*config.json")) + if len(model_config_path)>0: + with open(model_config_path[0], "r") as f: + model_config = { } + model_config["model_type"] = model + model_config["config"] = json.load(f) + model_config["path"] = model_path / model + if model_config["path"]/"*model.safetensors": + model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))}) + model_config["path"] = str(model_path / model) + m = {} + m[model] = model_config + models_config.update(m) + return models_config diff --git a/exo/inference/mlx/test_non_blocking.py b/exo/inference/mlx/test_non_blocking.py new file mode 100644 index 00000000..64eedfdd --- /dev/null +++ b/exo/inference/mlx/test_non_blocking.py @@ -0,0 +1,81 @@ +import asyncio +import time +import numpy as np +from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard +from exo.models import build_base_shard +from collections import deque +from statistics import mean, median + +async def test_non_blocking(): + # Setup + shard_downloader = HFShardDownloader() + engine = MLXDynamicShardInferenceEngine(shard_downloader) + _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine") + shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers) + await engine.ensure_shard(shard) + + queue = asyncio.Queue() + measurements = deque(maxlen=1000000) + running = True + + async def mlx_worker(): + try: + start_time = time.time() + count = 0 + while running and (time.time() - start_time) < 5: # Hard time limit + start = time.perf_counter_ns() + await engine.infer_prompt("req1", shard, "test prompt") + duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms + count += 1 + print(f"MLX operation {count} took: {duration:.3f}ms") + except asyncio.CancelledError: + pass + finally: + print(f"\nTotal MLX operations completed: {count}") + print(f"Average rate: {count/5:.1f} ops/second") + + async def latency_producer(): + try: + start_time = time.perf_counter_ns() + count = 0 + while running: + await queue.put(time.perf_counter_ns()) + count += 1 + await asyncio.sleep(0) # Yield to event loop without delay + duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds + print(f"\nProducer iterations: {count}") + print(f"Producer rate: {count/duration:.1f} iterations/second") + except asyncio.CancelledError: + pass + + async def latency_consumer(): + try: + while running: + timestamp = await queue.get() + latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms + measurements.append(latency) + queue.task_done() + except asyncio.CancelledError: + pass + + tasks = [ + asyncio.create_task(mlx_worker()), + asyncio.create_task(latency_producer()), + asyncio.create_task(latency_consumer()) + ] + + try: + await asyncio.wait_for(asyncio.gather(*tasks), timeout=6) + except asyncio.TimeoutError: + print("\nTest timed out") + finally: + running = False + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + print(f"\nFinal measurement count: {len(measurements)}") + +if __name__ == "__main__": + asyncio.run(test_non_blocking()) diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index 86abd76b..214cfd3d 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state from .losses import length_masked_ce_loss from collections import OrderedDict import asyncio - +from typing import Optional Tensor.no_grad = True # default settings TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85)) @@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine): state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model) safe_save(state_dict, path) - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray: + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray: await self.ensure_shard(shard) def wrap_infer(): x = Tensor(input_data) @@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine): self.states[request_id].start += x.shape[1] return out.realize() output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer) - return output_data.numpy() + return output_data.numpy(), inference_state async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss): def step(x, y, l): diff --git a/exo/inference/tokenizers.py b/exo/inference/tokenizers.py index 6b1439fc..4dccaf66 100644 --- a/exo/inference/tokenizers.py +++ b/exo/inference/tokenizers.py @@ -14,7 +14,7 @@ class DummyTokenizer: self.eos_token_id = 69 self.vocab_size = 1000 - def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True): + def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs): return "dummy_tokenized_prompt" def encode(self, text): diff --git a/exo/main.py b/exo/main.py index 677fb294..e07a03b3 100644 --- a/exo/main.py +++ b/exo/main.py @@ -103,6 +103,7 @@ parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailsca parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name") parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)") parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)") +parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API") args = parser.parse_args() print(f"Selected inference engine: {args.inference_engine}") @@ -182,11 +183,12 @@ api = ChatGPTAPI( inference_engine.__class__.__name__, response_timeout=args.chatgpt_api_response_timeout, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None, - default_model=args.default_model + default_model=args.default_model, + system_prompt=args.system_prompt +) +node.on_token.register("update_topology_viz").on_next( + lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None ) -# node.on_token.register("update_topology_viz").on_next( -# lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None -# ) def preemptively_start_download(request_id: str, opaque_status: str): try: diff --git a/exo/models.py b/exo/models.py index 0f984d48..8ffc1a5e 100644 --- a/exo/models.py +++ b/exo/models.py @@ -92,14 +92,17 @@ model_cards = { "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, }, ### qwen "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, }, + "qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, }, "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, }, + "qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, }, "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, }, - "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, }, - "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, }, - "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, }, "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, }, + "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, }, "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, }, "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, }, + "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, }, + "qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, }, + "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, }, "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, }, "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, }, ### nemotron @@ -108,6 +111,11 @@ model_cards = { # gemma "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, }, "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, }, + # stable diffusion + "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } }, + # phi + "phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, }, + "phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, }, # dummy "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, }, } @@ -133,18 +141,24 @@ pretty_name = { "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite", "deepseek-coder-v2.5": "Deepseek Coder V2.5", "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)", + "qwen-2.5-1.5b": "Qwen 2.5 1.5B", "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B", + "qwen-2.5-3b": "Qwen 2.5 3B", "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B", - "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B", - "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B", - "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B", "qwen-2.5-7b": "Qwen 2.5 7B", + "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B", "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)", "qwen-2.5-14b": "Qwen 2.5 14B", + "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B", + "qwen-2.5-32b": "Qwen 2.5 32B", + "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B", "qwen-2.5-72b": "Qwen 2.5 72B", "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)", + "phi-3.5-mini": "Phi-3.5 Mini", + "phi-4": "Phi-4", "llama-3-8b": "Llama 3 8B", "llama-3-70b": "Llama 3 70B", + "stable-diffusion-2-1-base": "Stable Diffusion 2.1", } def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]: diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index f0ef31db..eea315b8 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -11,7 +11,8 @@ from exo.inference.shard import Shard from exo.topology.topology import Topology from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops from exo.helpers import DEBUG - +import json +import mlx.core as mx class GRPCPeerHandle(PeerHandle): def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities): @@ -90,7 +91,7 @@ class GRPCPeerHandle(PeerHandle): traceback.print_exc() return False - async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None: + async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]: request = node_service_pb2.PromptRequest( prompt=prompt, shard=node_service_pb2.Shard( @@ -100,10 +101,11 @@ class GRPCPeerHandle(PeerHandle): n_layers=shard.n_layers, ), request_id=request_id, + inference_state=self.serialize_inference_state(inference_state) ) await self.stub.SendPrompt(request) - async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None: + async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]: request = node_service_pb2.TensorRequest( shard=node_service_pb2.Shard( model_id=shard.model_id, @@ -113,8 +115,14 @@ class GRPCPeerHandle(PeerHandle): ), tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)), request_id=request_id, + inference_state=self.serialize_inference_state(inference_state) ) - await self.stub.SendTensor(request) + response =await self.stub.SendTensor(request) + + if not response.tensor_data or not response.shape or not response.dtype: + return None + + return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape) async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]: request = node_service_pb2.ExampleRequest( @@ -173,10 +181,44 @@ class GRPCPeerHandle(PeerHandle): topology.add_edge(node_id, conn.to_id, conn.description) return topology - async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None: - request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished) - await self.stub.SendNewToken(request) + async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + tensor = None + if isinstance(result, np.ndarray): + tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype)) + result = [] + request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished) + await self.stub.SendResult(request) async def send_opaque_status(self, request_id: str, status: str) -> None: request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status) await self.stub.SendOpaqueStatus(request) + + def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState: + proto_inference_state = node_service_pb2.InferenceState() + other_data = {} + for k, v in inference_state.items(): + if isinstance(v, mx.array): + np_array = np.array(v) + tensor_data = node_service_pb2.Tensor( + tensor_data=np_array.tobytes(), + shape=list(np_array.shape), + dtype=str(np_array.dtype) + ) + proto_inference_state.tensor_data[k].CopyFrom(tensor_data) + elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v): + tensor_list = node_service_pb2.TensorList() + for tensor in v: + np_array = np.array(tensor) + tensor_data = node_service_pb2.Tensor( + tensor_data=np_array.tobytes(), + shape=list(np_array.shape), + dtype=str(np_array.dtype) + ) + tensor_list.tensors.append(tensor_data) + proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list) + else: + # For non-tensor data, we'll still use JSON + other_data[k] = v + if other_data: + proto_inference_state.other_data_json = json.dumps(other_data) + return proto_inference_state diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index ec37d768..da67d9c6 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -8,6 +8,8 @@ from . import node_service_pb2_grpc from exo import DEBUG from exo.inference.shard import Shard from exo.orchestration import Node +import json +import mlx.core as mx class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): @@ -58,9 +60,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): ) prompt = request.prompt request_id = request.request_id - await self.node.process_prompt(shard, prompt, request_id) - if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}") - return node_service_pb2.Empty() + inference_state = self.deserialize_inference_state(request.inference_state) + result = await self.node.process_prompt(shard, prompt, request_id, inference_state) + if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}") + tensor_data = result.tobytes() if result is not None else None + return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() async def SendTensor(self, request, context): shard = Shard( @@ -71,9 +75,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): ) tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape) request_id = request.request_id - await self.node.process_tensor(shard, tensor, request_id) - if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}") - return node_service_pb2.Empty() + + inference_state = self.deserialize_inference_state(request.inference_state) + + result = await self.node.process_tensor(shard, tensor, request_id, inference_state) + if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}") + tensor_data = result.tobytes() if result is not None else None + return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() async def SendExample(self, request, context): shard = Shard( @@ -127,8 +135,12 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): request_id = request.request_id token = request.token is_finished = request.is_finished - if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=}") - self.node.on_token.trigger_all(request_id, token, is_finished) + img = request.tensor + if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}") + result = list(result) + if len(img.tensor_data) > 0: + result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape) + self.node.on_token.trigger_all(request_id, result, is_finished) return node_service_pb2.Empty() async def SendOpaqueStatus(self, request, context): @@ -140,3 +152,22 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): async def HealthCheck(self, request, context): return node_service_pb2.HealthCheckResponse(is_healthy=True) + + def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict: + inference_state = {} + + for k, tensor_data in inference_state_proto.tensor_data.items(): + np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape) + inference_state[k] = mx.array(np_array) + + for k, tensor_list in inference_state_proto.tensor_list_data.items(): + inference_state[k] = [ + mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) + for tensor in tensor_list.tensors + ] + + if inference_state_proto.other_data_json: + other_data = json.loads(inference_state_proto.other_data_json) + inference_state.update(other_data) + + return inference_state diff --git a/exo/networking/grpc/node_service.proto b/exo/networking/grpc/node_service.proto index b99f5e66..06a35f2f 100644 --- a/exo/networking/grpc/node_service.proto +++ b/exo/networking/grpc/node_service.proto @@ -23,12 +23,14 @@ message PromptRequest { Shard shard = 1; string prompt = 2; optional string request_id = 3; + optional InferenceState inference_state = 4; } message TensorRequest { Shard shard = 1; Tensor tensor = 2; optional string request_id = 3; + optional InferenceState inference_state = 4; } message ExampleRequest { @@ -51,6 +53,16 @@ message Tensor { string dtype = 3; } +message TensorList { + repeated Tensor tensors = 1; +} + +message InferenceState { + map tensor_data = 1; + map tensor_list_data = 2; + string other_data_json = 3; +} + message CollectTopologyRequest { repeated string visited = 1; int32 max_depth = 2; @@ -85,8 +97,9 @@ message DeviceCapabilities { message SendNewTokenRequest { string request_id = 1; - int32 token = 2; - bool is_finished = 3; + repeated int32 result = 2; + optional Tensor tensor = 3; + bool is_finished = 4; } message SendOpaqueStatusRequest { diff --git a/exo/networking/grpc/node_service_pb2.py b/exo/networking/grpc/node_service_pb2.py index 7379eb69..6ff71086 100644 --- a/exo/networking/grpc/node_service_pb2.py +++ b/exo/networking/grpc/node_service_pb2.py @@ -24,55 +24,67 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"M\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\r\n\x05token\x18\x02 \x01(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x84\x01\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None + _globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None + _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_options = b'8\001' + _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._loaded_options = None + _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_options = b'8\001' _globals['_TOPOLOGY_NODESENTRY']._loaded_options = None _globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001' _globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001' _globals['_SHARD']._serialized_start=36 _globals['_SHARD']._serialized_end=119 - _globals['_PROMPTREQUEST']._serialized_start=121 - _globals['_PROMPTREQUEST']._serialized_end=228 - _globals['_TENSORREQUEST']._serialized_start=231 - _globals['_TENSORREQUEST']._serialized_end=360 - _globals['_EXAMPLEREQUEST']._serialized_start=363 - _globals['_EXAMPLEREQUEST']._serialized_end=585 - _globals['_LOSS']._serialized_start=587 - _globals['_LOSS']._serialized_end=659 - _globals['_TENSOR']._serialized_start=661 - _globals['_TENSOR']._serialized_end=720 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=722 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=782 - _globals['_TOPOLOGY']._serialized_start=785 - _globals['_TOPOLOGY']._serialized_end=1065 - _globals['_TOPOLOGY_NODESENTRY']._serialized_start=906 - _globals['_TOPOLOGY_NODESENTRY']._serialized_end=984 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=986 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1065 - _globals['_PEERCONNECTION']._serialized_start=1067 - _globals['_PEERCONNECTION']._serialized_end=1140 - _globals['_PEERCONNECTIONS']._serialized_start=1142 - _globals['_PEERCONNECTIONS']._serialized_end=1210 - _globals['_DEVICEFLOPS']._serialized_start=1212 - _globals['_DEVICEFLOPS']._serialized_end=1267 - _globals['_DEVICECAPABILITIES']._serialized_start=1269 - _globals['_DEVICECAPABILITIES']._serialized_end=1376 - _globals['_SENDNEWTOKENREQUEST']._serialized_start=1378 - _globals['_SENDNEWTOKENREQUEST']._serialized_end=1455 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1457 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1518 - _globals['_HEALTHCHECKREQUEST']._serialized_start=1520 - _globals['_HEALTHCHECKREQUEST']._serialized_end=1540 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=1542 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=1583 - _globals['_EMPTY']._serialized_start=1585 - _globals['_EMPTY']._serialized_end=1592 - _globals['_NODESERVICE']._serialized_start=1595 - _globals['_NODESERVICE']._serialized_end=2132 + _globals['_PROMPTREQUEST']._serialized_start=122 + _globals['_PROMPTREQUEST']._serialized_end=309 + _globals['_TENSORREQUEST']._serialized_start=312 + _globals['_TENSORREQUEST']._serialized_end=521 + _globals['_EXAMPLEREQUEST']._serialized_start=524 + _globals['_EXAMPLEREQUEST']._serialized_end=746 + _globals['_LOSS']._serialized_start=748 + _globals['_LOSS']._serialized_end=820 + _globals['_TENSOR']._serialized_start=822 + _globals['_TENSOR']._serialized_end=881 + _globals['_TENSORLIST']._serialized_start=883 + _globals['_TENSORLIST']._serialized_end=934 + _globals['_INFERENCESTATE']._serialized_start=937 + _globals['_INFERENCESTATE']._serialized_end=1275 + _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123 + _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194 + _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196 + _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337 + _globals['_TOPOLOGY']._serialized_start=1340 + _globals['_TOPOLOGY']._serialized_end=1620 + _globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461 + _globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620 + _globals['_PEERCONNECTION']._serialized_start=1622 + _globals['_PEERCONNECTION']._serialized_end=1695 + _globals['_PEERCONNECTIONS']._serialized_start=1697 + _globals['_PEERCONNECTIONS']._serialized_end=1765 + _globals['_DEVICEFLOPS']._serialized_start=1767 + _globals['_DEVICEFLOPS']._serialized_end=1822 + _globals['_DEVICECAPABILITIES']._serialized_start=1824 + _globals['_DEVICECAPABILITIES']._serialized_end=1931 + _globals['_SENDNEWTOKENREQUEST']._serialized_start=1934 + _globals['_SENDNEWTOKENREQUEST']._serialized_end=2066 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2068 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2129 + _globals['_HEALTHCHECKREQUEST']._serialized_start=2131 + _globals['_HEALTHCHECKREQUEST']._serialized_end=2151 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=2153 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=2194 + _globals['_EMPTY']._serialized_start=2196 + _globals['_EMPTY']._serialized_end=2203 + _globals['_NODESERVICE']._serialized_start=2206 + _globals['_NODESERVICE']._serialized_end=2743 # @@protoc_insertion_point(module_scope) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 8287605e..35a8fabe 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -1,7 +1,9 @@ +import os import asyncio -from exo.networking.discovery import Discovery -from typing import Dict, List, Callable +from typing import Dict, List, Callable, Optional +from concurrent.futures import ThreadPoolExecutor +from exo.networking.discovery import Discovery from exo.topology.device_capabilities import DeviceCapabilities from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig from exo.helpers import DEBUG_DISCOVERY @@ -13,28 +15,25 @@ class ManualDiscovery(Discovery): self, network_config_path: str, node_id: str, - create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], + create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle], ): - self.topology = NetworkTopology.from_path(network_config_path) + self.network_config_path = network_config_path + self.node_id = node_id self.create_peer_handle = create_peer_handle - if node_id not in self.topology.peers: - raise ValueError( - f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}" - ) - self.listen_task = None - self.known_peers: Dict[str, PeerHandle] = {} - self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers - self.peers_in_network.pop(node_id) + + self._cached_peers: Dict[str, PeerConfig] = {} + self._last_modified_time: Optional[float] = None + self._file_executor = ThreadPoolExecutor(max_workers=1) async def start(self) -> None: self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) async def stop(self) -> None: - if self.listen_task: - self.listen_task.cancel() + if self.listen_task: self.listen_task.cancel() + self._file_executor.shutdown(wait=True) async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: if wait_for_peers > 0: @@ -47,7 +46,9 @@ class ManualDiscovery(Discovery): async def task_find_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") while True: - for peer_id, peer_config in self.peers_in_network.items(): + peers_from_config = await self._get_peers() + new_known_peers = {} + for peer_id, peer_config in peers_from_config.items(): try: if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") peer = self.known_peers.get(peer_id) @@ -57,15 +58,43 @@ class ManualDiscovery(Discovery): is_healthy = await peer.health_check() if is_healthy: if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.") - self.known_peers[peer_id] = peer - else: - if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.") - try: - del self.known_peers[peer_id] - except KeyError: - pass + new_known_peers[peer_id] = peer + elif DEBUG_DISCOVERY >= 2: + print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.") except Exception as e: if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}") await asyncio.sleep(5.0) if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}") + + async def _get_peers(self): + try: + loop = asyncio.get_running_loop() + current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path) + + if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time): + return self._cached_peers + + topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path) + + if self.node_id not in topology.peers: + raise ValueError( + f"Node ID {self.node_id} not found in network config file " + f"{self.network_config_path}. Please run with `node_id` set to " + f"one of the keys in the config file: {[k for k, _ in topology.peers]}" + ) + + peers_in_network = topology.peers + peers_in_network.pop(self.node_id) + + self._cached_peers = peers_in_network + self._last_modified_time = current_mtime + + return peers_in_network + + except Exception as e: + if DEBUG_DISCOVERY >= 2: + print(f"Error when loading network config file from {self.network_config_path}. " + f"Please update the config file in order to successfully discover peers. " + f"Exception: {e}") + return self._cached_peers diff --git a/exo/networking/manual/test_data/test_config.json b/exo/networking/manual/test_data/test_config.json index b50ef635..54eced72 100644 --- a/exo/networking/manual/test_data/test_config.json +++ b/exo/networking/manual/test_data/test_config.json @@ -29,4 +29,4 @@ } } } -} +} \ No newline at end of file diff --git a/exo/networking/manual/test_manual_discovery.py b/exo/networking/manual/test_manual_discovery.py index 69f45fa1..317fba9d 100644 --- a/exo/networking/manual/test_manual_discovery.py +++ b/exo/networking/manual/test_manual_discovery.py @@ -1,3 +1,4 @@ +import json import asyncio import unittest from unittest import mock @@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.peer1 = mock.AsyncMock() self.peer1.connect = mock.AsyncMock() - self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1) - _ = self.discovery1.start() + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, + ) + await self.discovery1.start() async def asyncTearDown(self): await self.discovery1.stop() @@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase): self.peer2 = mock.AsyncMock() self.peer1.connect = mock.AsyncMock() self.peer2.connect = mock.AsyncMock() - self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1) - self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2) + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, + ) + self.discovery2 = ManualDiscovery( + root_path, + "node2", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2, + ) await self.discovery1.start() await self.discovery2.start() @@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase): self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port) await self.server1.start() await self.server2.start() - self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) - self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), + ) + self.discovery2 = ManualDiscovery( + root_path, + "node2", + create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), + ) await self.discovery1.start() await self.discovery2.start() @@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase): self.assertFalse(await peers1[0].is_connected()) self.assertFalse(await peers2[0].is_connected()) + async def test_dynamic_config_update(self): + initial_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(initial_peers), 1) + + # Save original config for cleanup + with open(root_path, "r") as f: + original_config = json.load(f) + + try: + updated_config = { + "peers": { + **original_config["peers"], + "node3": { + "address": "localhost", + "port": 50053, + "device_capabilities": { + "model": "Unknown Model", + "chip": "Unknown Chip", + "memory": 0, + "flops": {"fp32": 0, "fp16": 0, "int8": 0}, + }, + }, + } + } + + with open(root_path, "w") as f: + json.dump(updated_config, f, indent=2) + + node3 = mock.AsyncMock(spec=Node) + server3 = GRPCServer(node3, "localhost", 50053) + await server3.start() + + try: + # Wait for the config to be reloaded + await asyncio.sleep(1.5) + + updated_peers = await self.discovery1.discover_peers(wait_for_peers=2) + self.assertEqual(len(updated_peers), 2) + + for peer in updated_peers: + await peer.connect() + self.assertTrue(await peer.is_connected()) + + finally: + await server3.stop() + + finally: + # Restore the original config file + with open(root_path, "w") as f: + json.dump(original_config, f, indent=2) + + # Wait for the config to be reloaded again + await asyncio.sleep(1.5) + + updated_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(updated_peers), 1) + if __name__ == "__main__": asyncio.run(unittest.main()) diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index ebf9b673..00453deb 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -118,44 +118,50 @@ class Node: shard, result: np.ndarray, request_id: Optional[str] = None, + inference_state: Optional[dict] = None, ): - if request_id not in self.buffered_token_output: - self.buffered_token_output[request_id] = ([], False) - is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - - if shard.is_last_layer() and not is_finished: - self.token_count += 1 - if self.token_count == 1: - self.first_token_time = time.perf_counter_ns() - if self.token_count % 20 == 0: - print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}") - - token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) - await self.inference_engine.ensure_shard(shard) - self.buffered_token_output[request_id][0].append(token.item()) - is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") - forward = token.reshape(1, -1) - self.trigger_on_token_callbacks(request_id, token.item(), is_finished) - asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished)) + if shard.model_id != 'stable-diffusion-2-1-base': + if request_id not in self.buffered_token_output: + self.buffered_token_output[request_id] = ([], False) + is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + if shard.is_last_layer() and not is_finished: + token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) + await self.inference_engine.ensure_shard(shard) + self.buffered_token_output[request_id][0].append(token.item()) + is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") + asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id])) + forward = token.reshape(1, -1) + intermediate_result = self.buffered_token_output[request_id][0] + else: + forward = result else: + await self.inference_engine.ensure_shard(shard) + is_finished = inference_state.get("is_finished", False) + intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result) forward = result + if shard.is_last_layer(): + self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished) + asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished)) if is_finished: - self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) + if shard.model_id != 'stable-diffusion-2-1-base': + self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) self.outstanding_requests.pop(request_id) else: self.outstanding_requests[request_id] = "waiting" - asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1))) + asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state)) + + return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result - return np.array(self.buffered_token_output[request_id][0]) async def process_prompt( self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, - ) -> None: + inference_state: Optional[dict] = {}, + ) -> Optional[np.ndarray]: shard = self.get_current_shard(base_shard) start_time = time.perf_counter_ns() asyncio.create_task( @@ -172,7 +178,8 @@ class Node: }), ) ) - await self._process_prompt(base_shard, prompt, request_id) + start_time = time.perf_counter_ns() + resp = await self._process_prompt(base_shard, prompt, request_id, inference_state) end_time = time.perf_counter_ns() elapsed_time_ns = end_time - start_time asyncio.create_task( @@ -192,7 +199,7 @@ class Node: ) if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}") - async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]: + async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]: if request_id is None: request_id = str(uuid.uuid4()) shard = self.get_current_shard(base_shard) @@ -201,12 +208,13 @@ class Node: if not shard.is_first_layer(): if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}") self.outstanding_requests[request_id] = "waiting" - await self.forward_prompt(shard, prompt, request_id, 0) + resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state) return None - - self.outstanding_requests[request_id] = "processing" - result = await self.inference_engine.infer_prompt(request_id, shard, prompt) - await self.process_inference_result(shard, result, request_id) + else: + self.outstanding_requests[request_id] = "processing" + result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state) + ret = await self.process_inference_result(shard, result, request_id, inference_state) + return result async def enqueue_example( self, @@ -350,10 +358,11 @@ class Node: base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, - ) -> None: + inference_state: Optional[dict] = None, + ) -> Optional[np.ndarray]: shard = self.get_current_shard(base_shard) start_time = time.perf_counter_ns() - await self._process_tensor(shard, tensor, request_id) + resp = await self._process_tensor(shard, tensor, request_id, inference_state) end_time = time.perf_counter_ns() elapsed_time_ns = end_time - start_time if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}") @@ -363,15 +372,17 @@ class Node: base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, - ) -> None: + inference_state: Optional[dict] = None, + ) -> Optional[np.ndarray]: if request_id is None: request_id = str(uuid.uuid4()) shard = self.get_current_shard(base_shard) try: self.outstanding_requests[request_id] = "processing" - result = await self.inference_engine.infer_tensor(request_id, shard, tensor) - await self.process_inference_result(shard, result, request_id) + result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state) + ret = await self.process_inference_result(shard, result, request_id, inference_state) + return ret except Exception as e: self.outstanding_requests.pop(request_id) print(f"Error processing tensor for shard {shard}: {e}") @@ -404,19 +415,20 @@ class Node: prompt: str, request_id: str, target_index: int, + inference_state: Optional[dict] = None, ) -> None: if DEBUG >= 1: print(f"target partition index: {target_index}") target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id next_shard = self.get_current_shard(base_shard, target_index) if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}") if target_id == self.id: - await self.process_prompt(next_shard, prompt, request_id) + await self.process_prompt(next_shard, prompt, request_id, inference_state) else: target_peer = next((p for p in self.peers if p.id() == target_id), None) if not target_peer: raise ValueError(f"Peer for {target_index} not found") if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}") - await target_peer.send_prompt(next_shard, prompt, request_id=request_id) + await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state) async def forward_tensor( self, @@ -424,19 +436,20 @@ class Node: tensor: np.ndarray, request_id: str, target_index: int, + inference_state: Optional[dict] = None, ) -> None: if DEBUG >= 1: print(f"target partition index: {target_index}") target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id next_shard = self.get_current_shard(base_shard, target_index) if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}") if target_id == self.id: - await self.process_tensor(next_shard, tensor, request_id) + await self.process_tensor(next_shard, tensor, request_id, inference_state) else: target_peer = next((p for p in self.peers if p.id() == target_id), None) if not target_peer: raise ValueError(f"Peer for {target_index} not found") if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}") - await target_peer.send_tensor(next_shard, tensor, request_id=request_id) + await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state) def get_partition_index(self, offset: int = 0): if not self.partitioning_strategy: @@ -604,3 +617,12 @@ class Node: @property def current_topology(self) -> Topology: return self.topology + + def handle_stable_diffusion(self, inference_state, result): + if inference_state['is_step_finished']: + inference_state['step']+=1 + progress = [inference_state['step'],inference_state['total_steps']] + intermediate_result = result + if progress[0] == progress[1]: + intermediate_result = result + return intermediate_result, inference_state diff --git a/exo/orchestration/tracing.py b/exo/orchestration/tracing.py new file mode 100644 index 00000000..4466fc7d --- /dev/null +++ b/exo/orchestration/tracing.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Any +from opentelemetry import trace, context +from opentelemetry.trace import Status, StatusCode, SpanContext +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from contextlib import contextmanager +import time +from threading import Lock + +@dataclass +class TraceContext: + request_id: str + sequence_number: int + current_span: Optional[trace.Span] = None + trace_parent: Optional[str] = None + token_group_span: Optional[trace.Span] = None + token_count: int = 0 + token_group_size: int = 10 # Default group size + request_span: Optional[trace.Span] = None # Track the main request span + +class Tracer: + def __init__(self): + self.tracer = trace.get_tracer("exo") + self.contexts: Dict[str, TraceContext] = {} + self._lock = Lock() + self.propagator = TraceContextTextMapPropagator() + + def get_context(self, request_id: str) -> Optional[TraceContext]: + with self._lock: + return self.contexts.get(request_id) + + def set_context(self, request_id: str, context: TraceContext): + with self._lock: + self.contexts[request_id] = context + + def inject_context(self, span: trace.Span) -> str: + """Inject current span context into carrier for propagation""" + carrier = {} + ctx = trace.set_span_in_context(span) + self.propagator.inject(carrier, context=ctx) + return carrier.get("traceparent", "") + + def extract_context(self, trace_parent: str) -> Optional[context.Context]: + """Extract span context from carrier""" + if not trace_parent: + return None + carrier = {"traceparent": trace_parent} + return self.propagator.extract(carrier) + + def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext: + """Create a new context with the given trace parent""" + parent_ctx = self.extract_context(trace_parent) + if parent_ctx: + # Create a new request span that links to the parent context + request_span = self.tracer.start_span( + "request", + context=parent_ctx, + attributes={ + "request_id": request_id, + "sequence_number": sequence_number + } + ) + return TraceContext( + request_id=request_id, + sequence_number=sequence_number, + request_span=request_span, + current_span=request_span, + trace_parent=trace_parent + ) + return TraceContext(request_id=request_id, sequence_number=sequence_number) + + def handle_token(self, context: TraceContext, token: int, is_finished: bool = False): + """Handle token generation and manage token group spans""" + context.token_count += 1 + + # Start a new token group span if needed + if not context.token_group_span and context.request_span: + group_number = (context.token_count - 1) // context.token_group_size + 1 + + # Create token group span as child of request span + parent_ctx = trace.set_span_in_context(context.request_span) + context.token_group_span = self.tracer.start_span( + f"token_group_{group_number}", + context=parent_ctx, + attributes={ + "request_id": context.request_id, + "group.number": group_number, + "group.start_token": context.token_count, + "group.max_tokens": context.token_group_size + } + ) + + # Add token to current group span + if context.token_group_span: + relative_pos = ((context.token_count - 1) % context.token_group_size) + 1 + context.token_group_span.set_attribute(f"token.{relative_pos}", token) + context.token_group_span.set_attribute("token.count", relative_pos) + + # End current group span if we've reached the group size or if generation is finished + if context.token_count % context.token_group_size == 0 or is_finished: + context.token_group_span.set_attribute("token.final_count", relative_pos) + context.token_group_span.end() + context.token_group_span = None + + @contextmanager + def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None): + """Start a new span with proper parent context""" + attributes = { + "request_id": context.request_id, + "sequence_number": context.sequence_number + } + if extra_attributes: + attributes.update(extra_attributes) + + # Use request span as parent if available + parent_ctx = None + if context.request_span: + parent_ctx = trace.set_span_in_context(context.request_span) + elif context.trace_parent: + parent_ctx = self.extract_context(context.trace_parent) + if parent_ctx and not context.request_span: + # Create a new request span that links to the parent context + context.request_span = self.tracer.start_span( + "request", + context=parent_ctx, + attributes={ + "request_id": context.request_id, + "sequence_number": context.sequence_number + } + ) + parent_ctx = trace.set_span_in_context(context.request_span) + elif context.current_span: + parent_ctx = trace.set_span_in_context(context.current_span) + + # Create span with parent context if it exists + if parent_ctx: + span = self.tracer.start_span( + name, + context=parent_ctx, + attributes=attributes + ) + else: + span = self.tracer.start_span( + name, + attributes=attributes + ) + + # Update context with current span + prev_span = context.current_span + context.current_span = span + + try: + start_time = time.perf_counter() + yield span + duration = time.perf_counter() - start_time + span.set_attribute("duration_s", duration) + span.set_status(Status(StatusCode.OK)) + except Exception as e: + span.set_status(Status(StatusCode.ERROR, str(e))) + raise + finally: + span.end() + context.current_span = prev_span + +# Global tracer instance +tracer = Tracer() \ No newline at end of file diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index 4e0617e4..013d0d63 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -197,7 +197,25 @@ const div = document.createElement('div'); div.className = `message message-role-${role}`; try { - div.innerHTML = DOMPurify.sanitize(marked.parse(content)); + if (content.includes('![Generated Image]')) { + const imageUrl = content.match(/\((.*?)\)/)[1]; + const img = document.createElement('img'); + img.src = imageUrl; + img.alt = 'Generated Image'; + img.onclick = async () => { + try { + const response = await fetch(img.src); + const blob = await response.blob(); + const file = new File([blob], 'image.png', { type: 'image/png' }); + handleImageUpload({ target: { files: [file] } }); + } catch (error) { + console.error('Error fetching image:', error); + } + }; + div.appendChild(img); + } else { + div.innerHTML = DOMPurify.sanitize(marked.parse(content)); + } } catch (e) { console.log(content); console.error(e); @@ -281,7 +299,7 @@
- diff --git a/exo/tinychat/index.js b/exo/tinychat/index.js index 48c5a23c..5aa6c4a9 100644 --- a/exo/tinychat/index.js +++ b/exo/tinychat/index.js @@ -231,53 +231,110 @@ document.addEventListener("alpine:init", () => { }; } }); - const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url')); - if (containsImage) { - // Map all messages with string content to object with type text - apiMessages = apiMessages.map(msg => { - if (typeof msg.content === 'string') { - return { - ...msg, - content: [ - { - type: "text", - text: msg.content - } - ] - }; - } - return msg; + + if (this.cstate.selectedModel === "stable-diffusion-2-1-base") { + // Send a request to the image generation endpoint + console.log(apiMessages[apiMessages.length - 1].content) + console.log(this.cstate.selectedModel) + console.log(this.endpoint) + const response = await fetch(`${this.endpoint}/image/generations`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + "model": 'stable-diffusion-2-1-base', + "prompt": apiMessages[apiMessages.length - 1].content, + "image_url": this.imageUrl + }), }); - } - - - // start receiving server sent events - let gottenFirstChunk = false; - for await ( - const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages) - ) { - if (!gottenFirstChunk) { - this.cstate.messages.push({ role: "assistant", content: "" }); - gottenFirstChunk = true; + + if (!response.ok) { + throw new Error("Failed to fetch"); } - - // add chunk to the last message - this.cstate.messages[this.cstate.messages.length - 1].content += chunk; - - // calculate performance tracking - tokens += 1; - this.total_tokens += 1; - if (start_time === 0) { - start_time = Date.now(); - this.time_till_first = start_time - prefill_start; - } else { - const diff = Date.now() - start_time; - if (diff > 0) { - this.tokens_per_second = tokens / (diff / 1000); + const reader = response.body.getReader(); + let done = false; + let gottenFirstChunk = false; + + while (!done) { + const { value, done: readerDone } = await reader.read(); + done = readerDone; + const decoder = new TextDecoder(); + + if (value) { + // Assume non-binary data (text) comes first + const chunk = decoder.decode(value, { stream: true }); + const parsed = JSON.parse(chunk); + console.log(parsed) + + if (parsed.progress) { + if (!gottenFirstChunk) { + this.cstate.messages.push({ role: "assistant", content: "" }); + gottenFirstChunk = true; + } + this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress; + } + else if (parsed.images) { + if (!gottenFirstChunk) { + this.cstate.messages.push({ role: "assistant", content: "" }); + gottenFirstChunk = true; + } + const imageUrl = parsed.images[0].url; + console.log(imageUrl) + this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`; + } } } } + + else{ + const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url')); + if (containsImage) { + // Map all messages with string content to object with type text + apiMessages = apiMessages.map(msg => { + if (typeof msg.content === 'string') { + return { + ...msg, + content: [ + { + type: "text", + text: msg.content + } + ] + }; + } + return msg; + }); + } + console.log(apiMessages) + //start receiving server sent events + let gottenFirstChunk = false; + for await ( + const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages) + ) { + if (!gottenFirstChunk) { + this.cstate.messages.push({ role: "assistant", content: "" }); + gottenFirstChunk = true; + } + + // add chunk to the last message + this.cstate.messages[this.cstate.messages.length - 1].content += chunk; + + // calculate performance tracking + tokens += 1; + this.total_tokens += 1; + if (start_time === 0) { + start_time = Date.now(); + this.time_till_first = start_time - prefill_start; + } else { + const diff = Date.now() - start_time; + if (diff > 0) { + this.tokens_per_second = tokens / (diff / 1000); + } + } + } + } // Clean the cstate before adding it to histories const cleanedCstate = JSON.parse(JSON.stringify(this.cstate)); cleanedCstate.messages = cleanedCstate.messages.map(msg => { diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index 1519612f..734fe69d 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -91,25 +91,70 @@ class TopologyViz: content = [] requests = list(self.requests.values())[-3:] # Get the 3 most recent requests max_width = self.console.width - 6 # Full width minus padding and icon - max_lines = 13 # Maximum number of lines for the entire panel content + + # Calculate available height for content + panel_height = 15 # Fixed panel height + available_lines = panel_height - 2 # Subtract 2 for panel borders + lines_per_entry = available_lines // len(requests) if requests else 0 for (prompt, output) in reversed(requests): prompt_icon, output_icon = "💬️", "🤖" - # Process prompt - prompt_lines = prompt.split('\n') - if len(prompt_lines) > max_lines // 2: - prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...'] - prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue") - prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white") + # Calculate max lines for prompt and output + max_prompt_lines = lines_per_entry // 3 # Allocate 1/3 for prompt + max_output_lines = lines_per_entry - max_prompt_lines - 1 # Remaining space minus spacing + + # Process prompt + prompt_lines = [] + for line in prompt.split('\n'): + words = line.split() + current_line = [] + current_length = 0 + + for word in words: + if current_length + len(word) + 1 <= max_width: + current_line.append(word) + current_length += len(word) + 1 + else: + if current_line: + prompt_lines.append(' '.join(current_line)) + current_line = [word] + current_length = len(word) + + if current_line: + prompt_lines.append(' '.join(current_line)) + + if len(prompt_lines) > max_prompt_lines: + prompt_lines = prompt_lines[:max_prompt_lines - 1] + ['...'] + + prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue") + prompt_text.append('\n'.join(prompt_lines), style="white") + + # Process output - same word-aware wrapping + output_lines = [] + for line in output.split('\n'): + words = line.split() + current_line = [] + current_length = 0 + + for word in words: + if current_length + len(word) + 1 <= max_width: + current_line.append(word) + current_length += len(word) + 1 + else: + if current_line: + output_lines.append(' '.join(current_line)) + current_line = [word] + current_length = len(word) + + if current_line: + output_lines.append(' '.join(current_line)) + + if len(output_lines) > max_output_lines: + output_lines = output_lines[:max_output_lines - 1] + ['...'] - # Process output - output_lines = output.split('\n') - remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing - if len(output_lines) > remaining_lines: - output_lines = output_lines[:remaining_lines - 1] + ['...'] output_text = Text(f"\n{output_icon} ", style="bold bright_magenta") - output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white") + output_text.append('\n'.join(output_lines), style="white") content.append(prompt_text) content.append(output_text) @@ -119,8 +164,8 @@ class TopologyViz: Group(*content), title="", border_style="cyan", - height=15, # Increased height to accommodate multiple lines - expand=True # Allow the panel to expand to full width + height=panel_height, + expand=True ) def _generate_main_layout(self) -> str: diff --git a/install.sh b/install.sh index f136c317..a5fffec6 100755 --- a/install.sh +++ b/install.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash if command -v python3.12 &>/dev/null; then echo "Python 3.12 is installed, proceeding with python3.12..." diff --git a/scripts/compile_grpc.sh b/scripts/compile_grpc.sh index b9b87204..b0333bb5 100755 --- a/scripts/compile_grpc.sh +++ b/scripts/compile_grpc.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash source ./install.sh pushd exo/networking/grpc python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto diff --git a/test/reconnect.sh b/test/reconnect.sh index 7921d4e7..1e9a2add 100755 --- a/test/reconnect.sh +++ b/test/reconnect.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash echo "Starting node 1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 & diff --git a/test/test_tokenizers.py b/test/test_tokenizers.py index ebae91a1..3635357f 100644 --- a/test/test_tokenizers.py +++ b/test/test_tokenizers.py @@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False): strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id])) assert text == strip_tokens(decoded) == strip_tokens(reconstructed) -ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"] +ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"] ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")") models = [] for model_id in model_cards: