mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge branch 'main' into runners2
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -171,3 +171,5 @@ cython_debug/
|
|||||||
|
|
||||||
**/*.xcodeproj/*
|
**/*.xcodeproj/*
|
||||||
.aider*
|
.aider*
|
||||||
|
|
||||||
|
exo/tinychat/images/*.png
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
|
|||||||
[](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
|
[](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
|
||||||
[](https://www.gnu.org/licenses/gpl-3.0)
|
[](https://www.gnu.org/licenses/gpl-3.0)
|
||||||
|
|
||||||
|
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -38,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
|
|||||||
|
|
||||||
### Wide Model Support
|
### Wide Model Support
|
||||||
|
|
||||||
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
|
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
|
||||||
|
|
||||||
### Dynamic Model Partitioning
|
### Dynamic Model Partitioning
|
||||||
|
|
||||||
@@ -100,13 +102,13 @@ source install.sh
|
|||||||
|
|
||||||
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
|
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
|
||||||
|
|
||||||
1. Upgrade to the latest version of MacOS 15.
|
1. Upgrade to the latest version of macOS Sequoia.
|
||||||
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
|
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
|
||||||
|
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
### Example Usage on Multiple MacOS Devices
|
### Example Usage on Multiple macOS Devices
|
||||||
|
|
||||||
#### Device 1:
|
#### Device 1:
|
||||||
|
|
||||||
@@ -177,9 +179,9 @@ curl http://localhost:52415/v1/chat/completions \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
|
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
|
||||||
|
|
||||||
#### Device 1 (MacOS):
|
#### Device 1 (macOS):
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
exo
|
exo
|
||||||
@@ -244,7 +246,7 @@ python3 format.py ./exo
|
|||||||
|
|
||||||
## Known Issues
|
## Known Issues
|
||||||
|
|
||||||
- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
|
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
/Applications/Python 3.x/Install Certificates.command
|
/Applications/Python 3.x/Install Certificates.command
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
# Get the total memory in MB
|
# Get the total memory in MB
|
||||||
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
|
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
|
||||||
|
|||||||
111
examples/function_calling.py
Normal file
111
examples/function_calling.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import requests
|
||||||
|
|
||||||
|
def get_current_weather(location: str, unit: str = "celsius"):
|
||||||
|
"""Mock weather data function"""
|
||||||
|
# Hardcoded response for demo purposes
|
||||||
|
return {
|
||||||
|
"location": location,
|
||||||
|
"temperature": 22 if unit == "celsius" else 72,
|
||||||
|
"unit": unit,
|
||||||
|
"forecast": "Sunny with light clouds"
|
||||||
|
}
|
||||||
|
|
||||||
|
def try_parse_tool_calls(content: str):
|
||||||
|
"""Try parse the tool calls."""
|
||||||
|
tool_calls = []
|
||||||
|
offset = 0
|
||||||
|
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
|
||||||
|
if i == 0:
|
||||||
|
offset = m.start()
|
||||||
|
try:
|
||||||
|
func = json.loads(m.group(1))
|
||||||
|
tool_calls.append({"type": "function", "function": func})
|
||||||
|
if isinstance(func["arguments"], str):
|
||||||
|
func["arguments"] = json.loads(func["arguments"])
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
|
||||||
|
pass
|
||||||
|
if tool_calls:
|
||||||
|
if offset > 0 and content[:offset].strip():
|
||||||
|
c = content[:offset]
|
||||||
|
else:
|
||||||
|
c = ""
|
||||||
|
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
|
||||||
|
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
|
||||||
|
|
||||||
|
def chat_completion(messages):
|
||||||
|
"""Send chat completion request to local server"""
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:52415/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "qwen-2.5-1.5b",
|
||||||
|
"messages": messages,
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"tool_choice": "auto"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initial conversation
|
||||||
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi there, what's the weather in Boston?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# Get initial response
|
||||||
|
response = chat_completion(messages)
|
||||||
|
print(f"First response: {response}")
|
||||||
|
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
|
||||||
|
messages.append(assistant_message)
|
||||||
|
|
||||||
|
# If there are tool calls, execute them and continue conversation
|
||||||
|
if "tool_calls" in assistant_message:
|
||||||
|
for tool_call in assistant_message["tool_calls"]:
|
||||||
|
if tool_call["function"]["name"] == "get_current_weather":
|
||||||
|
args = tool_call["function"]["arguments"]
|
||||||
|
weather_data = get_current_weather(**args)
|
||||||
|
|
||||||
|
# Add tool response to messages
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": json.dumps(weather_data),
|
||||||
|
"name": tool_call["function"]["name"]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get final response with weather data
|
||||||
|
response = chat_completion(messages)
|
||||||
|
print(f"Final response: {response}")
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response["choices"][0]["message"]["content"]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Print full conversation
|
||||||
|
for msg in messages:
|
||||||
|
print(f"\n{msg['role'].upper()}: {msg['content']}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,18 +5,24 @@ import json
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from typing import List, Literal, Union, Dict
|
from typing import List, Literal, Union, Dict, Optional
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import aiohttp_cors
|
import aiohttp_cors
|
||||||
import traceback
|
import traceback
|
||||||
import signal
|
import signal
|
||||||
from exo import DEBUG, VERSION
|
from exo import DEBUG, VERSION
|
||||||
from exo.download.download_progress import RepoProgressEvent
|
from exo.download.download_progress import RepoProgressEvent
|
||||||
from exo.helpers import PrefixDict, shutdown
|
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
|
||||||
from exo.inference.tokenizers import resolve_tokenizer
|
from exo.inference.tokenizers import resolve_tokenizer
|
||||||
from exo.orchestration import Node
|
from exo.orchestration import Node
|
||||||
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
import mlx.core as mx
|
||||||
|
import tempfile
|
||||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||||
import shutil
|
import shutil
|
||||||
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
||||||
@@ -24,23 +30,28 @@ from exo.apputil import create_animation_mp4
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||||
self.role = role
|
self.role = role
|
||||||
self.content = content
|
self.content = content
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {"role": self.role, "content": self.content}
|
data = {"role": self.role, "content": self.content}
|
||||||
|
if self.tools:
|
||||||
|
data["tools"] = self.tools
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest:
|
class ChatCompletionRequest:
|
||||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
|
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
|
||||||
|
|
||||||
|
|
||||||
def generate_completion(
|
def generate_completion(
|
||||||
@@ -120,20 +131,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
|||||||
return remapped_messages
|
return remapped_messages
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(tokenizer, _messages: List[Message]):
|
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
||||||
messages = remap_messages(_messages)
|
messages = remap_messages(_messages)
|
||||||
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
|
chat_template_args = {
|
||||||
for message in messages:
|
"conversation": [m.to_dict() for m in messages],
|
||||||
if not isinstance(message.content, list):
|
"tokenize": False,
|
||||||
continue
|
"add_generation_prompt": True
|
||||||
|
}
|
||||||
|
if tools: chat_template_args["tools"] = tools
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||||
|
print(f"!!! Prompt: {prompt}")
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def parse_message(data: dict):
|
def parse_message(data: dict):
|
||||||
if "role" not in data or "content" not in data:
|
if "role" not in data or "content" not in data:
|
||||||
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
||||||
return Message(data["role"], data["content"])
|
return Message(data["role"], data["content"], data.get("tools"))
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_request(data: dict, default_model: str):
|
def parse_chat_request(data: dict, default_model: str):
|
||||||
@@ -141,6 +156,7 @@ def parse_chat_request(data: dict, default_model: str):
|
|||||||
data.get("model", default_model),
|
data.get("model", default_model),
|
||||||
[parse_message(msg) for msg in data["messages"]],
|
[parse_message(msg) for msg in data["messages"]],
|
||||||
data.get("temperature", 0.0),
|
data.get("temperature", 0.0),
|
||||||
|
data.get("tools", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -151,7 +167,7 @@ class PromptSession:
|
|||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
|
|
||||||
class ChatGPTAPI:
|
class ChatGPTAPI:
|
||||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
||||||
self.node = node
|
self.node = node
|
||||||
self.inference_engine_classname = inference_engine_classname
|
self.inference_engine_classname = inference_engine_classname
|
||||||
self.response_timeout = response_timeout
|
self.response_timeout = response_timeout
|
||||||
@@ -166,6 +182,7 @@ class ChatGPTAPI:
|
|||||||
# Get the callback system and register our handler
|
# Get the callback system and register our handler
|
||||||
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
||||||
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
|
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
cors = aiohttp_cors.setup(self.app)
|
cors = aiohttp_cors.setup(self.app)
|
||||||
cors_options = aiohttp_cors.ResourceOptions(
|
cors_options = aiohttp_cors.ResourceOptions(
|
||||||
@@ -180,6 +197,7 @@ class ChatGPTAPI:
|
|||||||
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
||||||
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
||||||
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
||||||
|
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
|
||||||
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
||||||
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
||||||
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
||||||
@@ -191,10 +209,12 @@ class ChatGPTAPI:
|
|||||||
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
|
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
|
||||||
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
||||||
|
|
||||||
|
|
||||||
if "__compiled__" not in globals():
|
if "__compiled__" not in globals():
|
||||||
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
||||||
self.app.router.add_get("/", self.handle_root)
|
self.app.router.add_get("/", self.handle_root)
|
||||||
self.app.router.add_static("/", self.static_dir, name="static")
|
self.app.router.add_static("/", self.static_dir, name="static")
|
||||||
|
self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
|
||||||
|
|
||||||
self.app.middlewares.append(self.timeout_middleware)
|
self.app.middlewares.append(self.timeout_middleware)
|
||||||
self.app.middlewares.append(self.log_request)
|
self.app.middlewares.append(self.log_request)
|
||||||
@@ -241,7 +261,7 @@ class ChatGPTAPI:
|
|||||||
)
|
)
|
||||||
await response.prepare(request)
|
await response.prepare(request)
|
||||||
|
|
||||||
for model_name, pretty in pretty_name.items():
|
async def process_model(model_name, pretty):
|
||||||
if model_name in model_cards:
|
if model_name in model_cards:
|
||||||
model_info = model_cards[model_name]
|
model_info = model_cards[model_name]
|
||||||
|
|
||||||
@@ -269,6 +289,12 @@ class ChatGPTAPI:
|
|||||||
|
|
||||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||||
|
|
||||||
|
# Process all models in parallel
|
||||||
|
await asyncio.gather(*[
|
||||||
|
process_model(model_name, pretty)
|
||||||
|
for model_name, pretty in pretty_name.items()
|
||||||
|
])
|
||||||
|
|
||||||
await response.write(b"data: [DONE]\n\n")
|
await response.write(b"data: [DONE]\n\n")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -281,7 +307,8 @@ class ChatGPTAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def handle_get_models(self, request):
|
async def handle_get_models(self, request):
|
||||||
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
||||||
|
return web.json_response({"object": "list", "data": models_list})
|
||||||
|
|
||||||
async def handle_post_chat_token_encode(self, request):
|
async def handle_post_chat_token_encode(self, request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
@@ -294,7 +321,7 @@ class ChatGPTAPI:
|
|||||||
shard = build_base_shard(model, self.inference_engine_classname)
|
shard = build_base_shard(model, self.inference_engine_classname)
|
||||||
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
||||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||||
prompt = build_prompt(tokenizer, messages)
|
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
|
||||||
tokens = tokenizer.encode(prompt)
|
tokens = tokenizer.encode(prompt)
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"length": len(prompt),
|
"length": len(prompt),
|
||||||
@@ -314,13 +341,13 @@ class ChatGPTAPI:
|
|||||||
|
|
||||||
async def handle_post_chat_completions(self, request):
|
async def handle_post_chat_completions(self, request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
|
||||||
stream = data.get("stream", False)
|
stream = data.get("stream", False)
|
||||||
chat_request = parse_chat_request(data, self.default_model)
|
chat_request = parse_chat_request(data, self.default_model)
|
||||||
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
|
||||||
chat_request.model = self.default_model
|
chat_request.model = self.default_model
|
||||||
if not chat_request.model or chat_request.model not in model_cards:
|
if not chat_request.model or chat_request.model not in model_cards:
|
||||||
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
||||||
chat_request.model = self.default_model
|
chat_request.model = self.default_model
|
||||||
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
|
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
|
||||||
if not shard:
|
if not shard:
|
||||||
@@ -331,34 +358,26 @@ class ChatGPTAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||||
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
|
||||||
|
|
||||||
prompt = build_prompt(tokenizer, chat_request.messages)
|
# Add system prompt if set
|
||||||
|
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
||||||
|
chat_request.messages.insert(0, Message("system", self.system_prompt))
|
||||||
|
|
||||||
|
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
if self.on_chat_completion_request:
|
if self.on_chat_completion_request:
|
||||||
try:
|
try:
|
||||||
self.on_chat_completion_request(request_id, chat_request, prompt)
|
self.on_chat_completion_request(request_id, chat_request, prompt)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if DEBUG >= 2: traceback.print_exc()
|
if DEBUG >= 2: traceback.print_exc()
|
||||||
# request_id = None
|
|
||||||
# match = self.prompts.find_longest_prefix(prompt)
|
|
||||||
# if match and len(prompt) > len(match[1].prompt):
|
|
||||||
# if DEBUG >= 2:
|
|
||||||
# print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
|
|
||||||
# request_id = match[1].request_id
|
|
||||||
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
|
||||||
# # remove the matching prefix from the prompt
|
|
||||||
# prompt = prompt[len(match[1].prompt):]
|
|
||||||
# else:
|
|
||||||
# request_id = str(uuid.uuid4())
|
|
||||||
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
|
||||||
|
|
||||||
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
|
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
|
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
|
||||||
|
|
||||||
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
|
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
response = web.StreamResponse(
|
response = web.StreamResponse(
|
||||||
@@ -374,10 +393,12 @@ class ChatGPTAPI:
|
|||||||
try:
|
try:
|
||||||
# Stream tokens while waiting for inference to complete
|
# Stream tokens while waiting for inference to complete
|
||||||
while True:
|
while True:
|
||||||
|
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
|
||||||
token, is_finished = await asyncio.wait_for(
|
token, is_finished = await asyncio.wait_for(
|
||||||
self.token_queues[request_id].get(),
|
self.token_queues[request_id].get(),
|
||||||
timeout=self.response_timeout
|
timeout=self.response_timeout
|
||||||
)
|
)
|
||||||
|
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
|
||||||
|
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
|
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
|
||||||
@@ -408,10 +429,13 @@ class ChatGPTAPI:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
|
||||||
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if DEBUG >= 2: traceback.print_exc()
|
if DEBUG >= 2:
|
||||||
|
print(f"[ChatGPTAPI] Error processing prompt: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"detail": f"Error processing prompt: {str(e)}"},
|
{"detail": f"Error processing prompt: {str(e)}"},
|
||||||
status=500
|
status=500
|
||||||
@@ -420,6 +444,7 @@ class ChatGPTAPI:
|
|||||||
finally:
|
finally:
|
||||||
# Clean up the queue for this request
|
# Clean up the queue for this request
|
||||||
if request_id in self.token_queues:
|
if request_id in self.token_queues:
|
||||||
|
if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
|
||||||
del self.token_queues[request_id]
|
del self.token_queues[request_id]
|
||||||
else:
|
else:
|
||||||
tokens = []
|
tokens = []
|
||||||
@@ -441,6 +466,85 @@ class ChatGPTAPI:
|
|||||||
if DEBUG >= 2: traceback.print_exc()
|
if DEBUG >= 2: traceback.print_exc()
|
||||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_post_image_generations(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
|
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
||||||
|
stream = data.get("stream", False)
|
||||||
|
model = data.get("model", "")
|
||||||
|
prompt = data.get("prompt", "")
|
||||||
|
image_url = data.get("image_url", "")
|
||||||
|
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
|
||||||
|
shard = build_base_shard(model, self.inference_engine_classname)
|
||||||
|
if DEBUG >= 2: print(f"shard: {shard}")
|
||||||
|
if not shard:
|
||||||
|
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
||||||
|
callback = self.node.on_token.register(callback_id)
|
||||||
|
try:
|
||||||
|
if image_url != "" and image_url != None:
|
||||||
|
img = self.base64_decode(image_url)
|
||||||
|
else:
|
||||||
|
img = None
|
||||||
|
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
||||||
|
|
||||||
|
|
||||||
|
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
|
||||||
|
await response.prepare(request)
|
||||||
|
|
||||||
|
def get_progress_bar(current_step, total_steps, bar_length=50):
|
||||||
|
# Calculate the percentage of completion
|
||||||
|
percent = float(current_step) / total_steps
|
||||||
|
# Calculate the number of hashes to display
|
||||||
|
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
|
||||||
|
spaces = ' ' * (bar_length - len(arrow))
|
||||||
|
|
||||||
|
# Create the progress bar string
|
||||||
|
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
|
||||||
|
return progress_bar
|
||||||
|
|
||||||
|
async def stream_image(_request_id: str, result, is_finished: bool):
|
||||||
|
if isinstance(result, list):
|
||||||
|
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
||||||
|
|
||||||
|
elif isinstance(result, np.ndarray):
|
||||||
|
im = Image.fromarray(np.array(result))
|
||||||
|
images_folder = get_exo_images_dir()
|
||||||
|
# Save the image to a file
|
||||||
|
image_filename = f"{_request_id}.png"
|
||||||
|
image_path = images_folder / image_filename
|
||||||
|
im.save(image_path)
|
||||||
|
image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
||||||
|
base_url = f"{request.scheme}://{request.host}"
|
||||||
|
# Construct the full URL correctly
|
||||||
|
full_image_url = base_url + str(image_url)
|
||||||
|
|
||||||
|
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||||
|
if is_finished:
|
||||||
|
await response.write_eof()
|
||||||
|
|
||||||
|
|
||||||
|
stream_task = None
|
||||||
|
def on_result(_request_id: str, result, is_finished: bool):
|
||||||
|
nonlocal stream_task
|
||||||
|
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
||||||
|
return _request_id == request_id and is_finished
|
||||||
|
|
||||||
|
await callback.wait(on_result, timeout=self.response_timeout*10)
|
||||||
|
|
||||||
|
if stream_task:
|
||||||
|
# Wait for the stream task to complete before returning
|
||||||
|
await stream_task
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if DEBUG >= 2: traceback.print_exc()
|
||||||
|
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||||
|
|
||||||
async def handle_delete_model(self, request):
|
async def handle_delete_model(self, request):
|
||||||
try:
|
try:
|
||||||
model_name = request.match_info.get('model_name')
|
model_name = request.match_info.get('model_name')
|
||||||
@@ -553,7 +657,7 @@ class ChatGPTAPI:
|
|||||||
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
||||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||||
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||||
|
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@@ -585,3 +689,19 @@ class ChatGPTAPI:
|
|||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(runner, host, port)
|
site = web.TCPSite(runner, host, port)
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
|
def base64_decode(self, base64_string):
|
||||||
|
#decode and reshape image
|
||||||
|
if base64_string.startswith('data:image'):
|
||||||
|
base64_string = base64_string.split(',')[1]
|
||||||
|
image_data = base64.b64decode(base64_string)
|
||||||
|
img = Image.open(BytesIO(image_data))
|
||||||
|
W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
||||||
|
if W != img.width or H != img.height:
|
||||||
|
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||||
|
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||||
|
img = mx.array(np.array(img))
|
||||||
|
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||||
|
img = img[None]
|
||||||
|
return img
|
||||||
|
|
||||||
|
|||||||
@@ -303,6 +303,10 @@ async def download_repo_files(
|
|||||||
await f.write(json.dumps(file_list))
|
await f.write(json.dumps(file_list))
|
||||||
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
|
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
|
||||||
|
|
||||||
|
model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
|
||||||
|
if model_index_exists:
|
||||||
|
allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
|
||||||
|
|
||||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
||||||
total_files = len(filtered_file_list)
|
total_files = len(filtered_file_list)
|
||||||
total_bytes = sum(file["size"] for file in filtered_file_list)
|
total_bytes = sum(file["size"] for file in filtered_file_list)
|
||||||
|
|||||||
@@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader):
|
|||||||
print(f"No snapshot directory found for {self.current_repo_id}")
|
print(f"No snapshot directory found for {self.current_repo_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if not await aios.path.exists(snapshot_dir/"model_index.json"):
|
||||||
# Get the weight map to know what files we need
|
# Get the weight map to know what files we need
|
||||||
weight_map = await get_weight_map(self.current_repo_id, self.revision)
|
weight_map = await get_weight_map(self.current_repo_id, self.revision)
|
||||||
if not weight_map:
|
if not weight_map:
|
||||||
if DEBUG >= 2:
|
if DEBUG >= 2:
|
||||||
print(f"No weight map found for {self.current_repo_id}")
|
print(f"No weight map found for {self.current_repo_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Get all files needed for this shard
|
||||||
|
patterns = get_allow_patterns(weight_map, self.current_shard)
|
||||||
|
else:
|
||||||
|
patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
|
||||||
|
|
||||||
# Get all files needed for this shard
|
|
||||||
patterns = get_allow_patterns(weight_map, self.current_shard)
|
|
||||||
|
|
||||||
# Check download status for all relevant files
|
# Check download status for all relevant files
|
||||||
status = {}
|
status = {}
|
||||||
|
|||||||
@@ -351,3 +351,20 @@ async def get_mac_system_info() -> Tuple[str, str, int]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
|
if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
|
||||||
return "Unknown Model", "Unknown Chip", 0
|
return "Unknown Model", "Unknown Chip", 0
|
||||||
|
|
||||||
|
def get_exo_home() -> Path:
|
||||||
|
if os.name == "nt": # Check if the OS is Windows
|
||||||
|
docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
|
||||||
|
else:
|
||||||
|
docs_folder = Path.home() / "Documents"
|
||||||
|
exo_folder = docs_folder / "Exo"
|
||||||
|
if not exo_folder.exists():
|
||||||
|
exo_folder.mkdir()
|
||||||
|
return exo_folder
|
||||||
|
|
||||||
|
def get_exo_images_dir() -> Path:
|
||||||
|
exo_home = get_exo_home()
|
||||||
|
images_dir = exo_home / "Images"
|
||||||
|
if not images_dir.exists():
|
||||||
|
images_dir.mkdir()
|
||||||
|
return images_dir
|
||||||
|
|||||||
@@ -39,11 +39,15 @@ class InferenceEngine(ABC):
|
|||||||
async def clear_session(self):
|
async def clear_session(self):
|
||||||
self.session.empty()
|
self.session.empty()
|
||||||
|
|
||||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
|
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
|
||||||
tokens = await self.encode(shard, prompt)
|
tokens = await self.encode(shard, prompt)
|
||||||
x = tokens.reshape(1, -1)
|
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||||
output_data = await self.infer_tensor(request_id, shard, x)
|
x = tokens.reshape(1, -1)
|
||||||
return output_data
|
else:
|
||||||
|
x = tokens
|
||||||
|
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
|
||||||
|
|
||||||
|
return output_data, inference_state
|
||||||
|
|
||||||
inference_engine_classes = {
|
inference_engine_classes = {
|
||||||
"mlx": "MLXDynamicShardInferenceEngine",
|
"mlx": "MLXDynamicShardInferenceEngine",
|
||||||
|
|||||||
307
exo/inference/mlx/models/StableDiffusionPipeline.py
Normal file
307
exo/inference/mlx/models/StableDiffusionPipeline.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .sd_models.vae import ModelArgs as VAEArgs
|
||||||
|
from .sd_models.vae import Autoencoder
|
||||||
|
from .sd_models.tokenizer import load_tokenizer
|
||||||
|
from .sd_models.clip import CLIPTextModel
|
||||||
|
from .sd_models.clip import ModelArgs as CLIPArgs
|
||||||
|
from .sd_models.unet import UNetConfig, UNetModel
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from exo.inference.shard import Shard
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffusionConfig:
|
||||||
|
beta_schedule: str = "scaled_linear"
|
||||||
|
beta_start: float = 0.00085
|
||||||
|
beta_end: float = 0.012
|
||||||
|
num_train_steps: int = 1000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||||
|
|
||||||
|
|
||||||
|
#Sampler
|
||||||
|
def _linspace(a, b, num):
|
||||||
|
x = mx.arange(0, num) / (num - 1)
|
||||||
|
return (b - a) * x + a
|
||||||
|
|
||||||
|
|
||||||
|
def _interp(y, x_new):
|
||||||
|
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
||||||
|
x_low = x_new.astype(mx.int32)
|
||||||
|
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
||||||
|
|
||||||
|
y_low = y[x_low]
|
||||||
|
y_high = y[x_high]
|
||||||
|
delta_x = x_new - x_low
|
||||||
|
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
||||||
|
|
||||||
|
return y_new
|
||||||
|
|
||||||
|
class SimpleEulerSampler:
|
||||||
|
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
||||||
|
|
||||||
|
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: DiffusionConfig):
|
||||||
|
# Compute the noise schedule
|
||||||
|
if config.beta_schedule == "linear":
|
||||||
|
betas = _linspace(
|
||||||
|
config.beta_start, config.beta_end, config.num_train_steps
|
||||||
|
)
|
||||||
|
elif config.beta_schedule == "scaled_linear":
|
||||||
|
betas = _linspace(
|
||||||
|
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
||||||
|
).square()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
||||||
|
|
||||||
|
alphas = 1 - betas
|
||||||
|
alphas_cumprod = mx.cumprod(alphas)
|
||||||
|
|
||||||
|
self._sigmas = mx.concatenate(
|
||||||
|
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_time(self):
|
||||||
|
return len(self._sigmas) - 1
|
||||||
|
|
||||||
|
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||||
|
noise = mx.random.normal(shape, key=key)
|
||||||
|
return (
|
||||||
|
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||||
|
).astype(dtype)
|
||||||
|
|
||||||
|
def add_noise(self, x, t, key=None):
|
||||||
|
noise = mx.random.normal(x.shape, key=key)
|
||||||
|
s = self.sigmas(t)
|
||||||
|
return (x + noise * s) * (s.square() + 1).rsqrt()
|
||||||
|
|
||||||
|
def sigmas(self, t):
|
||||||
|
return _interp(self._sigmas, t)
|
||||||
|
|
||||||
|
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
|
||||||
|
start_time = start_time or (len(self._sigmas) - 1)
|
||||||
|
assert 0 < start_time <= (len(self._sigmas) - 1)
|
||||||
|
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
|
||||||
|
return list(zip(steps, steps[1:]))
|
||||||
|
|
||||||
|
def current_timestep(self, step, total_steps, start_time=None):
|
||||||
|
if step < total_steps:
|
||||||
|
steps = self.timesteps(total_steps, start_time)
|
||||||
|
return steps[step]
|
||||||
|
else:
|
||||||
|
return mx.array(0),mx.array(0)
|
||||||
|
|
||||||
|
def step(self, eps_pred, x_t, t, t_prev):
|
||||||
|
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||||
|
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||||
|
|
||||||
|
dt = sigma_prev - sigma
|
||||||
|
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
||||||
|
|
||||||
|
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||||
|
|
||||||
|
return x_t_prev
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShardConfig:
|
||||||
|
model_id:str
|
||||||
|
start_layer:int
|
||||||
|
end_layer:int
|
||||||
|
n_layers:int
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StableDiffusionConfig:
|
||||||
|
model_type:str
|
||||||
|
vae:VAEArgs
|
||||||
|
text_encoder:CLIPArgs
|
||||||
|
scheduler:DiffusionConfig
|
||||||
|
unet:UNetConfig
|
||||||
|
shard:ShardConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(StableDiffusionConfig):
|
||||||
|
shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if isinstance(self.shard, dict):
|
||||||
|
self.shard = Shard(**self.shard)
|
||||||
|
|
||||||
|
if not isinstance(self.shard, Shard):
|
||||||
|
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = config.model_type
|
||||||
|
self.config = config
|
||||||
|
self.model_path = config.vae['path'].split('/vae')[0]
|
||||||
|
self.shard = config.shard
|
||||||
|
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
|
||||||
|
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
|
||||||
|
if self.shard_clip.start_layer != -1:
|
||||||
|
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
|
||||||
|
else:
|
||||||
|
self.text_encoder = nn.Identity()
|
||||||
|
self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
|
||||||
|
self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
|
||||||
|
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||||
|
if self.shard_unet.start_layer!=-1:
|
||||||
|
self.config_unet = UNetConfig.from_dict(config.unet['config'])
|
||||||
|
self.unet = UNetModel(self.config_unet, self.shard_unet)
|
||||||
|
else:
|
||||||
|
self.unet = nn.Identity()
|
||||||
|
self.config_vae=VAEArgs.from_dict(config.vae['config'])
|
||||||
|
if self.shard_encoder.start_layer != -1:
|
||||||
|
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Identity()
|
||||||
|
if self.shard_decoder.start_layer != -1:
|
||||||
|
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
|
||||||
|
else:
|
||||||
|
self.decoder = nn.Identity()
|
||||||
|
|
||||||
|
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
|
||||||
|
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||||
|
is_finished = False
|
||||||
|
is_step_finished = False
|
||||||
|
if t.item()==1000:
|
||||||
|
if self.shard_clip.start_layer == 0:
|
||||||
|
conditioning = x
|
||||||
|
if self.shard_clip.start_layer != -1:
|
||||||
|
conditioning, mask= self.text_encoder(conditioning,mask)
|
||||||
|
seed = int(time.time())
|
||||||
|
mx.random.seed(seed)
|
||||||
|
if image is None:
|
||||||
|
if self.shard_encoder.is_last_layer():
|
||||||
|
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
|
||||||
|
x_t_prev=x
|
||||||
|
start_step = self.sampler.max_time
|
||||||
|
else:
|
||||||
|
if self.shard_encoder.start_layer != -1:
|
||||||
|
image= self.encoder.encode(image)
|
||||||
|
if self.shard_encoder.is_last_layer():
|
||||||
|
start_step = self.sampler.max_time*strength
|
||||||
|
total_steps = int(total_steps*strength)
|
||||||
|
image = mx.broadcast_to(image, (1,) + image.shape[1:])
|
||||||
|
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
|
||||||
|
image = None
|
||||||
|
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||||
|
# Perform the denoising loop
|
||||||
|
if self.shard_unet.start_layer != -1:
|
||||||
|
with tqdm(total=total_steps,initial=step+1) as pbar:
|
||||||
|
if step<total_steps:
|
||||||
|
x = x_t_prev
|
||||||
|
if self.shard_unet.is_first_layer():
|
||||||
|
x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
|
||||||
|
else:
|
||||||
|
x_t_unet = x
|
||||||
|
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||||
|
x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
|
||||||
|
if self.shard_unet.is_last_layer():
|
||||||
|
if cfg_weight > 1:
|
||||||
|
eps_text, eps_neg = x.split(2)
|
||||||
|
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||||
|
x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
|
||||||
|
x_t_prev=x
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
if self.shard_decoder.is_last_layer():
|
||||||
|
is_step_finished=True
|
||||||
|
if self.shard_decoder.start_layer != -1:
|
||||||
|
x=self.decoder.decode(x)
|
||||||
|
if self.shard_decoder.is_last_layer():
|
||||||
|
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||||
|
x = x.reshape(1 * H, B // 1 * W, C)
|
||||||
|
x = (x * 255).astype(mx.uint8)
|
||||||
|
if t_prev.item() ==0:
|
||||||
|
is_finished=True
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
|
||||||
|
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
if self.shard_encoder.start_layer != -1:
|
||||||
|
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||||
|
vae_weights = self.encoder.sanitize(vae_weights)
|
||||||
|
self.encoder.load_weights(list(vae_weights.items()), strict=True)
|
||||||
|
if self.shard_decoder.start_layer != -1:
|
||||||
|
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||||
|
vae_weights = self.decoder.sanitize(vae_weights)
|
||||||
|
self.decoder.load_weights(list(vae_weights.items()), strict=True)
|
||||||
|
if self.shard_clip.start_layer != -1:
|
||||||
|
clip_weights = mx.load(self.config_clip.weight_files[0])
|
||||||
|
clip_weights = self.text_encoder.sanitize(clip_weights)
|
||||||
|
self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
|
||||||
|
if self.shard_unet.start_layer !=-1:
|
||||||
|
unet_weights = mx.load(self.config_unet.weight_files[0])
|
||||||
|
unet_weights = self.unet.sanitize(unet_weights)
|
||||||
|
self.unet.load_weights(list(unet_weights.items()), strict=True)
|
||||||
|
|
||||||
|
def model_shards(shard:ShardConfig):
|
||||||
|
def create_shard(shard, model_ranges):
|
||||||
|
start_layer = shard.start_layer
|
||||||
|
end_layer = shard.end_layer
|
||||||
|
|
||||||
|
shards = {}
|
||||||
|
|
||||||
|
for model_name, (range_start, range_end) in model_ranges.items():
|
||||||
|
if start_layer < range_end and end_layer >= range_start:
|
||||||
|
# Calculate the overlap with the model range
|
||||||
|
overlap_start = max(start_layer, range_start)
|
||||||
|
overlap_end = min(end_layer, range_end - 1)
|
||||||
|
|
||||||
|
# Adjust the layers relative to the model's range
|
||||||
|
relative_start = overlap_start - range_start
|
||||||
|
relative_end = overlap_end - range_start
|
||||||
|
shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
|
||||||
|
else:
|
||||||
|
# If no overlap, create a zero-layer shard
|
||||||
|
shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
|
||||||
|
|
||||||
|
return shards
|
||||||
|
|
||||||
|
# Define the ranges for different models
|
||||||
|
model_ranges = {
|
||||||
|
'clip': (0, 12),
|
||||||
|
'vae_encoder':(12,17),
|
||||||
|
'unet':(17,26),
|
||||||
|
'vae_decoder': (26, 31) # Example range for unet
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the function and get the shards for all models
|
||||||
|
shards = create_shard(shard, model_ranges)
|
||||||
|
|
||||||
|
# Access individual shards
|
||||||
|
shard_clip = shards['clip']
|
||||||
|
shard_encoder = shards['vae_encoder']
|
||||||
|
shard_unet = shards['unet']
|
||||||
|
shard_decoder = shards['vae_decoder']
|
||||||
|
|
||||||
|
return shard_clip, shard_encoder, shard_unet, shard_decoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
117
exo/inference/mlx/models/phi3.py
Normal file
117
exo/inference/mlx/models/phi3.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_lm.models.base import create_attention_mask
|
||||||
|
from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
|
||||||
|
|
||||||
|
from ...shard import Shard
|
||||||
|
from .base import IdentityBlock
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(ModelArgs):
|
||||||
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if isinstance(self.shard, Shard):
|
||||||
|
return
|
||||||
|
if not isinstance(self.shard, dict):
|
||||||
|
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||||
|
|
||||||
|
self.shard = Shard(**self.shard)
|
||||||
|
|
||||||
|
class Phi3Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
|
||||||
|
if self.args.shard.is_first_layer():
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
|
||||||
|
self.layers = []
|
||||||
|
for i in range(self.num_hidden_layers):
|
||||||
|
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||||
|
self.layers.append(TransformerBlock(args=args))
|
||||||
|
else:
|
||||||
|
self.layers.append(IdentityBlock())
|
||||||
|
|
||||||
|
if self.args.shard.is_last_layer():
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
if self.args.shard.is_first_layer():
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
else:
|
||||||
|
h = inputs
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
if self.args.shard.is_last_layer():
|
||||||
|
h = self.norm(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = Phi3Model(args)
|
||||||
|
if self.args.shard.is_last_layer():
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
if self.args.shard.is_last_layer():
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
shard_state_dict = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
if "self_attn.rope.inv_freq" in key:
|
||||||
|
continue
|
||||||
|
if key.startswith('model.layers.'):
|
||||||
|
layer_num = int(key.split('.')[2])
|
||||||
|
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||||
|
shard_state_dict[key] = value
|
||||||
|
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||||
|
shard_state_dict[key] = value
|
||||||
|
elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
|
||||||
|
shard_state_dict[key] = value
|
||||||
|
|
||||||
|
return shard_state_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
|
|||||||
from ...shard import Shard
|
from ...shard import Shard
|
||||||
from .base import IdentityBlock
|
from .base import IdentityBlock
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(ModelArgs):
|
class ModelArgs(ModelArgs):
|
||||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__() # Ensure parent initializations are respected
|
super().__post_init__()
|
||||||
|
|
||||||
if isinstance(self.shard, Shard):
|
if isinstance(self.shard, Shard):
|
||||||
return
|
return
|
||||||
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
|
|||||||
|
|
||||||
self.shard = Shard(**self.shard)
|
self.shard = Shard(**self.shard)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Model(nn.Module):
|
class Qwen2Model(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
|
|||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
self.num_hidden_layers = args.num_hidden_layers
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
assert self.vocab_size > 0
|
assert self.vocab_size > 0
|
||||||
|
|
||||||
if self.args.shard.is_first_layer():
|
if self.args.shard.is_first_layer():
|
||||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
|
||||||
self.layers = []
|
self.layers = []
|
||||||
for i in range(self.num_hidden_layers):
|
for i in range(self.num_hidden_layers):
|
||||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||||
self.layers.append(TransformerBlock(args=args))
|
self.layers.append(TransformerBlock(args=args))
|
||||||
else:
|
else:
|
||||||
self.layers.append(IdentityBlock())
|
self.layers.append(IdentityBlock())
|
||||||
|
|
||||||
if self.args.shard.is_last_layer():
|
if self.args.shard.is_last_layer():
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
|||||||
191
exo/inference/mlx/models/sd_models/clip.py
Normal file
191
exo/inference/mlx/models/sd_models/clip.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from dataclasses import field, dataclass
|
||||||
|
from exo.inference.shard import Shard
|
||||||
|
from exo.inference.mlx.models.base import IdentityBlock
|
||||||
|
|
||||||
|
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPTextModelConfig:
|
||||||
|
num_layers: int = 23
|
||||||
|
model_dims: int = 1024
|
||||||
|
num_heads: int = 16
|
||||||
|
max_length: int = 77
|
||||||
|
vocab_size: int = 49408
|
||||||
|
projection_dim: Optional[int] = None
|
||||||
|
hidden_act: str = "quick_gelu"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config):
|
||||||
|
return ModelArgs(
|
||||||
|
num_layers=config["num_hidden_layers"],
|
||||||
|
model_dims=config["hidden_size"],
|
||||||
|
num_heads=config["num_attention_heads"],
|
||||||
|
max_length=config["max_position_embeddings"],
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
|
||||||
|
hidden_act=config.get("hidden_act", "quick_gelu"),
|
||||||
|
weight_files=config.get("weight_files", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(CLIPTextModelConfig):
|
||||||
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||||
|
weight_files: List[str] = field(default_factory=lambda: [])
|
||||||
|
def __post_init__(self):
|
||||||
|
if isinstance(self.shard, dict):
|
||||||
|
self.shard = Shard(**self.shard)
|
||||||
|
|
||||||
|
if not isinstance(self.shard, Shard):
|
||||||
|
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||||
|
|
||||||
|
if not self.shard.is_first_layer():
|
||||||
|
self.vision_config = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPOutput:
|
||||||
|
pooled_output: Optional[mx.array] = None
|
||||||
|
last_hidden_state: Optional[mx.array] = None
|
||||||
|
hidden_states: Optional[List[mx.array]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(nn.Module):
|
||||||
|
"""The transformer encoder layer from CLIP."""
|
||||||
|
|
||||||
|
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||||
|
|
||||||
|
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
||||||
|
self.attention.query_proj.bias = mx.zeros(model_dims)
|
||||||
|
self.attention.key_proj.bias = mx.zeros(model_dims)
|
||||||
|
self.attention.value_proj.bias = mx.zeros(model_dims)
|
||||||
|
self.attention.out_proj.bias = mx.zeros(model_dims)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||||
|
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||||
|
|
||||||
|
self.act = _ACTIVATIONS[activation]
|
||||||
|
|
||||||
|
def __call__(self, x, attn_mask=None):
|
||||||
|
|
||||||
|
y = self.layer_norm1(x)
|
||||||
|
y = self.attention(y, y, y, attn_mask)
|
||||||
|
x = y + x
|
||||||
|
|
||||||
|
y = self.layer_norm2(x)
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = self.act(y)
|
||||||
|
y = self.linear2(y)
|
||||||
|
x = y + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextModel(nn.Module):
|
||||||
|
"""Implements the text encoder transformer from CLIP."""
|
||||||
|
|
||||||
|
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shard = shard
|
||||||
|
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
|
||||||
|
if self.shard.is_first_layer():
|
||||||
|
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||||
|
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||||
|
self.layers = []
|
||||||
|
for i in range(math.ceil(config.num_layers/2)):
|
||||||
|
if 2*i in self.layers_range:
|
||||||
|
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||||
|
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
|
||||||
|
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||||
|
else:
|
||||||
|
self.layers.append(IdentityBlock())
|
||||||
|
if self.shard.is_last_layer():
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||||
|
|
||||||
|
if config.projection_dim is not None:
|
||||||
|
self.text_projection = nn.Linear(
|
||||||
|
config.model_dims, config.projection_dim, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_mask(self, N, dtype):
|
||||||
|
indices = mx.arange(N)
|
||||||
|
mask = indices[:, None] < indices[None]
|
||||||
|
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None):
|
||||||
|
# Extract some shapes
|
||||||
|
if self.shard.is_first_layer():
|
||||||
|
B, N = x.shape
|
||||||
|
eos_tokens = x.argmax(-1)
|
||||||
|
|
||||||
|
# Compute the embeddings
|
||||||
|
x = self.token_embedding(x)
|
||||||
|
|
||||||
|
x = x + self.position_embedding.weight[:N]
|
||||||
|
# Compute the features from the transformer
|
||||||
|
mask = self._get_mask(N, x.dtype)
|
||||||
|
|
||||||
|
for l in self.layers:
|
||||||
|
x = l(x, mask)
|
||||||
|
# Apply the final layernorm and return
|
||||||
|
|
||||||
|
if self.shard.is_last_layer():
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return x, mask
|
||||||
|
def sanitize(self, weights):
|
||||||
|
sanitized_weights = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
if "position_ids" in key:
|
||||||
|
continue
|
||||||
|
if key.startswith("text_model."):
|
||||||
|
key = key[11:]
|
||||||
|
if key.startswith("embeddings."):
|
||||||
|
key = key[11:]
|
||||||
|
if key.startswith("encoder."):
|
||||||
|
key = key[8:]
|
||||||
|
|
||||||
|
# Map attention layers
|
||||||
|
if "self_attn." in key:
|
||||||
|
key = key.replace("self_attn.", "attention.")
|
||||||
|
if "q_proj." in key:
|
||||||
|
key = key.replace("q_proj.", "query_proj.")
|
||||||
|
if "k_proj." in key:
|
||||||
|
key = key.replace("k_proj.", "key_proj.")
|
||||||
|
if "v_proj." in key:
|
||||||
|
key = key.replace("v_proj.", "value_proj.")
|
||||||
|
|
||||||
|
# Map ffn layers
|
||||||
|
if "mlp.fc1" in key:
|
||||||
|
key = key.replace("mlp.fc1", "linear1")
|
||||||
|
if "mlp.fc2" in key:
|
||||||
|
key = key.replace("mlp.fc2", "linear2")
|
||||||
|
|
||||||
|
if key.startswith("layers."):
|
||||||
|
layer_num = int(key.split(".")[1])
|
||||||
|
if layer_num not in self.layers_range:
|
||||||
|
continue
|
||||||
|
if not self.shard.is_first_layer() and "embedding" in key:
|
||||||
|
continue
|
||||||
|
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
|
||||||
|
continue
|
||||||
|
if not self.shard.is_last_layer() and key.startswith("text_projection"):
|
||||||
|
continue
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
return sanitized_weights
|
||||||
131
exo/inference/mlx/models/sd_models/tokenizer.py
Normal file
131
exo/inference/mlx/models/sd_models/tokenizer.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
|
||||||
|
|
||||||
|
import regex
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||||
|
|
||||||
|
def __init__(self, bpe_ranks, vocab):
|
||||||
|
self.bpe_ranks = bpe_ranks
|
||||||
|
self.vocab = vocab
|
||||||
|
self.pat = regex.compile(
|
||||||
|
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||||
|
regex.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos(self):
|
||||||
|
return "<|startoftext|>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_token(self):
|
||||||
|
return self.vocab[self.bos]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos(self):
|
||||||
|
return "<|endoftext|>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token(self):
|
||||||
|
return self.vocab[self.eos]
|
||||||
|
|
||||||
|
def bpe(self, text):
|
||||||
|
if text in self._cache:
|
||||||
|
return self._cache[text]
|
||||||
|
|
||||||
|
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||||
|
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||||
|
|
||||||
|
if not unique_bigrams:
|
||||||
|
return unigrams
|
||||||
|
|
||||||
|
# In every iteration try to merge the two most likely bigrams. If none
|
||||||
|
# was merged we are done.
|
||||||
|
#
|
||||||
|
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||||
|
while unique_bigrams:
|
||||||
|
bigram = min(
|
||||||
|
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||||
|
)
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
|
||||||
|
new_unigrams = []
|
||||||
|
skip = False
|
||||||
|
for a, b in zip(unigrams, unigrams[1:]):
|
||||||
|
if skip:
|
||||||
|
skip = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (a, b) == bigram:
|
||||||
|
new_unigrams.append(a + b)
|
||||||
|
skip = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_unigrams.append(a)
|
||||||
|
|
||||||
|
if not skip:
|
||||||
|
new_unigrams.append(b)
|
||||||
|
|
||||||
|
unigrams = new_unigrams
|
||||||
|
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||||
|
|
||||||
|
self._cache[text] = unigrams
|
||||||
|
|
||||||
|
return unigrams
|
||||||
|
|
||||||
|
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||||
|
if isinstance(text, list):
|
||||||
|
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||||
|
|
||||||
|
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||||
|
# a much more thorough job here but this should suffice for 95% of
|
||||||
|
# cases.
|
||||||
|
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||||
|
tokens = regex.findall(self.pat, clean_text)
|
||||||
|
|
||||||
|
# Split the tokens according to the byte-pair merge file
|
||||||
|
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||||
|
|
||||||
|
# Map to token ids and return
|
||||||
|
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||||
|
if prepend_bos:
|
||||||
|
tokens = [self.bos_token] + tokens
|
||||||
|
if append_eos:
|
||||||
|
tokens.append(self.eos_token)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def encode(self, prompt):
|
||||||
|
tokens = [self.tokenize(prompt)]
|
||||||
|
negative_text = ""
|
||||||
|
if negative_text is not None:
|
||||||
|
tokens += [self.tokenize(negative_text)]
|
||||||
|
lengths = [len(t) for t in tokens]
|
||||||
|
N = max(lengths)
|
||||||
|
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def load_tokenizer(
|
||||||
|
model_path: str,
|
||||||
|
vocab_key: str = "tokenizer_vocab",
|
||||||
|
merges_key: str = "tokenizer_merges",
|
||||||
|
):
|
||||||
|
|
||||||
|
vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
|
||||||
|
with open(vocab_file, encoding="utf-8") as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
|
||||||
|
merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
|
||||||
|
with open(merges_file, encoding="utf-8") as f:
|
||||||
|
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||||
|
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||||
|
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||||
|
|
||||||
|
return Tokenizer(bpe_ranks, vocab)
|
||||||
|
|
||||||
629
exo/inference/mlx/models/sd_models/unet.py
Normal file
629
exo/inference/mlx/models/sd_models/unet.py
Normal file
@@ -0,0 +1,629 @@
|
|||||||
|
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Tuple, Optional, List
|
||||||
|
from exo.inference.shard import Shard
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UNetConfig:
|
||||||
|
in_channels: int = 4
|
||||||
|
out_channels: int = 4
|
||||||
|
conv_in_kernel: int = 3
|
||||||
|
conv_out_kernel: int = 3
|
||||||
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||||
|
layers_per_block: Tuple[int] = (2, 2, 2, 2)
|
||||||
|
mid_block_layers: int = 2
|
||||||
|
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||||
|
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||||
|
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||||
|
norm_num_groups: int = 32
|
||||||
|
down_block_types: Tuple[str] = (
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
)
|
||||||
|
up_block_types: Tuple[str] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
)
|
||||||
|
addition_embed_type: Optional[str] = None
|
||||||
|
addition_time_embed_dim: Optional[int] = None
|
||||||
|
projection_class_embeddings_input_dim: Optional[int] = None
|
||||||
|
weight_files: List[str] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls,config):
|
||||||
|
n_blocks = len(config['block_out_channels'])
|
||||||
|
return UNetConfig(
|
||||||
|
in_channels=config["in_channels"],
|
||||||
|
out_channels=config["out_channels"],
|
||||||
|
block_out_channels=config["block_out_channels"],
|
||||||
|
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||||
|
transformer_layers_per_block=config.get(
|
||||||
|
"transformer_layers_per_block", (1,) * 4
|
||||||
|
),
|
||||||
|
num_attention_heads=(
|
||||||
|
[config["attention_head_dim"]] * n_blocks
|
||||||
|
if isinstance(config["attention_head_dim"], int)
|
||||||
|
else config["attention_head_dim"]
|
||||||
|
),
|
||||||
|
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||||
|
norm_num_groups=config["norm_num_groups"],
|
||||||
|
down_block_types=config["down_block_types"],
|
||||||
|
up_block_types=config["up_block_types"][::-1],
|
||||||
|
addition_embed_type=config.get("addition_embed_type", None),
|
||||||
|
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
|
||||||
|
projection_class_embeddings_input_dim=config.get(
|
||||||
|
"projection_class_embeddings_input_dim", None
|
||||||
|
),
|
||||||
|
weight_files=config.get("weight_files", [])
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_nearest(x, scale: int = 2):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||||
|
x = x.reshape(B, H * scale, W * scale, C)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
hidden_dims: Optional[int] = None,
|
||||||
|
memory_dims: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(model_dims)
|
||||||
|
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
||||||
|
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
||||||
|
|
||||||
|
memory_dims = memory_dims or model_dims
|
||||||
|
self.norm2 = nn.LayerNorm(model_dims)
|
||||||
|
self.attn2 = nn.MultiHeadAttention(
|
||||||
|
model_dims, num_heads, key_input_dims=memory_dims
|
||||||
|
)
|
||||||
|
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
||||||
|
|
||||||
|
hidden_dims = hidden_dims or 4 * model_dims
|
||||||
|
self.norm3 = nn.LayerNorm(model_dims)
|
||||||
|
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
||||||
|
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
||||||
|
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, attn_mask, memory_mask):
|
||||||
|
# Self attention
|
||||||
|
y = self.norm1(x)
|
||||||
|
y = self.attn1(y, y, y, attn_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
# Cross attention
|
||||||
|
y = self.norm2(x)
|
||||||
|
y = self.attn2(y, memory, memory, memory_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
y = self.norm3(x)
|
||||||
|
y_a = self.linear1(y)
|
||||||
|
y_b = self.linear2(y)
|
||||||
|
y = y_a * nn.gelu(y_b)
|
||||||
|
y = self.linear3(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer2D(nn.Module):
|
||||||
|
"""A transformer model for inputs with 2 spatial dimensions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
model_dims: int,
|
||||||
|
encoder_dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
||||||
|
self.proj_in = nn.Linear(in_channels, model_dims)
|
||||||
|
self.transformer_blocks = [
|
||||||
|
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
self.proj_out = nn.Linear(model_dims, in_channels)
|
||||||
|
|
||||||
|
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||||
|
# Save the input to add to the output
|
||||||
|
input_x = x
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
# Perform the input norm and projection
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
# Apply the transformer
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
||||||
|
|
||||||
|
# Apply the output projection and reshape
|
||||||
|
x = self.proj_out(x)
|
||||||
|
x = x.reshape(B, H, W, C)
|
||||||
|
|
||||||
|
return x + input_x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
groups: int = 32,
|
||||||
|
temb_channels: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
out_channels = out_channels or in_channels
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
if temb_channels is not None:
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||||
|
|
||||||
|
def __call__(self, x, temb=None):
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.time_emb_proj(nn.silu(temb))
|
||||||
|
y = self.norm1(x.astype(mx.float32)).astype(dtype)
|
||||||
|
|
||||||
|
y = nn.silu(y)
|
||||||
|
|
||||||
|
y = self.conv1(y)
|
||||||
|
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
y = y + temb[:, None, None, :]
|
||||||
|
y = self.norm2(y.astype(mx.float32)).astype(dtype)
|
||||||
|
y = nn.silu(y)
|
||||||
|
y = self.conv2(y)
|
||||||
|
|
||||||
|
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UNetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
temb_channels: int,
|
||||||
|
prev_out_channels: Optional[int] = None,
|
||||||
|
num_layers: int = 1,
|
||||||
|
transformer_layers_per_block: int = 1,
|
||||||
|
num_attention_heads: int = 8,
|
||||||
|
cross_attention_dim=1280,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
add_downsample=True,
|
||||||
|
add_upsample=True,
|
||||||
|
add_cross_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Prepare the in channels list for the resnets
|
||||||
|
if prev_out_channels is None:
|
||||||
|
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
||||||
|
else:
|
||||||
|
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
||||||
|
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
||||||
|
in_channels_list = [
|
||||||
|
a + b for a, b in zip(in_channels_list, res_channels_list)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add resnet blocks that also process the time embedding
|
||||||
|
self.resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=ic,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
groups=resnet_groups,
|
||||||
|
)
|
||||||
|
for ic in in_channels_list
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add optional cross attention layers
|
||||||
|
if add_cross_attention:
|
||||||
|
self.attentions = [
|
||||||
|
Transformer2D(
|
||||||
|
in_channels=out_channels,
|
||||||
|
model_dims=out_channels,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
encoder_dims=cross_attention_dim,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add an optional downsampling layer
|
||||||
|
if add_downsample:
|
||||||
|
self.downsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# or upsampling layer
|
||||||
|
if add_upsample:
|
||||||
|
self.upsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
encoder_x=None,
|
||||||
|
temb=None,
|
||||||
|
attn_mask=None,
|
||||||
|
encoder_attn_mask=None,
|
||||||
|
residual_hidden_states=None,
|
||||||
|
):
|
||||||
|
output_states = []
|
||||||
|
|
||||||
|
for i in range(len(self.resnets)):
|
||||||
|
if residual_hidden_states is not None:
|
||||||
|
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
||||||
|
|
||||||
|
x = self.resnets[i](x, temb)
|
||||||
|
|
||||||
|
if "attentions" in self:
|
||||||
|
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||||
|
|
||||||
|
output_states.append(x)
|
||||||
|
|
||||||
|
if "downsample" in self:
|
||||||
|
x = self.downsample(x)
|
||||||
|
output_states.append(x)
|
||||||
|
|
||||||
|
if "upsample" in self:
|
||||||
|
x = self.upsample(upsample_nearest(x))
|
||||||
|
output_states.append(x)
|
||||||
|
|
||||||
|
return x, output_states
|
||||||
|
|
||||||
|
|
||||||
|
class UNetModel(nn.Module):
|
||||||
|
"""The conditional 2D UNet model that actually performs the denoising."""
|
||||||
|
|
||||||
|
def __init__(self, config: UNetConfig, shard: Shard):
|
||||||
|
super().__init__()
|
||||||
|
self.shard = shard
|
||||||
|
self.start_layer = shard.start_layer
|
||||||
|
self.end_layer = shard.end_layer
|
||||||
|
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
||||||
|
if shard.is_first_layer():
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
config.in_channels,
|
||||||
|
config.block_out_channels[0],
|
||||||
|
config.conv_in_kernel,
|
||||||
|
padding=(config.conv_in_kernel - 1) // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timesteps = nn.SinusoidalPositionalEncoding(
|
||||||
|
config.block_out_channels[0],
|
||||||
|
max_freq=1,
|
||||||
|
min_freq=math.exp(
|
||||||
|
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
||||||
|
),
|
||||||
|
scale=1.0,
|
||||||
|
cos_first=True,
|
||||||
|
full_turns=False,
|
||||||
|
)
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
config.block_out_channels[0],
|
||||||
|
config.block_out_channels[0] * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.addition_embed_type == "text_time":
|
||||||
|
self.add_time_proj = nn.SinusoidalPositionalEncoding(
|
||||||
|
config.addition_time_embed_dim,
|
||||||
|
max_freq=1,
|
||||||
|
min_freq=math.exp(
|
||||||
|
-math.log(10000)
|
||||||
|
+ 2 * math.log(10000) / config.addition_time_embed_dim
|
||||||
|
),
|
||||||
|
scale=1.0,
|
||||||
|
cos_first=True,
|
||||||
|
full_turns=False,
|
||||||
|
)
|
||||||
|
self.add_embedding = TimestepEmbedding(
|
||||||
|
config.projection_class_embeddings_input_dim,
|
||||||
|
config.block_out_channels[0] * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make the downsampling blocks
|
||||||
|
block_channels = [config.block_out_channels[0]] + list(
|
||||||
|
config.block_out_channels
|
||||||
|
)
|
||||||
|
self.down_blocks = []
|
||||||
|
|
||||||
|
for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):
|
||||||
|
if i in self.layers_range:
|
||||||
|
self.down_blocks.append(
|
||||||
|
UNetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
num_layers=config.layers_per_block[i],
|
||||||
|
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||||
|
num_attention_heads=config.num_attention_heads[i],
|
||||||
|
cross_attention_dim=config.cross_attention_dim[i],
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||||
|
add_upsample=False,
|
||||||
|
add_cross_attention="CrossAttn" in config.down_block_types[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.down_blocks.append(nn.Identity())
|
||||||
|
|
||||||
|
|
||||||
|
# Make the middle block
|
||||||
|
if 4 in self.layers_range:
|
||||||
|
self.mid_blocks = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=config.block_out_channels[-1],
|
||||||
|
out_channels=config.block_out_channels[-1],
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
groups=config.norm_num_groups,
|
||||||
|
),
|
||||||
|
Transformer2D(
|
||||||
|
in_channels=config.block_out_channels[-1],
|
||||||
|
model_dims=config.block_out_channels[-1],
|
||||||
|
num_heads=config.num_attention_heads[-1],
|
||||||
|
num_layers=config.transformer_layers_per_block[-1],
|
||||||
|
encoder_dims=config.cross_attention_dim[-1],
|
||||||
|
),
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=config.block_out_channels[-1],
|
||||||
|
out_channels=config.block_out_channels[-1],
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
groups=config.norm_num_groups,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Make the upsampling blocks
|
||||||
|
block_channels = (
|
||||||
|
[config.block_out_channels[0]]
|
||||||
|
+ list(config.block_out_channels)
|
||||||
|
+ [config.block_out_channels[-1]]
|
||||||
|
)
|
||||||
|
|
||||||
|
total_items = len(block_channels) - 3
|
||||||
|
reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
|
||||||
|
|
||||||
|
self.up_blocks = []
|
||||||
|
for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):
|
||||||
|
i = total_items - rev_i
|
||||||
|
if rev_i+5 in self.layers_range:
|
||||||
|
self.up_blocks.append(
|
||||||
|
UNetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
prev_out_channels=prev_out_channels,
|
||||||
|
num_layers=config.layers_per_block[i] + 1,
|
||||||
|
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||||
|
num_attention_heads=config.num_attention_heads[i],
|
||||||
|
cross_attention_dim=config.cross_attention_dim[i],
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
add_downsample=False,
|
||||||
|
add_upsample=(i > 0),
|
||||||
|
add_cross_attention="CrossAttn" in config.up_block_types[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.up_blocks.append(nn.Identity())
|
||||||
|
|
||||||
|
|
||||||
|
if shard.is_last_layer():
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
config.norm_num_groups,
|
||||||
|
config.block_out_channels[0],
|
||||||
|
pytorch_compatible=True,
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(
|
||||||
|
config.block_out_channels[0],
|
||||||
|
config.out_channels,
|
||||||
|
config.conv_out_kernel,
|
||||||
|
padding=(config.conv_out_kernel - 1) // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
encoder_x,
|
||||||
|
attn_mask=None,
|
||||||
|
encoder_attn_mask=None,
|
||||||
|
text_time=None,
|
||||||
|
residuals=None,
|
||||||
|
):
|
||||||
|
# Compute the time embeddings
|
||||||
|
|
||||||
|
temb = self.timesteps(timestep).astype(x.dtype)
|
||||||
|
temb = self.time_embedding(temb)
|
||||||
|
|
||||||
|
# Add the extra text_time conditioning
|
||||||
|
if text_time is not None:
|
||||||
|
text_emb, time_ids = text_time
|
||||||
|
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
|
||||||
|
emb = mx.concatenate([text_emb, emb], axis=-1)
|
||||||
|
emb = self.add_embedding(emb)
|
||||||
|
temb = temb + emb
|
||||||
|
|
||||||
|
if self.shard.is_first_layer():
|
||||||
|
# Preprocess the input
|
||||||
|
x = self.conv_in(x)
|
||||||
|
residuals = [x]
|
||||||
|
# Run the downsampling part of the unet
|
||||||
|
|
||||||
|
for i in range(len(self.down_blocks)):
|
||||||
|
if i in self.layers_range:
|
||||||
|
x, res = self.down_blocks[i](
|
||||||
|
x,
|
||||||
|
encoder_x=encoder_x,
|
||||||
|
temb=temb,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
encoder_attn_mask=encoder_attn_mask,
|
||||||
|
)
|
||||||
|
residuals.extend(res)
|
||||||
|
else:
|
||||||
|
x= self.down_blocks[i](x)
|
||||||
|
|
||||||
|
if 4 in self.layers_range:
|
||||||
|
# Run the middle part of the unet
|
||||||
|
x = self.mid_blocks[0](x, temb)
|
||||||
|
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||||
|
x = self.mid_blocks[2](x, temb)
|
||||||
|
|
||||||
|
# Run the upsampling part of the unet
|
||||||
|
for i in range(len(self.up_blocks)):
|
||||||
|
if i+5 in self.layers_range:
|
||||||
|
x, _ = self.up_blocks[i](
|
||||||
|
x,
|
||||||
|
encoder_x=encoder_x,
|
||||||
|
temb=temb,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
encoder_attn_mask=encoder_attn_mask,
|
||||||
|
residual_hidden_states=residuals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x= self.up_blocks[i](x)
|
||||||
|
|
||||||
|
# Postprocess the output
|
||||||
|
if self.shard.is_last_layer():
|
||||||
|
dtype = x.dtype
|
||||||
|
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
|
||||||
|
return x, residuals
|
||||||
|
def sanitize(self, weights):
|
||||||
|
sanitized_weights = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
k1=""
|
||||||
|
k2=""
|
||||||
|
if "downsamplers" in key:
|
||||||
|
key = key.replace("downsamplers.0.conv", "downsample")
|
||||||
|
if "upsamplers" in key:
|
||||||
|
key = key.replace("upsamplers.0.conv", "upsample")
|
||||||
|
|
||||||
|
# Map the mid block
|
||||||
|
if "mid_block.resnets.0" in key:
|
||||||
|
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||||
|
if "mid_block.attentions.0" in key:
|
||||||
|
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||||
|
if "mid_block.resnets.1" in key:
|
||||||
|
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||||
|
|
||||||
|
# Map attention layers
|
||||||
|
if "to_k" in key:
|
||||||
|
key = key.replace("to_k", "key_proj")
|
||||||
|
if "to_out.0" in key:
|
||||||
|
key = key.replace("to_out.0", "out_proj")
|
||||||
|
if "to_q" in key:
|
||||||
|
key = key.replace("to_q", "query_proj")
|
||||||
|
if "to_v" in key:
|
||||||
|
key = key.replace("to_v", "value_proj")
|
||||||
|
|
||||||
|
# Map transformer ffn
|
||||||
|
if "ff.net.2" in key:
|
||||||
|
key = key.replace("ff.net.2", "linear3")
|
||||||
|
if "ff.net.0" in key:
|
||||||
|
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||||
|
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||||
|
v1, v2 = mx.split(value, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if "conv_shortcut.weight" in key:
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
# Transform the weights from 1x1 convs to linear
|
||||||
|
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
value = value.transpose(0, 2, 3, 1)
|
||||||
|
value = value.reshape(-1).reshape(value.shape)
|
||||||
|
|
||||||
|
if key.startswith("conv_in") :
|
||||||
|
if 0 not in self.layers_range:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key.startswith("down_blocks"):
|
||||||
|
layer_num = int(key.split(".")[1])
|
||||||
|
if layer_num not in self.layers_range:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key.startswith("mid_block"):
|
||||||
|
if 4 not in self.layers_range:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key.startswith("up_blocks"):
|
||||||
|
layer_num = int(key.split(".")[1])
|
||||||
|
if (layer_num+5) not in self.layers_range:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key.startswith("conv_out") or key.startswith("conv_norm_out"):
|
||||||
|
if 8 not in self.layers_range:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(k1)>0:
|
||||||
|
sanitized_weights[k1] = v1
|
||||||
|
sanitized_weights[k2] = v2
|
||||||
|
else:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
return sanitized_weights
|
||||||
429
exo/inference/mlx/models/sd_models/vae.py
Normal file
429
exo/inference/mlx/models/sd_models/vae.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .unet import ResnetBlock2D, upsample_nearest
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from exo.inference.shard import Shard
|
||||||
|
from typing import Tuple
|
||||||
|
import inspect
|
||||||
|
from ..base import IdentityBlock
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoencoderConfig:
|
||||||
|
in_channels: int = 3
|
||||||
|
out_channels: int = 3
|
||||||
|
latent_channels_out: int = 8
|
||||||
|
latent_channels_in: int = 4
|
||||||
|
block_out_channels: Tuple[int] = (128, 256, 512, 512)
|
||||||
|
layers_per_block: int = 2
|
||||||
|
norm_num_groups: int = 32
|
||||||
|
scaling_factor: float = 0.18215
|
||||||
|
weight_files: List[str] = field(default_factory=lambda: [])
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(AutoencoderConfig):
|
||||||
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if isinstance(self.shard, dict):
|
||||||
|
self.shard = Shard(**self.shard)
|
||||||
|
|
||||||
|
if not isinstance(self.shard, Shard):
|
||||||
|
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||||
|
|
||||||
|
if not self.shard.is_first_layer():
|
||||||
|
self.vision_config = None
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""A single head unmasked attention for use with the VAE."""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, norm_groups: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
||||||
|
self.query_proj = nn.Linear(dims, dims)
|
||||||
|
self.key_proj = nn.Linear(dims, dims)
|
||||||
|
self.value_proj = nn.Linear(dims, dims)
|
||||||
|
self.out_proj = nn.Linear(dims, dims)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
|
||||||
|
y = self.group_norm(x)
|
||||||
|
|
||||||
|
queries = self.query_proj(y).reshape(B, H * W, C)
|
||||||
|
keys = self.key_proj(y).reshape(B, H * W, C)
|
||||||
|
values = self.value_proj(y).reshape(B, H * W, C)
|
||||||
|
|
||||||
|
scale = 1 / math.sqrt(queries.shape[-1])
|
||||||
|
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
||||||
|
attn = mx.softmax(scores, axis=-1)
|
||||||
|
y = (attn @ values).reshape(B, H, W, C)
|
||||||
|
|
||||||
|
y = self.out_proj(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
add_downsample=True,
|
||||||
|
add_upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Add the resnet blocks
|
||||||
|
self.resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels if i == 0 else out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
groups=resnet_groups,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add an optional downsampling layer
|
||||||
|
if add_downsample:
|
||||||
|
self.downsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# or upsampling layer
|
||||||
|
if add_upsample:
|
||||||
|
self.upsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
x = resnet(x)
|
||||||
|
if "downsample" in self:
|
||||||
|
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
if "upsample" in self:
|
||||||
|
x = self.upsample(upsample_nearest(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
"""Implements the encoder side of the Autoencoder."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
latent_channels_out: int,
|
||||||
|
block_out_channels: List[int] = [64],
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
layers_range: List[int] = [],
|
||||||
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers_range = layers_range
|
||||||
|
self.shard = shard
|
||||||
|
if self.shard.is_first_layer():
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||||
|
self.down_blocks = []
|
||||||
|
current_layer = 1
|
||||||
|
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
||||||
|
if current_layer in self.layers_range:
|
||||||
|
self.down_blocks.append(
|
||||||
|
EncoderDecoderBlock2D(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_groups=resnet_groups,
|
||||||
|
add_downsample=i < len(block_out_channels) - 1,
|
||||||
|
add_upsample=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.down_blocks.append(IdentityBlock())
|
||||||
|
current_layer += 1
|
||||||
|
|
||||||
|
if self.shard.is_last_layer():
|
||||||
|
self.mid_blocks = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
Attention(block_out_channels[-1], resnet_groups),
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if self.shard.is_first_layer():
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
for l in self.down_blocks:
|
||||||
|
x = l(x)
|
||||||
|
|
||||||
|
if self.shard.is_last_layer():
|
||||||
|
x = self.mid_blocks[0](x)
|
||||||
|
x = self.mid_blocks[1](x)
|
||||||
|
x = self.mid_blocks[2](x)
|
||||||
|
|
||||||
|
x = self.conv_norm_out(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""Implements the decoder side of the Autoencoder."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
shard: Shard,
|
||||||
|
layer_range: List[int],
|
||||||
|
block_out_channels: List[int] = [64],
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.layers_range = layer_range
|
||||||
|
if 0 in layer_range:
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if 0 in layer_range:
|
||||||
|
self.mid_blocks = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
Attention(block_out_channels[-1], resnet_groups),
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
channels = list(reversed(block_out_channels))
|
||||||
|
channels = [channels[0]] + channels
|
||||||
|
|
||||||
|
self.up_blocks = []
|
||||||
|
current_layer = 1
|
||||||
|
|
||||||
|
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
||||||
|
if current_layer in layer_range:
|
||||||
|
self.up_blocks.append(
|
||||||
|
EncoderDecoderBlock2D(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_groups=resnet_groups,
|
||||||
|
add_downsample=False,
|
||||||
|
add_upsample=i < len(block_out_channels) - 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.up_blocks.append(IdentityBlock())
|
||||||
|
current_layer += 1
|
||||||
|
if 4 in layer_range:
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if 0 in self.layers_range:
|
||||||
|
x = self.conv_in(x)
|
||||||
|
x = self.mid_blocks[0](x)
|
||||||
|
x = self.mid_blocks[1](x)
|
||||||
|
x = self.mid_blocks[2](x)
|
||||||
|
|
||||||
|
for l in self.up_blocks:
|
||||||
|
x = l(x)
|
||||||
|
if 4 in self.layers_range:
|
||||||
|
x = self.conv_norm_out(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Autoencoder(nn.Module):
|
||||||
|
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||||
|
|
||||||
|
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
|
||||||
|
super().__init__()
|
||||||
|
self.shard = shard
|
||||||
|
self.start_layer = shard.start_layer
|
||||||
|
self.end_layer = shard.end_layer
|
||||||
|
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
||||||
|
self.latent_channels = config.latent_channels_in
|
||||||
|
self.scaling_factor = config.scaling_factor
|
||||||
|
self.model_shard = model_shard
|
||||||
|
if self.model_shard == "vae_encoder":
|
||||||
|
self.encoder = Encoder(
|
||||||
|
config.in_channels,
|
||||||
|
config.latent_channels_out,
|
||||||
|
config.block_out_channels,
|
||||||
|
config.layers_per_block,
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
layers_range=self.layers_range,
|
||||||
|
shard=shard
|
||||||
|
)
|
||||||
|
if self.shard.is_last_layer():
|
||||||
|
self.quant_proj = nn.Linear(
|
||||||
|
config.latent_channels_out, config.latent_channels_out
|
||||||
|
)
|
||||||
|
if self.model_shard == "vae_decoder":
|
||||||
|
self.decoder = Decoder(
|
||||||
|
config.latent_channels_in,
|
||||||
|
config.out_channels,
|
||||||
|
shard,
|
||||||
|
self.layers_range,
|
||||||
|
config.block_out_channels,
|
||||||
|
config.layers_per_block + 1,
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
)
|
||||||
|
if self.shard.is_first_layer():
|
||||||
|
self.post_quant_proj = nn.Linear(
|
||||||
|
config.latent_channels_in, config.latent_channels_in
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
if self.shard.is_first_layer():
|
||||||
|
z = z / self.scaling_factor
|
||||||
|
z=self.post_quant_proj(z)
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
x = self.encoder(x)
|
||||||
|
if self.shard.is_last_layer():
|
||||||
|
x = self.quant_proj(x)
|
||||||
|
mean, logvar = x.split(2, axis=-1)
|
||||||
|
mean = mean * self.scaling_factor
|
||||||
|
logvar = logvar + 2 * math.log(self.scaling_factor)
|
||||||
|
x = mean
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __call__(self, x, key=None):
|
||||||
|
mean, logvar = self.encode(x)
|
||||||
|
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
|
||||||
|
x_hat = self.decode(z)
|
||||||
|
|
||||||
|
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
shard = self.shard
|
||||||
|
layers = self.layers_range
|
||||||
|
sanitized_weights = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
|
||||||
|
if "downsamplers" in key:
|
||||||
|
key = key.replace("downsamplers.0.conv", "downsample")
|
||||||
|
if "upsamplers" in key:
|
||||||
|
key = key.replace("upsamplers.0.conv", "upsample")
|
||||||
|
|
||||||
|
# Map attention layers
|
||||||
|
if "key" in key:
|
||||||
|
key = key.replace("key", "key_proj")
|
||||||
|
if "proj_attn" in key:
|
||||||
|
key = key.replace("proj_attn", "out_proj")
|
||||||
|
if "query" in key:
|
||||||
|
key = key.replace("query", "query_proj")
|
||||||
|
if "value" in key:
|
||||||
|
key = key.replace("value", "value_proj")
|
||||||
|
|
||||||
|
# Map the mid block
|
||||||
|
if "mid_block.resnets.0" in key:
|
||||||
|
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||||
|
if "mid_block.attentions.0" in key:
|
||||||
|
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||||
|
if "mid_block.resnets.1" in key:
|
||||||
|
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||||
|
|
||||||
|
# Map the quant/post_quant layers
|
||||||
|
if "quant_conv" in key:
|
||||||
|
key = key.replace("quant_conv", "quant_proj")
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
# Map the conv_shortcut to linear
|
||||||
|
if "conv_shortcut.weight" in key:
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
value = value.transpose(0, 2, 3, 1)
|
||||||
|
value = value.reshape(-1).reshape(value.shape)
|
||||||
|
|
||||||
|
|
||||||
|
if "post_quant_conv" in key :
|
||||||
|
key = key.replace("quant_conv", "quant_proj")
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
if 'decoder' in key and self.model_shard == "vae_decoder":
|
||||||
|
if key.startswith("decoder.mid_blocks."):
|
||||||
|
if 0 in layers:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if "conv_in" in key and 0 in layers:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if key.startswith("decoder.up_blocks."):
|
||||||
|
layer_num = int(key.split(".")[2])+1
|
||||||
|
if layer_num in layers:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if key.startswith("decoder.conv_norm_out") and 4 in layers:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if key.startswith("decoder.conv_out") and 4 in layers:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if self.model_shard == "vae_decoder":
|
||||||
|
if key.startswith("post_quant_proj") and 0 in layers:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if self.model_shard == "vae_encoder":
|
||||||
|
if key.startswith("encoder."):
|
||||||
|
if "conv_in" in key and shard.is_first_layer():
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if key.startswith("encoder.down_blocks."):
|
||||||
|
layer_num = int(key.split(".")[2])+1
|
||||||
|
if layer_num in layers:
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if "conv_norm_out" in key and shard.is_last_layer():
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if "conv_out" in key and shard.is_last_layer():
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
if key.startswith("quant_proj") and shard.is_last_layer():
|
||||||
|
sanitized_weights[key] = value
|
||||||
|
return sanitized_weights
|
||||||
|
|
||||||
@@ -12,6 +12,7 @@ from exo.download.shard_download import ShardDownloader
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from mlx_lm.models.cache import make_prompt_cache
|
from mlx_lm.models.cache import make_prompt_cache
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||||
def __init__(self, shard_downloader: ShardDownloader):
|
def __init__(self, shard_downloader: ShardDownloader):
|
||||||
@@ -20,6 +21,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|||||||
self.caches = OrderedDict()
|
self.caches = OrderedDict()
|
||||||
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
|
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
|
||||||
self.sampler = make_sampler(*self.sampler_params)
|
self.sampler = make_sampler(*self.sampler_params)
|
||||||
|
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
|
||||||
|
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
|
||||||
|
|
||||||
|
async def _eval_mlx(self, *args):
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
await loop.run_in_executor(self._mlx_thread, mx.eval, *args)
|
||||||
|
|
||||||
async def poll_state(self, request_id: str, max_caches=2):
|
async def poll_state(self, request_id: str, max_caches=2):
|
||||||
if request_id in self.caches:
|
if request_id in self.caches:
|
||||||
@@ -38,16 +45,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|||||||
logits = mx.array(x)
|
logits = mx.array(x)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
return np.asarray(self.sampler(logprobs), dtype=int)
|
result = self.sampler(logprobs)
|
||||||
|
await self._eval_mlx(result)
|
||||||
|
return np.asarray(result, dtype=int)
|
||||||
|
|
||||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||||
await self.ensure_shard(shard)
|
await self.ensure_shard(shard)
|
||||||
tokens = self.tokenizer.encode(prompt)
|
loop = asyncio.get_running_loop()
|
||||||
return np.asarray(tokens)
|
return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))
|
||||||
|
|
||||||
async def decode(self, shard: Shard, tokens) -> str:
|
async def decode(self, shard: Shard, tokens) -> str:
|
||||||
await self.ensure_shard(shard)
|
await self.ensure_shard(shard)
|
||||||
return self.tokenizer.decode(tokens)
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)
|
||||||
|
|
||||||
async def save_checkpoint(self, shard: Shard, path: str):
|
async def save_checkpoint(self, shard: Shard, path: str):
|
||||||
await self.ensure_shard(shard)
|
await self.ensure_shard(shard)
|
||||||
@@ -57,12 +67,17 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|||||||
await self.ensure_shard(shard)
|
await self.ensure_shard(shard)
|
||||||
self.model.load_weights(path)
|
self.model.load_weights(path)
|
||||||
|
|
||||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
|
||||||
await self.ensure_shard(shard)
|
await self.ensure_shard(shard)
|
||||||
state = await self.poll_state(request_id)
|
loop = asyncio.get_running_loop()
|
||||||
|
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
||||||
x = mx.array(input_data)
|
x = mx.array(input_data)
|
||||||
output_data = np.array(self.model(x, **state), copy=False)
|
if self.model.model_type != 'StableDiffusionPipeline':
|
||||||
return output_data
|
output_data = self.model(x, **state, **inference_state)
|
||||||
|
else:
|
||||||
|
output_data, inference_state = self.model(x, **state, **inference_state)
|
||||||
|
output_data = np.array(output_data, copy=False)
|
||||||
|
return output_data, inference_state
|
||||||
|
|
||||||
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
||||||
await self.ensure_shard(shard)
|
await self.ensure_shard(shard)
|
||||||
@@ -87,26 +102,25 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
|
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
|
||||||
loop = asyncio.get_running_loop()
|
await self.ensure_train(shard, loss, opt, lr)
|
||||||
nothin = await self.ensure_train(shard, loss, opt, lr)
|
|
||||||
def train_step(inp, tar, lng):
|
def train_step(inp, tar, lng):
|
||||||
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
|
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
|
||||||
gradlayers = grad['model']['layers']
|
gradlayers = grad['model']['layers']
|
||||||
self.session['opt'].update(self.model, grad)
|
self.session['opt'].update(self.model, grad)
|
||||||
mx.eval(self.model.parameters(), self.session['opt'].state, lval)
|
return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
|
||||||
return lval, gradlayers
|
|
||||||
|
|
||||||
x = mx.array(inputs)
|
x = mx.array(inputs)
|
||||||
y = mx.array(targets)
|
y = mx.array(targets)
|
||||||
l = mx.array(lengths)
|
l = mx.array(lengths)
|
||||||
|
|
||||||
score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
|
score, gradients, eval_args = train_step(x, y, l)
|
||||||
#print(f"{score=}")
|
await self._eval_mlx(*eval_args)
|
||||||
|
|
||||||
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
||||||
#print(layers[0])
|
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
|
||||||
|
await self._eval_mlx(first_layer)
|
||||||
return score, np.array(layers[0]['input_layernorm'], copy=False)
|
return score, first_layer
|
||||||
|
|
||||||
async def ensure_shard(self, shard: Shard):
|
async def ensure_shard(self, shard: Shard):
|
||||||
if self.shard == shard:
|
if self.shard == shard:
|
||||||
@@ -121,3 +135,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|||||||
self.caches = OrderedDict()
|
self.caches = OrderedDict()
|
||||||
self.session = {}
|
self.session = {}
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
self._mlx_thread.shutdown(wait=True)
|
||||||
|
|
||||||
|
|||||||
@@ -62,8 +62,16 @@ def _get_classes(config: dict):
|
|||||||
|
|
||||||
def load_config(model_path: Path) -> dict:
|
def load_config(model_path: Path) -> dict:
|
||||||
try:
|
try:
|
||||||
with open(model_path/"config.json", "r") as f:
|
config_path = model_path / "config.json"
|
||||||
config = json.load(f)
|
if config_path.exists():
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return config
|
||||||
|
|
||||||
|
model_index_path = model_path / "model_index.json"
|
||||||
|
if model_index_path.exists():
|
||||||
|
config = load_model_index(model_path, model_index_path)
|
||||||
|
return config
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.error(f"Config file not found in {model_path}")
|
logging.error(f"Config file not found in {model_path}")
|
||||||
raise
|
raise
|
||||||
@@ -110,6 +118,24 @@ def load_model_shard(
|
|||||||
# Try weight for back-compat
|
# Try weight for back-compat
|
||||||
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
|
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
|
||||||
|
|
||||||
|
model_class, model_args_class = _get_classes(config=config)
|
||||||
|
|
||||||
|
class ShardedModel(model_class):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__(args)
|
||||||
|
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
|
||||||
|
|
||||||
|
def __call__(self, x, *args, **kwargs):
|
||||||
|
y = super().__call__(x, *args, **kwargs)
|
||||||
|
return y
|
||||||
|
|
||||||
|
model_args = model_args_class.from_dict(config)
|
||||||
|
model = ShardedModel(model_args)
|
||||||
|
|
||||||
|
if config.get("model_index", False):
|
||||||
|
model.load()
|
||||||
|
return model
|
||||||
|
|
||||||
if not weight_files:
|
if not weight_files:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
@@ -129,19 +155,7 @@ def load_model_shard(
|
|||||||
|
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
|
|
||||||
model_class, model_args_class = _get_classes(config=config)
|
|
||||||
|
|
||||||
class ShardedModel(model_class):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__(args)
|
|
||||||
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
|
|
||||||
|
|
||||||
def __call__(self, x, *args, **kwargs):
|
|
||||||
y = super().__call__(x, *args, **kwargs)
|
|
||||||
return y
|
|
||||||
|
|
||||||
model_args = model_args_class.from_dict(config)
|
|
||||||
model = ShardedModel(model_args)
|
|
||||||
|
|
||||||
if hasattr(model, "sanitize"):
|
if hasattr(model, "sanitize"):
|
||||||
weights = model.sanitize(weights)
|
weights = model.sanitize(weights)
|
||||||
@@ -186,6 +200,9 @@ async def load_shard(
|
|||||||
processor.eos_token_id = processor.tokenizer.eos_token_id
|
processor.eos_token_id = processor.tokenizer.eos_token_id
|
||||||
processor.encode = processor.tokenizer.encode
|
processor.encode = processor.tokenizer.encode
|
||||||
return model, processor
|
return model, processor
|
||||||
|
elif hasattr(model, "tokenizer"):
|
||||||
|
tokenizer = model.tokenizer
|
||||||
|
return model, tokenizer
|
||||||
else:
|
else:
|
||||||
tokenizer = await resolve_tokenizer(model_path)
|
tokenizer = await resolve_tokenizer(model_path)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
|
|||||||
return img
|
return img
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
|
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
|
||||||
|
|
||||||
|
# loading a combined config for all models in the index
|
||||||
|
def load_model_index(model_path: Path, model_index_path: Path):
|
||||||
|
models_config = {}
|
||||||
|
with open(model_index_path, "r") as f:
|
||||||
|
model_index = json.load(f)
|
||||||
|
models_config["model_index"] = True
|
||||||
|
models_config["model_type"] = model_index["_class_name"]
|
||||||
|
models_config["models"] = {}
|
||||||
|
for model in model_index.keys():
|
||||||
|
model_config_path = glob.glob(str(model_path / model / "*config.json"))
|
||||||
|
if len(model_config_path)>0:
|
||||||
|
with open(model_config_path[0], "r") as f:
|
||||||
|
model_config = { }
|
||||||
|
model_config["model_type"] = model
|
||||||
|
model_config["config"] = json.load(f)
|
||||||
|
model_config["path"] = model_path / model
|
||||||
|
if model_config["path"]/"*model.safetensors":
|
||||||
|
model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
|
||||||
|
model_config["path"] = str(model_path / model)
|
||||||
|
m = {}
|
||||||
|
m[model] = model_config
|
||||||
|
models_config.update(m)
|
||||||
|
return models_config
|
||||||
|
|||||||
81
exo/inference/mlx/test_non_blocking.py
Normal file
81
exo/inference/mlx/test_non_blocking.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||||
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||||
|
from exo.inference.shard import Shard
|
||||||
|
from exo.models import build_base_shard
|
||||||
|
from collections import deque
|
||||||
|
from statistics import mean, median
|
||||||
|
|
||||||
|
async def test_non_blocking():
|
||||||
|
# Setup
|
||||||
|
shard_downloader = HFShardDownloader()
|
||||||
|
engine = MLXDynamicShardInferenceEngine(shard_downloader)
|
||||||
|
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
|
||||||
|
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
|
||||||
|
await engine.ensure_shard(shard)
|
||||||
|
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
measurements = deque(maxlen=1000000)
|
||||||
|
running = True
|
||||||
|
|
||||||
|
async def mlx_worker():
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
count = 0
|
||||||
|
while running and (time.time() - start_time) < 5: # Hard time limit
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
await engine.infer_prompt("req1", shard, "test prompt")
|
||||||
|
duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
|
||||||
|
count += 1
|
||||||
|
print(f"MLX operation {count} took: {duration:.3f}ms")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
print(f"\nTotal MLX operations completed: {count}")
|
||||||
|
print(f"Average rate: {count/5:.1f} ops/second")
|
||||||
|
|
||||||
|
async def latency_producer():
|
||||||
|
try:
|
||||||
|
start_time = time.perf_counter_ns()
|
||||||
|
count = 0
|
||||||
|
while running:
|
||||||
|
await queue.put(time.perf_counter_ns())
|
||||||
|
count += 1
|
||||||
|
await asyncio.sleep(0) # Yield to event loop without delay
|
||||||
|
duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds
|
||||||
|
print(f"\nProducer iterations: {count}")
|
||||||
|
print(f"Producer rate: {count/duration:.1f} iterations/second")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def latency_consumer():
|
||||||
|
try:
|
||||||
|
while running:
|
||||||
|
timestamp = await queue.get()
|
||||||
|
latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms
|
||||||
|
measurements.append(latency)
|
||||||
|
queue.task_done()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(mlx_worker()),
|
||||||
|
asyncio.create_task(latency_producer()),
|
||||||
|
asyncio.create_task(latency_consumer())
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("\nTest timed out")
|
||||||
|
finally:
|
||||||
|
running = False
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
print(f"\nFinal measurement count: {len(measurements)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_non_blocking())
|
||||||
@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
|
|||||||
from .losses import length_masked_ce_loss
|
from .losses import length_masked_ce_loss
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
# default settings
|
# default settings
|
||||||
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
|
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
|
||||||
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|||||||
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
|
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
|
||||||
safe_save(state_dict, path)
|
safe_save(state_dict, path)
|
||||||
|
|
||||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
|
||||||
await self.ensure_shard(shard)
|
await self.ensure_shard(shard)
|
||||||
def wrap_infer():
|
def wrap_infer():
|
||||||
x = Tensor(input_data)
|
x = Tensor(input_data)
|
||||||
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|||||||
self.states[request_id].start += x.shape[1]
|
self.states[request_id].start += x.shape[1]
|
||||||
return out.realize()
|
return out.realize()
|
||||||
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|
||||||
return output_data.numpy()
|
return output_data.numpy(), inference_state
|
||||||
|
|
||||||
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
||||||
def step(x, y, l):
|
def step(x, y, l):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class DummyTokenizer:
|
|||||||
self.eos_token_id = 69
|
self.eos_token_id = 69
|
||||||
self.vocab_size = 1000
|
self.vocab_size = 1000
|
||||||
|
|
||||||
def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
|
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
|
||||||
return "dummy_tokenized_prompt"
|
return "dummy_tokenized_prompt"
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
|
|||||||
10
exo/main.py
10
exo/main.py
@@ -103,6 +103,7 @@ parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailsca
|
|||||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||||
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
|
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
|
||||||
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
|
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
|
||||||
|
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(f"Selected inference engine: {args.inference_engine}")
|
print(f"Selected inference engine: {args.inference_engine}")
|
||||||
|
|
||||||
@@ -182,11 +183,12 @@ api = ChatGPTAPI(
|
|||||||
inference_engine.__class__.__name__,
|
inference_engine.__class__.__name__,
|
||||||
response_timeout=args.chatgpt_api_response_timeout,
|
response_timeout=args.chatgpt_api_response_timeout,
|
||||||
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
||||||
default_model=args.default_model
|
default_model=args.default_model,
|
||||||
|
system_prompt=args.system_prompt
|
||||||
|
)
|
||||||
|
node.on_token.register("update_topology_viz").on_next(
|
||||||
|
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
|
||||||
)
|
)
|
||||||
# node.on_token.register("update_topology_viz").on_next(
|
|
||||||
# lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
|
||||||
# )
|
|
||||||
|
|
||||||
def preemptively_start_download(request_id: str, opaque_status: str):
|
def preemptively_start_download(request_id: str, opaque_status: str):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -92,14 +92,17 @@ model_cards = {
|
|||||||
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
|
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
|
||||||
### qwen
|
### qwen
|
||||||
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
|
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
|
||||||
|
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
|
||||||
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
|
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
|
||||||
|
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
|
||||||
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
|
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
|
||||||
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
|
||||||
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
|
||||||
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
|
||||||
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
|
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
|
||||||
|
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
||||||
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
|
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
|
||||||
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
|
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
|
||||||
|
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
||||||
|
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
|
||||||
|
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
||||||
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
|
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
|
||||||
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
|
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
|
||||||
### nemotron
|
### nemotron
|
||||||
@@ -108,6 +111,11 @@ model_cards = {
|
|||||||
# gemma
|
# gemma
|
||||||
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
|
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
|
||||||
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
||||||
|
# stable diffusion
|
||||||
|
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
|
||||||
|
# phi
|
||||||
|
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
|
||||||
|
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
|
||||||
# dummy
|
# dummy
|
||||||
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
||||||
}
|
}
|
||||||
@@ -133,18 +141,24 @@ pretty_name = {
|
|||||||
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
||||||
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
||||||
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
||||||
|
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
|
||||||
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
||||||
|
"qwen-2.5-3b": "Qwen 2.5 3B",
|
||||||
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
|
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
|
||||||
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
|
||||||
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
|
||||||
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
|
||||||
"qwen-2.5-7b": "Qwen 2.5 7B",
|
"qwen-2.5-7b": "Qwen 2.5 7B",
|
||||||
|
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
||||||
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
|
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
|
||||||
"qwen-2.5-14b": "Qwen 2.5 14B",
|
"qwen-2.5-14b": "Qwen 2.5 14B",
|
||||||
|
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
||||||
|
"qwen-2.5-32b": "Qwen 2.5 32B",
|
||||||
|
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
||||||
"qwen-2.5-72b": "Qwen 2.5 72B",
|
"qwen-2.5-72b": "Qwen 2.5 72B",
|
||||||
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
||||||
|
"phi-3.5-mini": "Phi-3.5 Mini",
|
||||||
|
"phi-4": "Phi-4",
|
||||||
"llama-3-8b": "Llama 3 8B",
|
"llama-3-8b": "Llama 3 8B",
|
||||||
"llama-3-70b": "Llama 3 70B",
|
"llama-3-70b": "Llama 3 70B",
|
||||||
|
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ from exo.inference.shard import Shard
|
|||||||
from exo.topology.topology import Topology
|
from exo.topology.topology import Topology
|
||||||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
||||||
from exo.helpers import DEBUG
|
from exo.helpers import DEBUG
|
||||||
|
import json
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
class GRPCPeerHandle(PeerHandle):
|
class GRPCPeerHandle(PeerHandle):
|
||||||
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
|
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
|
||||||
@@ -90,7 +91,7 @@ class GRPCPeerHandle(PeerHandle):
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None:
|
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||||
request = node_service_pb2.PromptRequest(
|
request = node_service_pb2.PromptRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
shard=node_service_pb2.Shard(
|
shard=node_service_pb2.Shard(
|
||||||
@@ -100,10 +101,11 @@ class GRPCPeerHandle(PeerHandle):
|
|||||||
n_layers=shard.n_layers,
|
n_layers=shard.n_layers,
|
||||||
),
|
),
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
|
inference_state=self.serialize_inference_state(inference_state)
|
||||||
)
|
)
|
||||||
await self.stub.SendPrompt(request)
|
await self.stub.SendPrompt(request)
|
||||||
|
|
||||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None:
|
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||||
request = node_service_pb2.TensorRequest(
|
request = node_service_pb2.TensorRequest(
|
||||||
shard=node_service_pb2.Shard(
|
shard=node_service_pb2.Shard(
|
||||||
model_id=shard.model_id,
|
model_id=shard.model_id,
|
||||||
@@ -113,8 +115,14 @@ class GRPCPeerHandle(PeerHandle):
|
|||||||
),
|
),
|
||||||
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
|
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
|
inference_state=self.serialize_inference_state(inference_state)
|
||||||
)
|
)
|
||||||
await self.stub.SendTensor(request)
|
response =await self.stub.SendTensor(request)
|
||||||
|
|
||||||
|
if not response.tensor_data or not response.shape or not response.dtype:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||||
|
|
||||||
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||||
request = node_service_pb2.ExampleRequest(
|
request = node_service_pb2.ExampleRequest(
|
||||||
@@ -173,10 +181,44 @@ class GRPCPeerHandle(PeerHandle):
|
|||||||
topology.add_edge(node_id, conn.to_id, conn.description)
|
topology.add_edge(node_id, conn.to_id, conn.description)
|
||||||
return topology
|
return topology
|
||||||
|
|
||||||
async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
|
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||||
request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished)
|
tensor = None
|
||||||
await self.stub.SendNewToken(request)
|
if isinstance(result, np.ndarray):
|
||||||
|
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
|
||||||
|
result = []
|
||||||
|
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
|
||||||
|
await self.stub.SendResult(request)
|
||||||
|
|
||||||
async def send_opaque_status(self, request_id: str, status: str) -> None:
|
async def send_opaque_status(self, request_id: str, status: str) -> None:
|
||||||
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
|
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
|
||||||
await self.stub.SendOpaqueStatus(request)
|
await self.stub.SendOpaqueStatus(request)
|
||||||
|
|
||||||
|
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
|
||||||
|
proto_inference_state = node_service_pb2.InferenceState()
|
||||||
|
other_data = {}
|
||||||
|
for k, v in inference_state.items():
|
||||||
|
if isinstance(v, mx.array):
|
||||||
|
np_array = np.array(v)
|
||||||
|
tensor_data = node_service_pb2.Tensor(
|
||||||
|
tensor_data=np_array.tobytes(),
|
||||||
|
shape=list(np_array.shape),
|
||||||
|
dtype=str(np_array.dtype)
|
||||||
|
)
|
||||||
|
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
|
||||||
|
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
|
||||||
|
tensor_list = node_service_pb2.TensorList()
|
||||||
|
for tensor in v:
|
||||||
|
np_array = np.array(tensor)
|
||||||
|
tensor_data = node_service_pb2.Tensor(
|
||||||
|
tensor_data=np_array.tobytes(),
|
||||||
|
shape=list(np_array.shape),
|
||||||
|
dtype=str(np_array.dtype)
|
||||||
|
)
|
||||||
|
tensor_list.tensors.append(tensor_data)
|
||||||
|
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
|
||||||
|
else:
|
||||||
|
# For non-tensor data, we'll still use JSON
|
||||||
|
other_data[k] = v
|
||||||
|
if other_data:
|
||||||
|
proto_inference_state.other_data_json = json.dumps(other_data)
|
||||||
|
return proto_inference_state
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from . import node_service_pb2_grpc
|
|||||||
from exo import DEBUG
|
from exo import DEBUG
|
||||||
from exo.inference.shard import Shard
|
from exo.inference.shard import Shard
|
||||||
from exo.orchestration import Node
|
from exo.orchestration import Node
|
||||||
|
import json
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||||
@@ -58,9 +60,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|||||||
)
|
)
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
await self.node.process_prompt(shard, prompt, request_id)
|
inference_state = self.deserialize_inference_state(request.inference_state)
|
||||||
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}")
|
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
|
||||||
return node_service_pb2.Empty()
|
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
||||||
|
tensor_data = result.tobytes() if result is not None else None
|
||||||
|
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||||
|
|
||||||
async def SendTensor(self, request, context):
|
async def SendTensor(self, request, context):
|
||||||
shard = Shard(
|
shard = Shard(
|
||||||
@@ -71,9 +75,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|||||||
)
|
)
|
||||||
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
await self.node.process_tensor(shard, tensor, request_id)
|
|
||||||
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}")
|
inference_state = self.deserialize_inference_state(request.inference_state)
|
||||||
return node_service_pb2.Empty()
|
|
||||||
|
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
|
||||||
|
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
||||||
|
tensor_data = result.tobytes() if result is not None else None
|
||||||
|
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||||
|
|
||||||
async def SendExample(self, request, context):
|
async def SendExample(self, request, context):
|
||||||
shard = Shard(
|
shard = Shard(
|
||||||
@@ -127,8 +135,12 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
token = request.token
|
token = request.token
|
||||||
is_finished = request.is_finished
|
is_finished = request.is_finished
|
||||||
if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=}")
|
img = request.tensor
|
||||||
self.node.on_token.trigger_all(request_id, token, is_finished)
|
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
|
||||||
|
result = list(result)
|
||||||
|
if len(img.tensor_data) > 0:
|
||||||
|
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
|
||||||
|
self.node.on_token.trigger_all(request_id, result, is_finished)
|
||||||
return node_service_pb2.Empty()
|
return node_service_pb2.Empty()
|
||||||
|
|
||||||
async def SendOpaqueStatus(self, request, context):
|
async def SendOpaqueStatus(self, request, context):
|
||||||
@@ -140,3 +152,22 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|||||||
|
|
||||||
async def HealthCheck(self, request, context):
|
async def HealthCheck(self, request, context):
|
||||||
return node_service_pb2.HealthCheckResponse(is_healthy=True)
|
return node_service_pb2.HealthCheckResponse(is_healthy=True)
|
||||||
|
|
||||||
|
def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
|
||||||
|
inference_state = {}
|
||||||
|
|
||||||
|
for k, tensor_data in inference_state_proto.tensor_data.items():
|
||||||
|
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
|
||||||
|
inference_state[k] = mx.array(np_array)
|
||||||
|
|
||||||
|
for k, tensor_list in inference_state_proto.tensor_list_data.items():
|
||||||
|
inference_state[k] = [
|
||||||
|
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
|
||||||
|
for tensor in tensor_list.tensors
|
||||||
|
]
|
||||||
|
|
||||||
|
if inference_state_proto.other_data_json:
|
||||||
|
other_data = json.loads(inference_state_proto.other_data_json)
|
||||||
|
inference_state.update(other_data)
|
||||||
|
|
||||||
|
return inference_state
|
||||||
|
|||||||
@@ -23,12 +23,14 @@ message PromptRequest {
|
|||||||
Shard shard = 1;
|
Shard shard = 1;
|
||||||
string prompt = 2;
|
string prompt = 2;
|
||||||
optional string request_id = 3;
|
optional string request_id = 3;
|
||||||
|
optional InferenceState inference_state = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TensorRequest {
|
message TensorRequest {
|
||||||
Shard shard = 1;
|
Shard shard = 1;
|
||||||
Tensor tensor = 2;
|
Tensor tensor = 2;
|
||||||
optional string request_id = 3;
|
optional string request_id = 3;
|
||||||
|
optional InferenceState inference_state = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ExampleRequest {
|
message ExampleRequest {
|
||||||
@@ -51,6 +53,16 @@ message Tensor {
|
|||||||
string dtype = 3;
|
string dtype = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message TensorList {
|
||||||
|
repeated Tensor tensors = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message InferenceState {
|
||||||
|
map<string, Tensor> tensor_data = 1;
|
||||||
|
map<string, TensorList> tensor_list_data = 2;
|
||||||
|
string other_data_json = 3;
|
||||||
|
}
|
||||||
|
|
||||||
message CollectTopologyRequest {
|
message CollectTopologyRequest {
|
||||||
repeated string visited = 1;
|
repeated string visited = 1;
|
||||||
int32 max_depth = 2;
|
int32 max_depth = 2;
|
||||||
@@ -85,8 +97,9 @@ message DeviceCapabilities {
|
|||||||
|
|
||||||
message SendNewTokenRequest {
|
message SendNewTokenRequest {
|
||||||
string request_id = 1;
|
string request_id = 1;
|
||||||
int32 token = 2;
|
repeated int32 result = 2;
|
||||||
bool is_finished = 3;
|
optional Tensor tensor = 3;
|
||||||
|
bool is_finished = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SendOpaqueStatusRequest {
|
message SendOpaqueStatusRequest {
|
||||||
|
|||||||
@@ -24,55 +24,67 @@ _sym_db = _symbol_database.Default()
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"M\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\r\n\x05token\x18\x02 \x01(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x84\x01\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
|
||||||
|
|
||||||
_globals = globals()
|
_globals = globals()
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
DESCRIPTOR._loaded_options = None
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None
|
||||||
|
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_options = b'8\001'
|
||||||
|
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._loaded_options = None
|
||||||
|
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_options = b'8\001'
|
||||||
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
|
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
|
||||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
|
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
|
||||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
|
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
|
||||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
|
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
|
||||||
_globals['_SHARD']._serialized_start=36
|
_globals['_SHARD']._serialized_start=36
|
||||||
_globals['_SHARD']._serialized_end=119
|
_globals['_SHARD']._serialized_end=119
|
||||||
_globals['_PROMPTREQUEST']._serialized_start=121
|
_globals['_PROMPTREQUEST']._serialized_start=122
|
||||||
_globals['_PROMPTREQUEST']._serialized_end=228
|
_globals['_PROMPTREQUEST']._serialized_end=309
|
||||||
_globals['_TENSORREQUEST']._serialized_start=231
|
_globals['_TENSORREQUEST']._serialized_start=312
|
||||||
_globals['_TENSORREQUEST']._serialized_end=360
|
_globals['_TENSORREQUEST']._serialized_end=521
|
||||||
_globals['_EXAMPLEREQUEST']._serialized_start=363
|
_globals['_EXAMPLEREQUEST']._serialized_start=524
|
||||||
_globals['_EXAMPLEREQUEST']._serialized_end=585
|
_globals['_EXAMPLEREQUEST']._serialized_end=746
|
||||||
_globals['_LOSS']._serialized_start=587
|
_globals['_LOSS']._serialized_start=748
|
||||||
_globals['_LOSS']._serialized_end=659
|
_globals['_LOSS']._serialized_end=820
|
||||||
_globals['_TENSOR']._serialized_start=661
|
_globals['_TENSOR']._serialized_start=822
|
||||||
_globals['_TENSOR']._serialized_end=720
|
_globals['_TENSOR']._serialized_end=881
|
||||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=722
|
_globals['_TENSORLIST']._serialized_start=883
|
||||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=782
|
_globals['_TENSORLIST']._serialized_end=934
|
||||||
_globals['_TOPOLOGY']._serialized_start=785
|
_globals['_INFERENCESTATE']._serialized_start=937
|
||||||
_globals['_TOPOLOGY']._serialized_end=1065
|
_globals['_INFERENCESTATE']._serialized_end=1275
|
||||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=906
|
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123
|
||||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=984
|
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194
|
||||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=986
|
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196
|
||||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1065
|
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275
|
||||||
_globals['_PEERCONNECTION']._serialized_start=1067
|
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277
|
||||||
_globals['_PEERCONNECTION']._serialized_end=1140
|
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337
|
||||||
_globals['_PEERCONNECTIONS']._serialized_start=1142
|
_globals['_TOPOLOGY']._serialized_start=1340
|
||||||
_globals['_PEERCONNECTIONS']._serialized_end=1210
|
_globals['_TOPOLOGY']._serialized_end=1620
|
||||||
_globals['_DEVICEFLOPS']._serialized_start=1212
|
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461
|
||||||
_globals['_DEVICEFLOPS']._serialized_end=1267
|
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539
|
||||||
_globals['_DEVICECAPABILITIES']._serialized_start=1269
|
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541
|
||||||
_globals['_DEVICECAPABILITIES']._serialized_end=1376
|
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620
|
||||||
_globals['_SENDNEWTOKENREQUEST']._serialized_start=1378
|
_globals['_PEERCONNECTION']._serialized_start=1622
|
||||||
_globals['_SENDNEWTOKENREQUEST']._serialized_end=1455
|
_globals['_PEERCONNECTION']._serialized_end=1695
|
||||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1457
|
_globals['_PEERCONNECTIONS']._serialized_start=1697
|
||||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1518
|
_globals['_PEERCONNECTIONS']._serialized_end=1765
|
||||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=1520
|
_globals['_DEVICEFLOPS']._serialized_start=1767
|
||||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=1540
|
_globals['_DEVICEFLOPS']._serialized_end=1822
|
||||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=1542
|
_globals['_DEVICECAPABILITIES']._serialized_start=1824
|
||||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1583
|
_globals['_DEVICECAPABILITIES']._serialized_end=1931
|
||||||
_globals['_EMPTY']._serialized_start=1585
|
_globals['_SENDNEWTOKENREQUEST']._serialized_start=1934
|
||||||
_globals['_EMPTY']._serialized_end=1592
|
_globals['_SENDNEWTOKENREQUEST']._serialized_end=2066
|
||||||
_globals['_NODESERVICE']._serialized_start=1595
|
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2068
|
||||||
_globals['_NODESERVICE']._serialized_end=2132
|
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2129
|
||||||
|
_globals['_HEALTHCHECKREQUEST']._serialized_start=2131
|
||||||
|
_globals['_HEALTHCHECKREQUEST']._serialized_end=2151
|
||||||
|
_globals['_HEALTHCHECKRESPONSE']._serialized_start=2153
|
||||||
|
_globals['_HEALTHCHECKRESPONSE']._serialized_end=2194
|
||||||
|
_globals['_EMPTY']._serialized_start=2196
|
||||||
|
_globals['_EMPTY']._serialized_end=2203
|
||||||
|
_globals['_NODESERVICE']._serialized_start=2206
|
||||||
|
_globals['_NODESERVICE']._serialized_end=2743
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
from exo.networking.discovery import Discovery
|
from typing import Dict, List, Callable, Optional
|
||||||
from typing import Dict, List, Callable
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from exo.networking.discovery import Discovery
|
||||||
from exo.topology.device_capabilities import DeviceCapabilities
|
from exo.topology.device_capabilities import DeviceCapabilities
|
||||||
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
|
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
|
||||||
from exo.helpers import DEBUG_DISCOVERY
|
from exo.helpers import DEBUG_DISCOVERY
|
||||||
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
|
|||||||
self,
|
self,
|
||||||
network_config_path: str,
|
network_config_path: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
|
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
|
||||||
):
|
):
|
||||||
self.topology = NetworkTopology.from_path(network_config_path)
|
self.network_config_path = network_config_path
|
||||||
|
self.node_id = node_id
|
||||||
self.create_peer_handle = create_peer_handle
|
self.create_peer_handle = create_peer_handle
|
||||||
|
|
||||||
if node_id not in self.topology.peers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.listen_task = None
|
self.listen_task = None
|
||||||
|
|
||||||
self.known_peers: Dict[str, PeerHandle] = {}
|
self.known_peers: Dict[str, PeerHandle] = {}
|
||||||
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
|
|
||||||
self.peers_in_network.pop(node_id)
|
self._cached_peers: Dict[str, PeerConfig] = {}
|
||||||
|
self._last_modified_time: Optional[float] = None
|
||||||
|
self._file_executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
|
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
if self.listen_task:
|
if self.listen_task: self.listen_task.cancel()
|
||||||
self.listen_task.cancel()
|
self._file_executor.shutdown(wait=True)
|
||||||
|
|
||||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||||
if wait_for_peers > 0:
|
if wait_for_peers > 0:
|
||||||
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
|
|||||||
async def task_find_peers_from_config(self):
|
async def task_find_peers_from_config(self):
|
||||||
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
|
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
|
||||||
while True:
|
while True:
|
||||||
for peer_id, peer_config in self.peers_in_network.items():
|
peers_from_config = await self._get_peers()
|
||||||
|
new_known_peers = {}
|
||||||
|
for peer_id, peer_config in peers_from_config.items():
|
||||||
try:
|
try:
|
||||||
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
|
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
|
||||||
peer = self.known_peers.get(peer_id)
|
peer = self.known_peers.get(peer_id)
|
||||||
@@ -57,15 +58,43 @@ class ManualDiscovery(Discovery):
|
|||||||
is_healthy = await peer.health_check()
|
is_healthy = await peer.health_check()
|
||||||
if is_healthy:
|
if is_healthy:
|
||||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
||||||
self.known_peers[peer_id] = peer
|
new_known_peers[peer_id] = peer
|
||||||
else:
|
elif DEBUG_DISCOVERY >= 2:
|
||||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
|
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
|
||||||
try:
|
|
||||||
del self.known_peers[peer_id]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
|
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
|
||||||
await asyncio.sleep(5.0)
|
await asyncio.sleep(5.0)
|
||||||
|
|
||||||
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
|
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
|
||||||
|
|
||||||
|
async def _get_peers(self):
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
|
||||||
|
|
||||||
|
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
|
||||||
|
return self._cached_peers
|
||||||
|
|
||||||
|
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
|
||||||
|
|
||||||
|
if self.node_id not in topology.peers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Node ID {self.node_id} not found in network config file "
|
||||||
|
f"{self.network_config_path}. Please run with `node_id` set to "
|
||||||
|
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
peers_in_network = topology.peers
|
||||||
|
peers_in_network.pop(self.node_id)
|
||||||
|
|
||||||
|
self._cached_peers = peers_in_network
|
||||||
|
self._last_modified_time = current_mtime
|
||||||
|
|
||||||
|
return peers_in_network
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if DEBUG_DISCOVERY >= 2:
|
||||||
|
print(f"Error when loading network config file from {self.network_config_path}. "
|
||||||
|
f"Please update the config file in order to successfully discover peers. "
|
||||||
|
f"Exception: {e}")
|
||||||
|
return self._cached_peers
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
|||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
self.peer1 = mock.AsyncMock()
|
self.peer1 = mock.AsyncMock()
|
||||||
self.peer1.connect = mock.AsyncMock()
|
self.peer1.connect = mock.AsyncMock()
|
||||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
self.discovery1 = ManualDiscovery(
|
||||||
_ = self.discovery1.start()
|
root_path,
|
||||||
|
"node1",
|
||||||
|
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||||
|
)
|
||||||
|
await self.discovery1.start()
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
async def asyncTearDown(self):
|
||||||
await self.discovery1.stop()
|
await self.discovery1.stop()
|
||||||
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.peer2 = mock.AsyncMock()
|
self.peer2 = mock.AsyncMock()
|
||||||
self.peer1.connect = mock.AsyncMock()
|
self.peer1.connect = mock.AsyncMock()
|
||||||
self.peer2.connect = mock.AsyncMock()
|
self.peer2.connect = mock.AsyncMock()
|
||||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
self.discovery1 = ManualDiscovery(
|
||||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
|
root_path,
|
||||||
|
"node1",
|
||||||
|
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||||
|
)
|
||||||
|
self.discovery2 = ManualDiscovery(
|
||||||
|
root_path,
|
||||||
|
"node2",
|
||||||
|
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
|
||||||
|
)
|
||||||
await self.discovery1.start()
|
await self.discovery1.start()
|
||||||
await self.discovery2.start()
|
await self.discovery2.start()
|
||||||
|
|
||||||
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
|
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
|
||||||
await self.server1.start()
|
await self.server1.start()
|
||||||
await self.server2.start()
|
await self.server2.start()
|
||||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
self.discovery1 = ManualDiscovery(
|
||||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
root_path,
|
||||||
|
"node1",
|
||||||
|
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||||
|
)
|
||||||
|
self.discovery2 = ManualDiscovery(
|
||||||
|
root_path,
|
||||||
|
"node2",
|
||||||
|
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||||
|
)
|
||||||
await self.discovery1.start()
|
await self.discovery1.start()
|
||||||
await self.discovery2.start()
|
await self.discovery2.start()
|
||||||
|
|
||||||
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertFalse(await peers1[0].is_connected())
|
self.assertFalse(await peers1[0].is_connected())
|
||||||
self.assertFalse(await peers2[0].is_connected())
|
self.assertFalse(await peers2[0].is_connected())
|
||||||
|
|
||||||
|
async def test_dynamic_config_update(self):
|
||||||
|
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||||
|
self.assertEqual(len(initial_peers), 1)
|
||||||
|
|
||||||
|
# Save original config for cleanup
|
||||||
|
with open(root_path, "r") as f:
|
||||||
|
original_config = json.load(f)
|
||||||
|
|
||||||
|
try:
|
||||||
|
updated_config = {
|
||||||
|
"peers": {
|
||||||
|
**original_config["peers"],
|
||||||
|
"node3": {
|
||||||
|
"address": "localhost",
|
||||||
|
"port": 50053,
|
||||||
|
"device_capabilities": {
|
||||||
|
"model": "Unknown Model",
|
||||||
|
"chip": "Unknown Chip",
|
||||||
|
"memory": 0,
|
||||||
|
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(root_path, "w") as f:
|
||||||
|
json.dump(updated_config, f, indent=2)
|
||||||
|
|
||||||
|
node3 = mock.AsyncMock(spec=Node)
|
||||||
|
server3 = GRPCServer(node3, "localhost", 50053)
|
||||||
|
await server3.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for the config to be reloaded
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
|
||||||
|
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
|
||||||
|
self.assertEqual(len(updated_peers), 2)
|
||||||
|
|
||||||
|
for peer in updated_peers:
|
||||||
|
await peer.connect()
|
||||||
|
self.assertTrue(await peer.is_connected())
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await server3.stop()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore the original config file
|
||||||
|
with open(root_path, "w") as f:
|
||||||
|
json.dump(original_config, f, indent=2)
|
||||||
|
|
||||||
|
# Wait for the config to be reloaded again
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
|
||||||
|
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||||
|
self.assertEqual(len(updated_peers), 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(unittest.main())
|
asyncio.run(unittest.main())
|
||||||
|
|||||||
@@ -118,44 +118,50 @@ class Node:
|
|||||||
shard,
|
shard,
|
||||||
result: np.ndarray,
|
result: np.ndarray,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
|
inference_state: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if request_id not in self.buffered_token_output:
|
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||||
self.buffered_token_output[request_id] = ([], False)
|
if request_id not in self.buffered_token_output:
|
||||||
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
self.buffered_token_output[request_id] = ([], False)
|
||||||
|
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||||
if shard.is_last_layer() and not is_finished:
|
if shard.is_last_layer() and not is_finished:
|
||||||
self.token_count += 1
|
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
||||||
if self.token_count == 1:
|
await self.inference_engine.ensure_shard(shard)
|
||||||
self.first_token_time = time.perf_counter_ns()
|
self.buffered_token_output[request_id][0].append(token.item())
|
||||||
if self.token_count % 20 == 0:
|
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||||
print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}")
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
||||||
|
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
|
||||||
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
forward = token.reshape(1, -1)
|
||||||
await self.inference_engine.ensure_shard(shard)
|
intermediate_result = self.buffered_token_output[request_id][0]
|
||||||
self.buffered_token_output[request_id][0].append(token.item())
|
else:
|
||||||
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
forward = result
|
||||||
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
||||||
forward = token.reshape(1, -1)
|
|
||||||
self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
|
|
||||||
asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished))
|
|
||||||
else:
|
else:
|
||||||
|
await self.inference_engine.ensure_shard(shard)
|
||||||
|
is_finished = inference_state.get("is_finished", False)
|
||||||
|
intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
|
||||||
forward = result
|
forward = result
|
||||||
|
if shard.is_last_layer():
|
||||||
|
self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
|
||||||
|
asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
|
||||||
|
|
||||||
if is_finished:
|
if is_finished:
|
||||||
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||||
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
||||||
self.outstanding_requests.pop(request_id)
|
self.outstanding_requests.pop(request_id)
|
||||||
else:
|
else:
|
||||||
self.outstanding_requests[request_id] = "waiting"
|
self.outstanding_requests[request_id] = "waiting"
|
||||||
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
|
||||||
|
|
||||||
|
return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
|
||||||
|
|
||||||
return np.array(self.buffered_token_output[request_id][0])
|
|
||||||
|
|
||||||
async def process_prompt(
|
async def process_prompt(
|
||||||
self,
|
self,
|
||||||
base_shard: Shard,
|
base_shard: Shard,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
) -> None:
|
inference_state: Optional[dict] = {},
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
shard = self.get_current_shard(base_shard)
|
shard = self.get_current_shard(base_shard)
|
||||||
start_time = time.perf_counter_ns()
|
start_time = time.perf_counter_ns()
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -172,7 +178,8 @@ class Node:
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await self._process_prompt(base_shard, prompt, request_id)
|
start_time = time.perf_counter_ns()
|
||||||
|
resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
|
||||||
end_time = time.perf_counter_ns()
|
end_time = time.perf_counter_ns()
|
||||||
elapsed_time_ns = end_time - start_time
|
elapsed_time_ns = end_time - start_time
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -192,7 +199,7 @@ class Node:
|
|||||||
)
|
)
|
||||||
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
|
||||||
|
|
||||||
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
|
||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
shard = self.get_current_shard(base_shard)
|
shard = self.get_current_shard(base_shard)
|
||||||
@@ -201,12 +208,13 @@ class Node:
|
|||||||
if not shard.is_first_layer():
|
if not shard.is_first_layer():
|
||||||
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
||||||
self.outstanding_requests[request_id] = "waiting"
|
self.outstanding_requests[request_id] = "waiting"
|
||||||
await self.forward_prompt(shard, prompt, request_id, 0)
|
resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
self.outstanding_requests[request_id] = "processing"
|
self.outstanding_requests[request_id] = "processing"
|
||||||
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
|
||||||
await self.process_inference_result(shard, result, request_id)
|
ret = await self.process_inference_result(shard, result, request_id, inference_state)
|
||||||
|
return result
|
||||||
|
|
||||||
async def enqueue_example(
|
async def enqueue_example(
|
||||||
self,
|
self,
|
||||||
@@ -350,10 +358,11 @@ class Node:
|
|||||||
base_shard: Shard,
|
base_shard: Shard,
|
||||||
tensor: np.ndarray,
|
tensor: np.ndarray,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
) -> None:
|
inference_state: Optional[dict] = None,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
shard = self.get_current_shard(base_shard)
|
shard = self.get_current_shard(base_shard)
|
||||||
start_time = time.perf_counter_ns()
|
start_time = time.perf_counter_ns()
|
||||||
await self._process_tensor(shard, tensor, request_id)
|
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
|
||||||
end_time = time.perf_counter_ns()
|
end_time = time.perf_counter_ns()
|
||||||
elapsed_time_ns = end_time - start_time
|
elapsed_time_ns = end_time - start_time
|
||||||
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
|
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
|
||||||
@@ -363,15 +372,17 @@ class Node:
|
|||||||
base_shard: Shard,
|
base_shard: Shard,
|
||||||
tensor: np.ndarray,
|
tensor: np.ndarray,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
) -> None:
|
inference_state: Optional[dict] = None,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
shard = self.get_current_shard(base_shard)
|
shard = self.get_current_shard(base_shard)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.outstanding_requests[request_id] = "processing"
|
self.outstanding_requests[request_id] = "processing"
|
||||||
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
|
||||||
await self.process_inference_result(shard, result, request_id)
|
ret = await self.process_inference_result(shard, result, request_id, inference_state)
|
||||||
|
return ret
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.outstanding_requests.pop(request_id)
|
self.outstanding_requests.pop(request_id)
|
||||||
print(f"Error processing tensor for shard {shard}: {e}")
|
print(f"Error processing tensor for shard {shard}: {e}")
|
||||||
@@ -404,19 +415,20 @@ class Node:
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
target_index: int,
|
target_index: int,
|
||||||
|
inference_state: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if DEBUG >= 1: print(f"target partition index: {target_index}")
|
if DEBUG >= 1: print(f"target partition index: {target_index}")
|
||||||
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
||||||
next_shard = self.get_current_shard(base_shard, target_index)
|
next_shard = self.get_current_shard(base_shard, target_index)
|
||||||
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
|
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
|
||||||
if target_id == self.id:
|
if target_id == self.id:
|
||||||
await self.process_prompt(next_shard, prompt, request_id)
|
await self.process_prompt(next_shard, prompt, request_id, inference_state)
|
||||||
else:
|
else:
|
||||||
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
||||||
if not target_peer:
|
if not target_peer:
|
||||||
raise ValueError(f"Peer for {target_index} not found")
|
raise ValueError(f"Peer for {target_index} not found")
|
||||||
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
||||||
await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
|
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
|
||||||
|
|
||||||
async def forward_tensor(
|
async def forward_tensor(
|
||||||
self,
|
self,
|
||||||
@@ -424,19 +436,20 @@ class Node:
|
|||||||
tensor: np.ndarray,
|
tensor: np.ndarray,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
target_index: int,
|
target_index: int,
|
||||||
|
inference_state: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if DEBUG >= 1: print(f"target partition index: {target_index}")
|
if DEBUG >= 1: print(f"target partition index: {target_index}")
|
||||||
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
||||||
next_shard = self.get_current_shard(base_shard, target_index)
|
next_shard = self.get_current_shard(base_shard, target_index)
|
||||||
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
|
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
|
||||||
if target_id == self.id:
|
if target_id == self.id:
|
||||||
await self.process_tensor(next_shard, tensor, request_id)
|
await self.process_tensor(next_shard, tensor, request_id, inference_state)
|
||||||
else:
|
else:
|
||||||
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
||||||
if not target_peer:
|
if not target_peer:
|
||||||
raise ValueError(f"Peer for {target_index} not found")
|
raise ValueError(f"Peer for {target_index} not found")
|
||||||
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
|
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
|
||||||
await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
|
await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
|
||||||
|
|
||||||
def get_partition_index(self, offset: int = 0):
|
def get_partition_index(self, offset: int = 0):
|
||||||
if not self.partitioning_strategy:
|
if not self.partitioning_strategy:
|
||||||
@@ -604,3 +617,12 @@ class Node:
|
|||||||
@property
|
@property
|
||||||
def current_topology(self) -> Topology:
|
def current_topology(self) -> Topology:
|
||||||
return self.topology
|
return self.topology
|
||||||
|
|
||||||
|
def handle_stable_diffusion(self, inference_state, result):
|
||||||
|
if inference_state['is_step_finished']:
|
||||||
|
inference_state['step']+=1
|
||||||
|
progress = [inference_state['step'],inference_state['total_steps']]
|
||||||
|
intermediate_result = result
|
||||||
|
if progress[0] == progress[1]:
|
||||||
|
intermediate_result = result
|
||||||
|
return intermediate_result, inference_state
|
||||||
|
|||||||
166
exo/orchestration/tracing.py
Normal file
166
exo/orchestration/tracing.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
from opentelemetry import trace, context
|
||||||
|
from opentelemetry.trace import Status, StatusCode, SpanContext
|
||||||
|
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import time
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TraceContext:
|
||||||
|
request_id: str
|
||||||
|
sequence_number: int
|
||||||
|
current_span: Optional[trace.Span] = None
|
||||||
|
trace_parent: Optional[str] = None
|
||||||
|
token_group_span: Optional[trace.Span] = None
|
||||||
|
token_count: int = 0
|
||||||
|
token_group_size: int = 10 # Default group size
|
||||||
|
request_span: Optional[trace.Span] = None # Track the main request span
|
||||||
|
|
||||||
|
class Tracer:
|
||||||
|
def __init__(self):
|
||||||
|
self.tracer = trace.get_tracer("exo")
|
||||||
|
self.contexts: Dict[str, TraceContext] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
self.propagator = TraceContextTextMapPropagator()
|
||||||
|
|
||||||
|
def get_context(self, request_id: str) -> Optional[TraceContext]:
|
||||||
|
with self._lock:
|
||||||
|
return self.contexts.get(request_id)
|
||||||
|
|
||||||
|
def set_context(self, request_id: str, context: TraceContext):
|
||||||
|
with self._lock:
|
||||||
|
self.contexts[request_id] = context
|
||||||
|
|
||||||
|
def inject_context(self, span: trace.Span) -> str:
|
||||||
|
"""Inject current span context into carrier for propagation"""
|
||||||
|
carrier = {}
|
||||||
|
ctx = trace.set_span_in_context(span)
|
||||||
|
self.propagator.inject(carrier, context=ctx)
|
||||||
|
return carrier.get("traceparent", "")
|
||||||
|
|
||||||
|
def extract_context(self, trace_parent: str) -> Optional[context.Context]:
|
||||||
|
"""Extract span context from carrier"""
|
||||||
|
if not trace_parent:
|
||||||
|
return None
|
||||||
|
carrier = {"traceparent": trace_parent}
|
||||||
|
return self.propagator.extract(carrier)
|
||||||
|
|
||||||
|
def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
|
||||||
|
"""Create a new context with the given trace parent"""
|
||||||
|
parent_ctx = self.extract_context(trace_parent)
|
||||||
|
if parent_ctx:
|
||||||
|
# Create a new request span that links to the parent context
|
||||||
|
request_span = self.tracer.start_span(
|
||||||
|
"request",
|
||||||
|
context=parent_ctx,
|
||||||
|
attributes={
|
||||||
|
"request_id": request_id,
|
||||||
|
"sequence_number": sequence_number
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return TraceContext(
|
||||||
|
request_id=request_id,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
request_span=request_span,
|
||||||
|
current_span=request_span,
|
||||||
|
trace_parent=trace_parent
|
||||||
|
)
|
||||||
|
return TraceContext(request_id=request_id, sequence_number=sequence_number)
|
||||||
|
|
||||||
|
def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
|
||||||
|
"""Handle token generation and manage token group spans"""
|
||||||
|
context.token_count += 1
|
||||||
|
|
||||||
|
# Start a new token group span if needed
|
||||||
|
if not context.token_group_span and context.request_span:
|
||||||
|
group_number = (context.token_count - 1) // context.token_group_size + 1
|
||||||
|
|
||||||
|
# Create token group span as child of request span
|
||||||
|
parent_ctx = trace.set_span_in_context(context.request_span)
|
||||||
|
context.token_group_span = self.tracer.start_span(
|
||||||
|
f"token_group_{group_number}",
|
||||||
|
context=parent_ctx,
|
||||||
|
attributes={
|
||||||
|
"request_id": context.request_id,
|
||||||
|
"group.number": group_number,
|
||||||
|
"group.start_token": context.token_count,
|
||||||
|
"group.max_tokens": context.token_group_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add token to current group span
|
||||||
|
if context.token_group_span:
|
||||||
|
relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
|
||||||
|
context.token_group_span.set_attribute(f"token.{relative_pos}", token)
|
||||||
|
context.token_group_span.set_attribute("token.count", relative_pos)
|
||||||
|
|
||||||
|
# End current group span if we've reached the group size or if generation is finished
|
||||||
|
if context.token_count % context.token_group_size == 0 or is_finished:
|
||||||
|
context.token_group_span.set_attribute("token.final_count", relative_pos)
|
||||||
|
context.token_group_span.end()
|
||||||
|
context.token_group_span = None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
|
||||||
|
"""Start a new span with proper parent context"""
|
||||||
|
attributes = {
|
||||||
|
"request_id": context.request_id,
|
||||||
|
"sequence_number": context.sequence_number
|
||||||
|
}
|
||||||
|
if extra_attributes:
|
||||||
|
attributes.update(extra_attributes)
|
||||||
|
|
||||||
|
# Use request span as parent if available
|
||||||
|
parent_ctx = None
|
||||||
|
if context.request_span:
|
||||||
|
parent_ctx = trace.set_span_in_context(context.request_span)
|
||||||
|
elif context.trace_parent:
|
||||||
|
parent_ctx = self.extract_context(context.trace_parent)
|
||||||
|
if parent_ctx and not context.request_span:
|
||||||
|
# Create a new request span that links to the parent context
|
||||||
|
context.request_span = self.tracer.start_span(
|
||||||
|
"request",
|
||||||
|
context=parent_ctx,
|
||||||
|
attributes={
|
||||||
|
"request_id": context.request_id,
|
||||||
|
"sequence_number": context.sequence_number
|
||||||
|
}
|
||||||
|
)
|
||||||
|
parent_ctx = trace.set_span_in_context(context.request_span)
|
||||||
|
elif context.current_span:
|
||||||
|
parent_ctx = trace.set_span_in_context(context.current_span)
|
||||||
|
|
||||||
|
# Create span with parent context if it exists
|
||||||
|
if parent_ctx:
|
||||||
|
span = self.tracer.start_span(
|
||||||
|
name,
|
||||||
|
context=parent_ctx,
|
||||||
|
attributes=attributes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
span = self.tracer.start_span(
|
||||||
|
name,
|
||||||
|
attributes=attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update context with current span
|
||||||
|
prev_span = context.current_span
|
||||||
|
context.current_span = span
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
yield span
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
span.set_attribute("duration_s", duration)
|
||||||
|
span.set_status(Status(StatusCode.OK))
|
||||||
|
except Exception as e:
|
||||||
|
span.set_status(Status(StatusCode.ERROR, str(e)))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
span.end()
|
||||||
|
context.current_span = prev_span
|
||||||
|
|
||||||
|
# Global tracer instance
|
||||||
|
tracer = Tracer()
|
||||||
@@ -197,7 +197,25 @@
|
|||||||
const div = document.createElement('div');
|
const div = document.createElement('div');
|
||||||
div.className = `message message-role-${role}`;
|
div.className = `message message-role-${role}`;
|
||||||
try {
|
try {
|
||||||
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
|
if (content.includes('![Generated Image]')) {
|
||||||
|
const imageUrl = content.match(/\((.*?)\)/)[1];
|
||||||
|
const img = document.createElement('img');
|
||||||
|
img.src = imageUrl;
|
||||||
|
img.alt = 'Generated Image';
|
||||||
|
img.onclick = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(img.src);
|
||||||
|
const blob = await response.blob();
|
||||||
|
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||||
|
handleImageUpload({ target: { files: [file] } });
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching image:', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
div.appendChild(img);
|
||||||
|
} else {
|
||||||
|
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
|
||||||
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log(content);
|
console.log(content);
|
||||||
console.error(e);
|
console.error(e);
|
||||||
@@ -281,7 +299,7 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="input">
|
<div class="input">
|
||||||
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
|
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
|
||||||
<i class="fas fa-image"></i>
|
<i class="fas fa-image"></i>
|
||||||
</button>
|
</button>
|
||||||
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
|
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
|
||||||
|
|||||||
@@ -231,53 +231,110 @@ document.addEventListener("alpine:init", () => {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
|
|
||||||
if (containsImage) {
|
if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
|
||||||
// Map all messages with string content to object with type text
|
// Send a request to the image generation endpoint
|
||||||
apiMessages = apiMessages.map(msg => {
|
console.log(apiMessages[apiMessages.length - 1].content)
|
||||||
if (typeof msg.content === 'string') {
|
console.log(this.cstate.selectedModel)
|
||||||
return {
|
console.log(this.endpoint)
|
||||||
...msg,
|
const response = await fetch(`${this.endpoint}/image/generations`, {
|
||||||
content: [
|
method: "POST",
|
||||||
{
|
headers: {
|
||||||
type: "text",
|
"Content-Type": "application/json",
|
||||||
text: msg.content
|
},
|
||||||
}
|
body: JSON.stringify({
|
||||||
]
|
"model": 'stable-diffusion-2-1-base',
|
||||||
};
|
"prompt": apiMessages[apiMessages.length - 1].content,
|
||||||
}
|
"image_url": this.imageUrl
|
||||||
return msg;
|
}),
|
||||||
});
|
});
|
||||||
}
|
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
// start receiving server sent events
|
throw new Error("Failed to fetch");
|
||||||
let gottenFirstChunk = false;
|
|
||||||
for await (
|
|
||||||
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
|
|
||||||
) {
|
|
||||||
if (!gottenFirstChunk) {
|
|
||||||
this.cstate.messages.push({ role: "assistant", content: "" });
|
|
||||||
gottenFirstChunk = true;
|
|
||||||
}
|
}
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
let done = false;
|
||||||
|
let gottenFirstChunk = false;
|
||||||
|
|
||||||
// add chunk to the last message
|
while (!done) {
|
||||||
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
|
const { value, done: readerDone } = await reader.read();
|
||||||
|
done = readerDone;
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
|
||||||
// calculate performance tracking
|
if (value) {
|
||||||
tokens += 1;
|
// Assume non-binary data (text) comes first
|
||||||
this.total_tokens += 1;
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
if (start_time === 0) {
|
const parsed = JSON.parse(chunk);
|
||||||
start_time = Date.now();
|
console.log(parsed)
|
||||||
this.time_till_first = start_time - prefill_start;
|
|
||||||
} else {
|
if (parsed.progress) {
|
||||||
const diff = Date.now() - start_time;
|
if (!gottenFirstChunk) {
|
||||||
if (diff > 0) {
|
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||||
this.tokens_per_second = tokens / (diff / 1000);
|
gottenFirstChunk = true;
|
||||||
|
}
|
||||||
|
this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
|
||||||
|
}
|
||||||
|
else if (parsed.images) {
|
||||||
|
if (!gottenFirstChunk) {
|
||||||
|
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||||
|
gottenFirstChunk = true;
|
||||||
|
}
|
||||||
|
const imageUrl = parsed.images[0].url;
|
||||||
|
console.log(imageUrl)
|
||||||
|
this.cstate.messages[this.cstate.messages.length - 1].content = `})`;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else{
|
||||||
|
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
|
||||||
|
if (containsImage) {
|
||||||
|
// Map all messages with string content to object with type text
|
||||||
|
apiMessages = apiMessages.map(msg => {
|
||||||
|
if (typeof msg.content === 'string') {
|
||||||
|
return {
|
||||||
|
...msg,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "text",
|
||||||
|
text: msg.content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(apiMessages)
|
||||||
|
//start receiving server sent events
|
||||||
|
let gottenFirstChunk = false;
|
||||||
|
for await (
|
||||||
|
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
|
||||||
|
) {
|
||||||
|
if (!gottenFirstChunk) {
|
||||||
|
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||||
|
gottenFirstChunk = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add chunk to the last message
|
||||||
|
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
|
||||||
|
|
||||||
|
// calculate performance tracking
|
||||||
|
tokens += 1;
|
||||||
|
this.total_tokens += 1;
|
||||||
|
if (start_time === 0) {
|
||||||
|
start_time = Date.now();
|
||||||
|
this.time_till_first = start_time - prefill_start;
|
||||||
|
} else {
|
||||||
|
const diff = Date.now() - start_time;
|
||||||
|
if (diff > 0) {
|
||||||
|
this.tokens_per_second = tokens / (diff / 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
// Clean the cstate before adding it to histories
|
// Clean the cstate before adding it to histories
|
||||||
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
|
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
|
||||||
cleanedCstate.messages = cleanedCstate.messages.map(msg => {
|
cleanedCstate.messages = cleanedCstate.messages.map(msg => {
|
||||||
|
|||||||
@@ -91,25 +91,70 @@ class TopologyViz:
|
|||||||
content = []
|
content = []
|
||||||
requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
|
requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
|
||||||
max_width = self.console.width - 6 # Full width minus padding and icon
|
max_width = self.console.width - 6 # Full width minus padding and icon
|
||||||
max_lines = 13 # Maximum number of lines for the entire panel content
|
|
||||||
|
# Calculate available height for content
|
||||||
|
panel_height = 15 # Fixed panel height
|
||||||
|
available_lines = panel_height - 2 # Subtract 2 for panel borders
|
||||||
|
lines_per_entry = available_lines // len(requests) if requests else 0
|
||||||
|
|
||||||
for (prompt, output) in reversed(requests):
|
for (prompt, output) in reversed(requests):
|
||||||
prompt_icon, output_icon = "💬️", "🤖"
|
prompt_icon, output_icon = "💬️", "🤖"
|
||||||
|
|
||||||
# Process prompt
|
# Calculate max lines for prompt and output
|
||||||
prompt_lines = prompt.split('\n')
|
max_prompt_lines = lines_per_entry // 3 # Allocate 1/3 for prompt
|
||||||
if len(prompt_lines) > max_lines // 2:
|
max_output_lines = lines_per_entry - max_prompt_lines - 1 # Remaining space minus spacing
|
||||||
prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
|
|
||||||
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
|
# Process prompt
|
||||||
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
|
prompt_lines = []
|
||||||
|
for line in prompt.split('\n'):
|
||||||
|
words = line.split()
|
||||||
|
current_line = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
if current_length + len(word) + 1 <= max_width:
|
||||||
|
current_line.append(word)
|
||||||
|
current_length += len(word) + 1
|
||||||
|
else:
|
||||||
|
if current_line:
|
||||||
|
prompt_lines.append(' '.join(current_line))
|
||||||
|
current_line = [word]
|
||||||
|
current_length = len(word)
|
||||||
|
|
||||||
|
if current_line:
|
||||||
|
prompt_lines.append(' '.join(current_line))
|
||||||
|
|
||||||
|
if len(prompt_lines) > max_prompt_lines:
|
||||||
|
prompt_lines = prompt_lines[:max_prompt_lines - 1] + ['...']
|
||||||
|
|
||||||
|
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
|
||||||
|
prompt_text.append('\n'.join(prompt_lines), style="white")
|
||||||
|
|
||||||
|
# Process output - same word-aware wrapping
|
||||||
|
output_lines = []
|
||||||
|
for line in output.split('\n'):
|
||||||
|
words = line.split()
|
||||||
|
current_line = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
if current_length + len(word) + 1 <= max_width:
|
||||||
|
current_line.append(word)
|
||||||
|
current_length += len(word) + 1
|
||||||
|
else:
|
||||||
|
if current_line:
|
||||||
|
output_lines.append(' '.join(current_line))
|
||||||
|
current_line = [word]
|
||||||
|
current_length = len(word)
|
||||||
|
|
||||||
|
if current_line:
|
||||||
|
output_lines.append(' '.join(current_line))
|
||||||
|
|
||||||
|
if len(output_lines) > max_output_lines:
|
||||||
|
output_lines = output_lines[:max_output_lines - 1] + ['...']
|
||||||
|
|
||||||
# Process output
|
|
||||||
output_lines = output.split('\n')
|
|
||||||
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
|
|
||||||
if len(output_lines) > remaining_lines:
|
|
||||||
output_lines = output_lines[:remaining_lines - 1] + ['...']
|
|
||||||
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
|
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
|
||||||
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
|
output_text.append('\n'.join(output_lines), style="white")
|
||||||
|
|
||||||
content.append(prompt_text)
|
content.append(prompt_text)
|
||||||
content.append(output_text)
|
content.append(output_text)
|
||||||
@@ -119,8 +164,8 @@ class TopologyViz:
|
|||||||
Group(*content),
|
Group(*content),
|
||||||
title="",
|
title="",
|
||||||
border_style="cyan",
|
border_style="cyan",
|
||||||
height=15, # Increased height to accommodate multiple lines
|
height=panel_height,
|
||||||
expand=True # Allow the panel to expand to full width
|
expand=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_main_layout(self) -> str:
|
def _generate_main_layout(self) -> str:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
if command -v python3.12 &>/dev/null; then
|
if command -v python3.12 &>/dev/null; then
|
||||||
echo "Python 3.12 is installed, proceeding with python3.12..."
|
echo "Python 3.12 is installed, proceeding with python3.12..."
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
source ./install.sh
|
source ./install.sh
|
||||||
pushd exo/networking/grpc
|
pushd exo/networking/grpc
|
||||||
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
|
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
echo "Starting node 1"
|
echo "Starting node 1"
|
||||||
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
|
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
|
|||||||
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
|
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
|
||||||
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
|
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
|
||||||
|
|
||||||
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
|
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
|
||||||
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
|
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
|
||||||
models = []
|
models = []
|
||||||
for model_id in model_cards:
|
for model_id in model_cards:
|
||||||
|
|||||||
Reference in New Issue
Block a user