Merge pull request #474 from pranav4501/stable-stable-diffusion-mlx

Stable diffusion mlx
This commit is contained in:
Alex Cheema
2025-01-12 02:57:21 +00:00
committed by GitHub
24 changed files with 2274 additions and 212 deletions

2
.gitignore vendored
View File

@@ -171,3 +171,5 @@ cython_debug/
**/*.xcodeproj/*
.aider*
exo/tinychat/images/*.png

View File

@@ -12,11 +12,17 @@ import traceback
import signal
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable, Optional
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import mlx.core as mx
import tempfile
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
@@ -185,6 +191,7 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
@@ -195,10 +202,12 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
@@ -457,6 +466,85 @@ class ChatGPTAPI:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
async def handle_post_image_generations(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
model = data.get("model", "")
prompt = data.get("prompt", "")
image_url = data.get("image_url", "")
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
shard = build_base_shard(model, self.inference_engine_classname)
if DEBUG >= 2: print(f"shard: {shard}")
if not shard:
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
request_id = str(uuid.uuid4())
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)
try:
if image_url != "" and image_url != None:
img = self.base64_decode(image_url)
else:
img = None
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
await response.prepare(request)
def get_progress_bar(current_step, total_steps, bar_length=50):
# Calculate the percentage of completion
percent = float(current_step) / total_steps
# Calculate the number of hashes to display
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
spaces = ' ' * (bar_length - len(arrow))
# Create the progress bar string
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
return progress_bar
async def stream_image(_request_id: str, result, is_finished: bool):
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
elif isinstance(result, np.ndarray):
im = Image.fromarray(np.array(result))
images_folder = get_exo_images_dir()
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = images_folder / image_filename
im.save(image_path)
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
# Construct the full URL correctly
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
stream_task = None
def on_result(_request_id: str, result, is_finished: bool):
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
await callback.wait(on_result, timeout=self.response_timeout*10)
if stream_task:
# Wait for the stream task to complete before returning
await stream_task
return response
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_delete_model(self, request):
try:
model_name = request.match_info.get('model_name')
@@ -598,3 +686,19 @@ class ChatGPTAPI:
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
def base64_decode(self, base64_string):
#decode and reshape image
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
img = Image.open(BytesIO(image_data))
W, H = (dim - dim % 64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
img = img[None]
return img

View File

@@ -303,6 +303,10 @@ async def download_repo_files(
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
if model_index_exists:
allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)

View File

@@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader):
print(f"No snapshot directory found for {self.current_repo_id}")
return None
if not await aios.path.exists(snapshot_dir/"model_index.json"):
# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
else:
patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
# Check download status for all relevant files
status = {}

View File

@@ -325,4 +325,23 @@ async def shutdown(signal, loop, server):
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
def get_exo_home() -> Path:
if os.name == "nt": # Check if the OS is Windows
docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else:
docs_folder = Path.home() / "Documents"
exo_folder = docs_folder / "Exo"
if not exo_folder.exists():
exo_folder.mkdir()
return exo_folder
def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
if not images_dir.exists():
images_dir.mkdir()
return images_dir

View File

@@ -39,11 +39,15 @@ class InferenceEngine(ABC):
async def clear_session(self):
self.session.empty()
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
tokens = await self.encode(shard, prompt)
x = tokens.reshape(1, -1)
output_data = await self.infer_tensor(request_id, shard, x)
return output_data
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
else:
x = tokens
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
return output_data, inference_state
inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",

View File

@@ -0,0 +1,307 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
import time
from typing import Optional, Tuple
import inspect
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from tqdm import tqdm
from .sd_models.vae import ModelArgs as VAEArgs
from .sd_models.vae import Autoencoder
from .sd_models.tokenizer import load_tokenizer
from .sd_models.clip import CLIPTextModel
from .sd_models.clip import ModelArgs as CLIPArgs
from .sd_models.unet import UNetConfig, UNetModel
from dataclasses import dataclass, field
from exo.inference.shard import Shard
@dataclass
class DiffusionConfig:
beta_schedule: str = "scaled_linear"
beta_start: float = 0.00085
beta_end: float = 0.012
num_train_steps: int = 1000
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
#Sampler
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)
return (b - a) * x + a
def _interp(y, x_new):
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
x_low = x_new.astype(mx.int32)
x_high = mx.minimum(x_low + 1, len(y) - 1)
y_low = y[x_low]
y_high = y[x_high]
delta_x = x_new - x_low
y_new = y_low * (1 - delta_x) + delta_x * y_high
return y_new
class SimpleEulerSampler:
"""A simple Euler integrator that can be used to sample from our diffusion models.
The method ``step()`` performs one Euler step from x_t to x_t_prev.
"""
def __init__(self, config: DiffusionConfig):
# Compute the noise schedule
if config.beta_schedule == "linear":
betas = _linspace(
config.beta_start, config.beta_end, config.num_train_steps
)
elif config.beta_schedule == "scaled_linear":
betas = _linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
).square()
else:
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
alphas = 1 - betas
alphas_cumprod = mx.cumprod(alphas)
self._sigmas = mx.concatenate(
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
@property
def max_time(self):
return len(self._sigmas) - 1
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def add_noise(self, x, t, key=None):
noise = mx.random.normal(x.shape, key=key)
s = self.sigmas(t)
return (x + noise * s) * (s.square() + 1).rsqrt()
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
start_time = start_time or (len(self._sigmas) - 1)
assert 0 < start_time <= (len(self._sigmas) - 1)
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def current_timestep(self, step, total_steps, start_time=None):
if step < total_steps:
steps = self.timesteps(total_steps, start_time)
return steps[step]
else:
return mx.array(0),mx.array(0)
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
dt = sigma_prev - sigma
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev
@dataclass
class ShardConfig:
model_id:str
start_layer:int
end_layer:int
n_layers:int
@dataclass
class StableDiffusionConfig:
model_type:str
vae:VAEArgs
text_encoder:CLIPArgs
scheduler:DiffusionConfig
unet:UNetConfig
shard:ShardConfig
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(StableDiffusionConfig):
shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.model_type = config.model_type
self.config = config
self.model_path = config.vae['path'].split('/vae')[0]
self.shard = config.shard
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
if self.shard_clip.start_layer != -1:
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
else:
self.text_encoder = nn.Identity()
self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
self.sampler = SimpleEulerSampler(self.diffusion_config)
if self.shard_unet.start_layer!=-1:
self.config_unet = UNetConfig.from_dict(config.unet['config'])
self.unet = UNetModel(self.config_unet, self.shard_unet)
else:
self.unet = nn.Identity()
self.config_vae=VAEArgs.from_dict(config.vae['config'])
if self.shard_encoder.start_layer != -1:
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
else:
self.encoder = nn.Identity()
if self.shard_decoder.start_layer != -1:
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
else:
self.decoder = nn.Identity()
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
is_finished = False
is_step_finished = False
if t.item()==1000:
if self.shard_clip.start_layer == 0:
conditioning = x
if self.shard_clip.start_layer != -1:
conditioning, mask= self.text_encoder(conditioning,mask)
seed = int(time.time())
mx.random.seed(seed)
if image is None:
if self.shard_encoder.is_last_layer():
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
x_t_prev=x
start_step = self.sampler.max_time
else:
if self.shard_encoder.start_layer != -1:
image= self.encoder.encode(image)
if self.shard_encoder.is_last_layer():
start_step = self.sampler.max_time*strength
total_steps = int(total_steps*strength)
image = mx.broadcast_to(image, (1,) + image.shape[1:])
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
image = None
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
# Perform the denoising loop
if self.shard_unet.start_layer != -1:
with tqdm(total=total_steps,initial=step+1) as pbar:
if step<total_steps:
x = x_t_prev
if self.shard_unet.is_first_layer():
x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
else:
x_t_unet = x
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
if self.shard_unet.is_last_layer():
if cfg_weight > 1:
eps_text, eps_neg = x.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
x_t_prev=x
mx.eval(x)
if self.shard_decoder.is_last_layer():
is_step_finished=True
if self.shard_decoder.start_layer != -1:
x=self.decoder.decode(x)
if self.shard_decoder.is_last_layer():
x = mx.clip(x / 2 + 0.5, 0, 1)
B, H, W, C = x.shape
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(1 * H, B // 1 * W, C)
x = (x * 255).astype(mx.uint8)
if t_prev.item() ==0:
is_finished=True
mx.eval(x)
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
def load(self):
if self.shard_encoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.encoder.sanitize(vae_weights)
self.encoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_decoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.decoder.sanitize(vae_weights)
self.decoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_clip.start_layer != -1:
clip_weights = mx.load(self.config_clip.weight_files[0])
clip_weights = self.text_encoder.sanitize(clip_weights)
self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
if self.shard_unet.start_layer !=-1:
unet_weights = mx.load(self.config_unet.weight_files[0])
unet_weights = self.unet.sanitize(unet_weights)
self.unet.load_weights(list(unet_weights.items()), strict=True)
def model_shards(shard:ShardConfig):
def create_shard(shard, model_ranges):
start_layer = shard.start_layer
end_layer = shard.end_layer
shards = {}
for model_name, (range_start, range_end) in model_ranges.items():
if start_layer < range_end and end_layer >= range_start:
# Calculate the overlap with the model range
overlap_start = max(start_layer, range_start)
overlap_end = min(end_layer, range_end - 1)
# Adjust the layers relative to the model's range
relative_start = overlap_start - range_start
relative_end = overlap_end - range_start
shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
else:
# If no overlap, create a zero-layer shard
shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
return shards
# Define the ranges for different models
model_ranges = {
'clip': (0, 12),
'vae_encoder':(12,17),
'unet':(17,26),
'vae_decoder': (26, 31) # Example range for unet
}
# Call the function and get the shards for all models
shards = create_shard(shard, model_ranges)
# Access individual shards
shard_clip = shards['clip']
shard_encoder = shards['vae_encoder']
shard_unet = shards['unet']
shard_decoder = shards['vae_decoder']
return shard_clip, shard_encoder, shard_unet, shard_decoder

View File

@@ -0,0 +1,191 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
import math
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import field, dataclass
from exo.inference.shard import Shard
from exo.inference.mlx.models.base import IdentityBlock
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
projection_dim: Optional[int] = None
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return ModelArgs(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
hidden_act=config.get("hidden_act", "quick_gelu"),
weight_files=config.get("weight_files", [])
)
@dataclass
class ModelArgs(CLIPTextModelConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
weight_files: List[str] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
@dataclass
class CLIPOutput:
pooled_output: Optional[mx.array] = None
last_hidden_state: Optional[mx.array] = None
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
self.attention.query_proj.bias = mx.zeros(model_dims)
self.attention.key_proj.bias = mx.zeros(model_dims)
self.attention.value_proj.bias = mx.zeros(model_dims)
self.attention.out_proj.bias = mx.zeros(model_dims)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
super().__init__()
self.shard = shard
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
if self.shard.is_first_layer():
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = []
for i in range(math.ceil(config.num_layers/2)):
if 2*i in self.layers_range:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
else:
self.layers.append(IdentityBlock())
if self.shard.is_last_layer():
self.final_layer_norm = nn.LayerNorm(config.model_dims)
if config.projection_dim is not None:
self.text_projection = nn.Linear(
config.model_dims, config.projection_dim, bias=False
)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def __call__(self, x, mask=None):
# Extract some shapes
if self.shard.is_first_layer():
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
# Apply the final layernorm and return
if self.shard.is_last_layer():
x = self.final_layer_norm(x)
return x, mask
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
if "position_ids" in key:
continue
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
if key.startswith("layers."):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if not self.shard.is_first_layer() and "embedding" in key:
continue
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
continue
if not self.shard.is_last_layer() and key.startswith("text_projection"):
continue
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,131 @@
# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
import regex
import json
import glob
class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
return tokens
def encode(self, prompt):
tokens = [self.tokenize(prompt)]
negative_text = ""
if negative_text is not None:
tokens += [self.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
return tokens
def load_tokenizer(
model_path: str,
vocab_key: str = "tokenizer_vocab",
merges_key: str = "tokenizer_merges",
):
vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return Tokenizer(bpe_ranks, vocab)

View File

@@ -0,0 +1,629 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import dataclass, field
from typing import Tuple, Optional, List
from exo.inference.shard import Shard
@dataclass
class UNetConfig:
in_channels: int = 4
out_channels: int = 4
conv_in_kernel: int = 3
conv_out_kernel: int = 3
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: Tuple[int] = (2, 2, 2, 2)
mid_block_layers: int = 2
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
)
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
projection_class_embeddings_input_dim: Optional[int] = None
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls,config):
n_blocks = len(config['block_out_channels'])
return UNetConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
transformer_layers_per_block=config.get(
"transformer_layers_per_block", (1,) * 4
),
num_attention_heads=(
[config["attention_head_dim"]] * n_blocks
if isinstance(config["attention_head_dim"], int)
else config["attention_head_dim"]
),
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
down_block_types=config["down_block_types"],
up_block_types=config["up_block_types"][::-1],
addition_embed_type=config.get("addition_embed_type", None),
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
projection_class_embeddings_input_dim=config.get(
"projection_class_embeddings_input_dim", None
),
weight_files=config.get("weight_files", [])
)
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def __call__(self, x):
x = self.linear_1(x)
x = nn.silu(x)
x = self.linear_2(x)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
model_dims: int,
num_heads: int,
hidden_dims: Optional[int] = None,
memory_dims: Optional[int] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(model_dims)
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
self.attn1.out_proj.bias = mx.zeros(model_dims)
memory_dims = memory_dims or model_dims
self.norm2 = nn.LayerNorm(model_dims)
self.attn2 = nn.MultiHeadAttention(
model_dims, num_heads, key_input_dims=memory_dims
)
self.attn2.out_proj.bias = mx.zeros(model_dims)
hidden_dims = hidden_dims or 4 * model_dims
self.norm3 = nn.LayerNorm(model_dims)
self.linear1 = nn.Linear(model_dims, hidden_dims)
self.linear2 = nn.Linear(model_dims, hidden_dims)
self.linear3 = nn.Linear(hidden_dims, model_dims)
def __call__(self, x, memory, attn_mask, memory_mask):
# Self attention
y = self.norm1(x)
y = self.attn1(y, y, y, attn_mask)
x = x + y
# Cross attention
y = self.norm2(x)
y = self.attn2(y, memory, memory, memory_mask)
x = x + y
# FFN
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu(y_b)
y = self.linear3(y)
x = x + y
return x
class Transformer2D(nn.Module):
"""A transformer model for inputs with 2 spatial dimensions."""
def __init__(
self,
in_channels: int,
model_dims: int,
encoder_dims: int,
num_heads: int,
num_layers: int = 1,
norm_num_groups: int = 32,
):
super().__init__()
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
self.proj_in = nn.Linear(in_channels, model_dims)
self.transformer_blocks = [
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
for i in range(num_layers)
]
self.proj_out = nn.Linear(model_dims, in_channels)
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
dtype = x.dtype
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
for block in self.transformer_blocks:
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
# Apply the output projection and reshape
x = self.proj_out(x)
x = x.reshape(B, H, W, C)
return x + input_x
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
groups: int = 32,
temb_channels: Optional[int] = None,
):
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if in_channels != out_channels:
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
dtype = x.dtype
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv2(y)
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
return x
class UNetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
prev_out_channels: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 8,
cross_attention_dim=1280,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
add_cross_attention=True,
):
super().__init__()
# Prepare the in channels list for the resnets
if prev_out_channels is None:
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
else:
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
in_channels_list = [
a + b for a, b in zip(in_channels_list, res_channels_list)
]
# Add resnet blocks that also process the time embedding
self.resnets = [
ResnetBlock2D(
in_channels=ic,
out_channels=out_channels,
temb_channels=temb_channels,
groups=resnet_groups,
)
for ic in in_channels_list
]
# Add optional cross attention layers
if add_cross_attention:
self.attentions = [
Transformer2D(
in_channels=out_channels,
model_dims=out_channels,
num_heads=num_attention_heads,
num_layers=transformer_layers_per_block,
encoder_dims=cross_attention_dim,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(
self,
x,
encoder_x=None,
temb=None,
attn_mask=None,
encoder_attn_mask=None,
residual_hidden_states=None,
):
output_states = []
for i in range(len(self.resnets)):
if residual_hidden_states is not None:
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
x = self.resnets[i](x, temb)
if "attentions" in self:
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
output_states.append(x)
if "downsample" in self:
x = self.downsample(x)
output_states.append(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
output_states.append(x)
return x, output_states
class UNetModel(nn.Module):
"""The conditional 2D UNet model that actually performs the denoising."""
def __init__(self, config: UNetConfig, shard: Shard):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
if shard.is_first_layer():
self.conv_in = nn.Conv2d(
config.in_channels,
config.block_out_channels[0],
config.conv_in_kernel,
padding=(config.conv_in_kernel - 1) // 2,
)
self.timesteps = nn.SinusoidalPositionalEncoding(
config.block_out_channels[0],
max_freq=1,
min_freq=math.exp(
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.time_embedding = TimestepEmbedding(
config.block_out_channels[0],
config.block_out_channels[0] * 4,
)
if config.addition_embed_type == "text_time":
self.add_time_proj = nn.SinusoidalPositionalEncoding(
config.addition_time_embed_dim,
max_freq=1,
min_freq=math.exp(
-math.log(10000)
+ 2 * math.log(10000) / config.addition_time_embed_dim
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.add_embedding = TimestepEmbedding(
config.projection_class_embeddings_input_dim,
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
)
self.down_blocks = []
for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):
if i in self.layers_range:
self.down_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
num_layers=config.layers_per_block[i],
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention="CrossAttn" in config.down_block_types[i],
)
)
else:
self.down_blocks.append(nn.Identity())
# Make the middle block
if 4 in self.layers_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
Transformer2D(
in_channels=config.block_out_channels[-1],
model_dims=config.block_out_channels[-1],
num_heads=config.num_attention_heads[-1],
num_layers=config.transformer_layers_per_block[-1],
encoder_dims=config.cross_attention_dim[-1],
),
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
]
# Make the upsampling blocks
block_channels = (
[config.block_out_channels[0]]
+ list(config.block_out_channels)
+ [config.block_out_channels[-1]]
)
total_items = len(block_channels) - 3
reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
self.up_blocks = []
for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):
i = total_items - rev_i
if rev_i+5 in self.layers_range:
self.up_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
prev_out_channels=prev_out_channels,
num_layers=config.layers_per_block[i] + 1,
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention="CrossAttn" in config.up_block_types[i],
)
)
else:
self.up_blocks.append(nn.Identity())
if shard.is_last_layer():
self.conv_norm_out = nn.GroupNorm(
config.norm_num_groups,
config.block_out_channels[0],
pytorch_compatible=True,
)
self.conv_out = nn.Conv2d(
config.block_out_channels[0],
config.out_channels,
config.conv_out_kernel,
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(
self,
x,
timestep,
encoder_x,
attn_mask=None,
encoder_attn_mask=None,
text_time=None,
residuals=None,
):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Add the extra text_time conditioning
if text_time is not None:
text_emb, time_ids = text_time
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
emb = mx.concatenate([text_emb, emb], axis=-1)
emb = self.add_embedding(emb)
temb = temb + emb
if self.shard.is_first_layer():
# Preprocess the input
x = self.conv_in(x)
residuals = [x]
# Run the downsampling part of the unet
for i in range(len(self.down_blocks)):
if i in self.layers_range:
x, res = self.down_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
)
residuals.extend(res)
else:
x= self.down_blocks[i](x)
if 4 in self.layers_range:
# Run the middle part of the unet
x = self.mid_blocks[0](x, temb)
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
x = self.mid_blocks[2](x, temb)
# Run the upsampling part of the unet
for i in range(len(self.up_blocks)):
if i+5 in self.layers_range:
x, _ = self.up_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
residual_hidden_states=residuals,
)
else:
x= self.up_blocks[i](x)
# Postprocess the output
if self.shard.is_last_layer():
dtype = x.dtype
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = nn.silu(x)
x = self.conv_out(x)
return x, residuals
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
k1=""
k2=""
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map transformer ffn
if "ff.net.2" in key:
key = key.replace("ff.net.2", "linear3")
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = mx.split(value, 2)
if "conv_shortcut.weight" in key:
value = value.squeeze()
# Transform the weights from 1x1 convs to linear
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if key.startswith("conv_in") :
if 0 not in self.layers_range:
continue
if key.startswith("down_blocks"):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if key.startswith("mid_block"):
if 4 not in self.layers_range:
continue
if key.startswith("up_blocks"):
layer_num = int(key.split(".")[1])
if (layer_num+5) not in self.layers_range:
continue
if key.startswith("conv_out") or key.startswith("conv_norm_out"):
if 8 not in self.layers_range:
continue
if len(k1)>0:
sanitized_weights[k1] = v1
sanitized_weights[k2] = v2
else:
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,429 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .unet import ResnetBlock2D, upsample_nearest
from dataclasses import dataclass, field
from exo.inference.shard import Shard
from typing import Tuple
import inspect
from ..base import IdentityBlock
@dataclass
class AutoencoderConfig:
in_channels: int = 3
out_channels: int = 3
latent_channels_out: int = 8
latent_channels_in: int = 4
block_out_channels: Tuple[int] = (128, 256, 512, 512)
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.18215
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(AutoencoderConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
class Attention(nn.Module):
"""A single head unmasked attention for use with the VAE."""
def __init__(self, dims: int, norm_groups: int = 32):
super().__init__()
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
self.query_proj = nn.Linear(dims, dims)
self.key_proj = nn.Linear(dims, dims)
self.value_proj = nn.Linear(dims, dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x):
B, H, W, C = x.shape
y = self.group_norm(x)
queries = self.query_proj(y).reshape(B, H * W, C)
keys = self.key_proj(y).reshape(B, H * W, C)
values = self.value_proj(y).reshape(B, H * W, C)
scale = 1 / math.sqrt(queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 2, 1)
attn = mx.softmax(scores, axis=-1)
y = (attn @ values).reshape(B, H, W, C)
y = self.out_proj(y)
x = x + y
return x
class EncoderDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
):
super().__init__()
# Add the resnet blocks
self.resnets = [
ResnetBlock2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
groups=resnet_groups,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=0
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x):
for resnet in self.resnets:
x = resnet(x)
if "downsample" in self:
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.downsample(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
return x
class Encoder(nn.Module):
"""Implements the encoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
latent_channels_out: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
layers_range: List[int] = [],
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
):
super().__init__()
self.layers_range = layers_range
self.shard = shard
if self.shard.is_first_layer():
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
)
channels = [block_out_channels[0]] + list(block_out_channels)
self.down_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in self.layers_range:
self.down_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=i < len(block_out_channels) - 1,
add_upsample=False,
)
)
else:
self.down_blocks.append(IdentityBlock())
current_layer += 1
if self.shard.is_last_layer():
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[-1], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
def __call__(self, x):
if self.shard.is_first_layer():
x = self.conv_in(x)
for l in self.down_blocks:
x = l(x)
if self.shard.is_last_layer():
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
"""Implements the decoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
shard: Shard,
layer_range: List[int],
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.out_channels = out_channels
self.layers_range = layer_range
if 0 in layer_range:
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
)
if 0 in layer_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
channels = list(reversed(block_out_channels))
channels = [channels[0]] + channels
self.up_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in layer_range:
self.up_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=False,
add_upsample=i < len(block_out_channels) - 1,
)
)
else:
self.up_blocks.append(IdentityBlock())
current_layer += 1
if 4 in layer_range:
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[0], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
def __call__(self, x):
if 0 in self.layers_range:
x = self.conv_in(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
for l in self.up_blocks:
x = l(x)
if 4 in self.layers_range:
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Autoencoder(nn.Module):
"""The autoencoder that allows us to perform diffusion in the latent space."""
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
self.latent_channels = config.latent_channels_in
self.scaling_factor = config.scaling_factor
self.model_shard = model_shard
if self.model_shard == "vae_encoder":
self.encoder = Encoder(
config.in_channels,
config.latent_channels_out,
config.block_out_channels,
config.layers_per_block,
resnet_groups=config.norm_num_groups,
layers_range=self.layers_range,
shard=shard
)
if self.shard.is_last_layer():
self.quant_proj = nn.Linear(
config.latent_channels_out, config.latent_channels_out
)
if self.model_shard == "vae_decoder":
self.decoder = Decoder(
config.latent_channels_in,
config.out_channels,
shard,
self.layers_range,
config.block_out_channels,
config.layers_per_block + 1,
resnet_groups=config.norm_num_groups,
)
if self.shard.is_first_layer():
self.post_quant_proj = nn.Linear(
config.latent_channels_in, config.latent_channels_in
)
def decode(self, z):
if self.shard.is_first_layer():
z = z / self.scaling_factor
z=self.post_quant_proj(z)
return self.decoder(z)
def encode(self, x):
x = self.encoder(x)
if self.shard.is_last_layer():
x = self.quant_proj(x)
mean, logvar = x.split(2, axis=-1)
mean = mean * self.scaling_factor
logvar = logvar + 2 * math.log(self.scaling_factor)
x = mean
return x
def __call__(self, x, key=None):
mean, logvar = self.encode(x)
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
def sanitize(self, weights):
shard = self.shard
layers = self.layers_range
sanitized_weights = {}
for key, value in weights.items():
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map attention layers
if "key" in key:
key = key.replace("key", "key_proj")
if "proj_attn" in key:
key = key.replace("proj_attn", "out_proj")
if "query" in key:
key = key.replace("query", "query_proj")
if "value" in key:
key = key.replace("value", "value_proj")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map the quant/post_quant layers
if "quant_conv" in key:
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
# Map the conv_shortcut to linear
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if "post_quant_conv" in key :
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
if 'decoder' in key and self.model_shard == "vae_decoder":
if key.startswith("decoder.mid_blocks."):
if 0 in layers:
sanitized_weights[key] = value
if "conv_in" in key and 0 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.up_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_norm_out") and 4 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_out") and 4 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_decoder":
if key.startswith("post_quant_proj") and 0 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_encoder":
if key.startswith("encoder."):
if "conv_in" in key and shard.is_first_layer():
sanitized_weights[key] = value
if key.startswith("encoder.down_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_norm_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if key.startswith("quant_proj") and shard.is_last_layer():
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -77,13 +77,17 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id)
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
return output_data
if self.model.model_type != 'StableDiffusionPipeline':
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
else:
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
output_data = np.array(output_data)
return output_data, inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)

View File

@@ -62,8 +62,16 @@ def _get_classes(config: dict):
def load_config(model_path: Path) -> dict:
try:
with open(model_path/"config.json", "r") as f:
config = json.load(f)
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
return config
model_index_path = model_path / "model_index.json"
if model_index_path.exists():
config = load_model_index(model_path, model_index_path)
return config
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
@@ -110,6 +118,24 @@ def load_model_shard(
# Try weight for back-compat
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if config.get("model_index", False):
model.load()
return model
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -129,19 +155,7 @@ def load_model_shard(
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
@@ -186,6 +200,9 @@ async def load_shard(
processor.eos_token_id = processor.tokenizer.eos_token_id
processor.encode = processor.tokenizer.encode
return model, processor
elif hasattr(model, "tokenizer"):
tokenizer = model.tokenizer
return model, tokenizer
else:
tokenizer = await resolve_tokenizer(model_path)
return model, tokenizer
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
return img
else:
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
# loading a combined config for all models in the index
def load_model_index(model_path: Path, model_index_path: Path):
models_config = {}
with open(model_index_path, "r") as f:
model_index = json.load(f)
models_config["model_index"] = True
models_config["model_type"] = model_index["_class_name"]
models_config["models"] = {}
for model in model_index.keys():
model_config_path = glob.glob(str(model_path / model / "*config.json"))
if len(model_config_path)>0:
with open(model_config_path[0], "r") as f:
model_config = { }
model_config["model_type"] = model
model_config["config"] = json.load(f)
model_config["path"] = model_path / model
if model_config["path"]/"*model.safetensors":
model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
model_config["path"] = str(model_path / model)
m = {}
m[model] = model_config
models_config.update(m)
return models_config

View File

@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
from .losses import length_masked_ce_loss
from collections import OrderedDict
import asyncio
from typing import Optional
Tensor.no_grad = True
# default settings
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
safe_save(state_dict, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
await self.ensure_shard(shard)
def wrap_infer():
x = Tensor(input_data)
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
self.states[request_id].start += x.shape[1]
return out.realize()
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
return output_data.numpy()
return output_data.numpy(), inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
def step(x, y, l):

View File

@@ -151,7 +151,7 @@ api = ChatGPTAPI(
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
)
def preemptively_start_download(request_id: str, opaque_status: str):

View File

@@ -111,6 +111,8 @@ model_cards = {
# gemma
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
# stable diffusion
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
# phi
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
@@ -156,6 +158,7 @@ pretty_name = {
"phi-4": "Phi-4",
"llama-3-8b": "Llama 3 8B",
"llama-3-70b": "Llama 3 70B",
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
}
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

View File

@@ -11,7 +11,8 @@ from exo.inference.shard import Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.helpers import DEBUG
import json
import mlx.core as mx
class GRPCPeerHandle(PeerHandle):
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
@@ -71,7 +72,7 @@ class GRPCPeerHandle(PeerHandle):
traceback.print_exc()
return False
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
@@ -81,6 +82,7 @@ class GRPCPeerHandle(PeerHandle):
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
)
response = await self.stub.SendPrompt(request)
@@ -89,7 +91,7 @@ class GRPCPeerHandle(PeerHandle):
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
@@ -99,6 +101,7 @@ class GRPCPeerHandle(PeerHandle):
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
)
response = await self.stub.SendTensor(request)
@@ -175,9 +178,43 @@ class GRPCPeerHandle(PeerHandle):
return topology
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
tensor = None
if isinstance(result, np.ndarray):
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
result = []
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
await self.stub.SendResult(request)
async def send_opaque_status(self, request_id: str, status: str) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
await self.stub.SendOpaqueStatus(request)
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state

View File

@@ -8,6 +8,8 @@ from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
import json
import mlx.core as mx
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -50,7 +52,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
)
prompt = request.prompt
request_id = request.request_id
result = await self.node.process_prompt(shard, prompt, request_id)
inference_state = self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
@@ -65,7 +68,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id
result = await self.node.process_tensor(shard, tensor, request_id)
inference_state = self.deserialize_inference_state(request.inference_state)
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
@@ -122,7 +127,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
request_id = request.request_id
result = request.result
is_finished = request.is_finished
img = request.tensor
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
result = list(result)
if len(img.tensor_data) > 0:
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()
@@ -135,3 +144,22 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)
def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
inference_state = {}
for k, tensor_data in inference_state_proto.tensor_data.items():
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
for k, tensor_list in inference_state_proto.tensor_list_data.items():
inference_state[k] = [
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
for tensor in tensor_list.tensors
]
if inference_state_proto.other_data_json:
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
return inference_state

View File

@@ -24,12 +24,14 @@ message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message ExampleRequest {
@@ -61,6 +63,16 @@ message Tensor {
string dtype = 3;
}
message TensorList {
repeated Tensor tensors = 1;
}
message InferenceState {
map<string, Tensor> tensor_data = 1;
map<string, TensorList> tensor_list_data = 2;
string other_data_json = 3;
}
message CollectTopologyRequest {
repeated string visited = 1;
int32 max_depth = 2;
@@ -96,7 +108,8 @@ message DeviceCapabilities {
message SendResultRequest {
string request_id = 1;
repeated int32 result = 2;
bool is_finished = 3;
optional Tensor tensor = 3;
bool is_finished = 4;
}
message SendOpaqueStatusRequest {

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,7 @@
import grpc
import warnings
from . import node_service_pb2 as node__service__pb2
from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
GRPC_GENERATED_VERSION = '1.68.0'
GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in node_service_pb2_grpc.py depends on'
+ f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -36,43 +36,43 @@ class NodeServiceStub(object):
"""
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendExample = channel.unary_unary(
'/node_service.NodeService/SendExample',
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=node__service__pb2.Loss.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
_registered_method=True)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=node__service__pb2.InferenceResult.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
_registered_method=True)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
'/node_service.NodeService/SendOpaqueStatus',
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
_registered_method=True)
self.HealthCheck = channel.unary_unary(
'/node_service.NodeService/HealthCheck',
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
_registered_method=True)
@@ -132,43 +132,43 @@ def add_NodeServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
),
'SendExample': grpc.unary_unary_rpc_method_handler(
servicer.SendExample,
request_deserializer=node__service__pb2.ExampleRequest.FromString,
response_serializer=node__service__pb2.Loss.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
),
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
),
'CollectTopology': grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
),
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
servicer.SendOpaqueStatus,
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
),
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@@ -196,8 +196,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -223,8 +223,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -250,8 +250,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendExample',
node__service__pb2.ExampleRequest.SerializeToString,
node__service__pb2.Loss.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
options,
channel_credentials,
insecure,
@@ -277,8 +277,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/GetInferenceResult',
node__service__pb2.GetInferenceResultRequest.SerializeToString,
node__service__pb2.InferenceResult.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
options,
channel_credentials,
insecure,
@@ -304,8 +304,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/CollectTopology',
node__service__pb2.CollectTopologyRequest.SerializeToString,
node__service__pb2.Topology.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
options,
channel_credentials,
insecure,
@@ -331,8 +331,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -358,8 +358,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendOpaqueStatus',
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
node__service__pb2.Empty.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -385,8 +385,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/HealthCheck',
node__service__pb2.HealthCheckRequest.SerializeToString,
node__service__pb2.HealthCheckResponse.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,

View File

@@ -112,37 +112,49 @@ class Node:
shard,
result: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
):
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
forward = token.reshape(1, -1)
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
if shard.model_id != 'stable-diffusion-2-1-base':
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
forward = token.reshape(1, -1)
intermediate_result = self.buffered_token_output[request_id][0]
else:
forward = result
else:
await self.inference_engine.ensure_shard(shard)
is_finished = inference_state.get("is_finished", False)
intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
forward = result
if shard.is_last_layer():
self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
if is_finished:
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
if shard.model_id != 'stable-diffusion-2-1-base':
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
self.outstanding_requests.pop(request_id)
else:
self.outstanding_requests[request_id] = "waiting"
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
return np.array(self.buffered_token_output[request_id][0])
async def process_prompt(
self,
base_shard: Shard,
prompt: str,
request_id: Optional[str] = None,
inference_state: Optional[dict] = {},
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
@@ -160,7 +172,7 @@ class Node:
)
)
start_time = time.perf_counter_ns()
resp = await self._process_prompt(base_shard, prompt, request_id)
resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
@@ -181,7 +193,7 @@ class Node:
)
return resp
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
@@ -190,12 +202,12 @@ class Node:
if not shard.is_first_layer():
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
self.outstanding_requests[request_id] = "waiting"
resp = await self.forward_prompt(shard, prompt, request_id, 0)
resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
return None
else:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
ret = await self.process_inference_result(shard, result, request_id)
result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return result
async def enqueue_example(
@@ -340,6 +352,7 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
@@ -358,7 +371,7 @@ class Node:
)
)
start_time = time.perf_counter_ns()
resp = await self._process_tensor(shard, tensor, request_id)
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
@@ -383,6 +396,7 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
@@ -391,8 +405,8 @@ class Node:
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
try:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
ret = await self.process_inference_result(shard, result, request_id)
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return ret
except Exception as e:
self.outstanding_requests.pop(request_id)
@@ -427,19 +441,20 @@ class Node:
prompt: str,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
if target_id == self.id:
await self.process_prompt(next_shard, prompt, request_id)
await self.process_prompt(next_shard, prompt, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
async def forward_tensor(
self,
@@ -447,19 +462,20 @@ class Node:
tensor: np.ndarray,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
if target_id == self.id:
await self.process_tensor(next_shard, tensor, request_id)
await self.process_tensor(next_shard, tensor, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
def get_partition_index(self, offset: int = 0):
if not self.partitioning_strategy:
@@ -632,3 +648,12 @@ class Node:
@property
def current_topology(self) -> Topology:
return self.topology
def handle_stable_diffusion(self, inference_state, result):
if inference_state['is_step_finished']:
inference_state['step']+=1
progress = [inference_state['step'],inference_state['total_steps']]
intermediate_result = result
if progress[0] == progress[1]:
intermediate_result = result
return intermediate_result, inference_state

View File

@@ -182,7 +182,25 @@
const div = document.createElement('div');
div.className = `message message-role-${role}`;
try {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
if (content.includes('![Generated Image]')) {
const imageUrl = content.match(/\((.*?)\)/)[1];
const img = document.createElement('img');
img.src = imageUrl;
img.alt = 'Generated Image';
img.onclick = async () => {
try {
const response = await fetch(img.src);
const blob = await response.blob();
const file = new File([blob], 'image.png', { type: 'image/png' });
handleImageUpload({ target: { files: [file] } });
} catch (error) {
console.error('Error fetching image:', error);
}
};
div.appendChild(img);
} else {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
}
} catch (e) {
console.log(content);
console.error(e);
@@ -266,7 +284,7 @@
</span>
</div>
<div class="input">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
<i class="fas fa-image"></i>
</button>
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>

View File

@@ -228,53 +228,110 @@ document.addEventListener("alpine:init", () => {
};
}
});
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
// Send a request to the image generation endpoint
console.log(apiMessages[apiMessages.length - 1].content)
console.log(this.cstate.selectedModel)
console.log(this.endpoint)
const response = await fetch(`${this.endpoint}/image/generations`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
"model": 'stable-diffusion-2-1-base',
"prompt": apiMessages[apiMessages.length - 1].content,
"image_url": this.imageUrl
}),
});
}
// start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
if (!response.ok) {
throw new Error("Failed to fetch");
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
const reader = response.body.getReader();
let done = false;
let gottenFirstChunk = false;
while (!done) {
const { value, done: readerDone } = await reader.read();
done = readerDone;
const decoder = new TextDecoder();
if (value) {
// Assume non-binary data (text) comes first
const chunk = decoder.decode(value, { stream: true });
const parsed = JSON.parse(chunk);
console.log(parsed)
if (parsed.progress) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
}
else if (parsed.images) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
const imageUrl = parsed.images[0].url;
console.log(imageUrl)
this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
}
}
}
}
else{
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
});
}
console.log(apiMessages)
//start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
}
}
}
}
// Clean the cstate before adding it to histories
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
cleanedCstate.messages = cleanedCstate.messages.map(msg => {