mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Image to image generation
This commit is contained in:
@@ -20,6 +20,9 @@ from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get
|
||||
from typing import Callable, Optional
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import mlx.core as mx
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
@@ -383,6 +386,7 @@ class ChatGPTAPI:
|
||||
stream = data.get("stream", False)
|
||||
model = data.get("model", "")
|
||||
prompt = data.get("prompt", "")
|
||||
image_url = data.get("image_url", "")
|
||||
print(f"model: {model}, prompt: {prompt}, stream: {stream}")
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
print(f"shard: {shard}")
|
||||
@@ -393,7 +397,11 @@ class ChatGPTAPI:
|
||||
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
||||
callback = self.node.on_token.register(callback_id)
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
|
||||
if image_url != "" and image_url != None:
|
||||
img = self.base64_decode(image_url)
|
||||
else:
|
||||
img = None
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
||||
|
||||
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
|
||||
@@ -454,3 +462,19 @@ class ChatGPTAPI:
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, host, port)
|
||||
await site.start()
|
||||
|
||||
def base64_decode(self, base64_string):
|
||||
#decode and reshape image
|
||||
if base64_string.startswith('data:image'):
|
||||
base64_string = base64_string.split(',')[1]
|
||||
image_data = base64.b64decode(base64_string)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
||||
if W != img.width or H != img.height:
|
||||
print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||
img = img[None]
|
||||
return img
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ class Model(nn.Module):
|
||||
self.config = config
|
||||
self.model_path = config.vae['path'].split('/vae')[0]
|
||||
self.shard = config.shard
|
||||
self.shard_clip, self.shard_unet, self.shard_vae = model_shards(config.shard)
|
||||
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
|
||||
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
|
||||
if self.shard_clip.start_layer != -1:
|
||||
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
|
||||
@@ -172,26 +172,41 @@ class Model(nn.Module):
|
||||
else:
|
||||
self.unet = nn.Identity()
|
||||
self.config_vae=VAEArgs.from_dict(config.vae['config'])
|
||||
if self.shard_vae.start_layer != -1:
|
||||
self.first_stage_model=Autoencoder(self.config_vae, self.shard_vae)
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
|
||||
else:
|
||||
self.first_stage_model = nn.Identity()
|
||||
self.encoder = nn.Identity()
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
|
||||
else:
|
||||
self.decoder = nn.Identity()
|
||||
|
||||
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False):
|
||||
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps)
|
||||
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.7, start_step=None):
|
||||
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||
is_finished = False
|
||||
is_step_finished = False
|
||||
if t.item()==1000:
|
||||
if self.shard_clip.start_layer == 0:
|
||||
conditioning = x
|
||||
if self.shard_clip.start_layer != -1:
|
||||
|
||||
conditioning, mask= self.text_encoder(conditioning,mask)
|
||||
seed = int(time.time())
|
||||
mx.random.seed(seed)
|
||||
if self.shard_unet.is_first_layer():
|
||||
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
|
||||
x_t_prev=x
|
||||
if image is None:
|
||||
if self.shard_encoder.is_last_layer():
|
||||
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
|
||||
x_t_prev=x
|
||||
start_step = self.sampler.max_time
|
||||
else:
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
image= self.encoder.encode(image)
|
||||
if self.shard_encoder.is_last_layer():
|
||||
start_step = self.sampler.max_time*strength
|
||||
total_steps = int(total_steps*strength)
|
||||
image = mx.broadcast_to(image, (1,) + image.shape[1:])
|
||||
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
|
||||
image = None
|
||||
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||
# Perform the denoising loop
|
||||
if self.shard_unet.start_layer != -1:
|
||||
with tqdm(total=total_steps,initial=step+1) as pbar:
|
||||
@@ -211,28 +226,32 @@ class Model(nn.Module):
|
||||
x_t_prev=x
|
||||
mx.eval(x)
|
||||
|
||||
if self.shard_vae.is_last_layer():
|
||||
if self.shard_decoder.is_last_layer():
|
||||
is_step_finished=True
|
||||
if self.shard_vae.start_layer != -1:
|
||||
x=self.first_stage_model.decode(x)
|
||||
if self.shard_vae.is_last_layer():
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
x=self.decoder.decode(x)
|
||||
if self.shard_decoder.is_last_layer():
|
||||
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(1 * H, B // 1 * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
if t_prev.item() ==0:
|
||||
is_finished=True
|
||||
mx.eval(x)
|
||||
|
||||
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps}
|
||||
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
|
||||
|
||||
|
||||
def load(self):
|
||||
if self.shard_vae.start_layer != -1:
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||
vae_weights = self.first_stage_model.sanitize(vae_weights)
|
||||
self.first_stage_model.load_weights(list(vae_weights.items()), strict=True)
|
||||
vae_weights = self.encoder.sanitize(vae_weights)
|
||||
self.encoder.load_weights(list(vae_weights.items()), strict=True)
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||
vae_weights = self.decoder.sanitize(vae_weights)
|
||||
self.decoder.load_weights(list(vae_weights.items()), strict=True)
|
||||
if self.shard_clip.start_layer != -1:
|
||||
clip_weights = mx.load(self.config_clip.weight_files[0])
|
||||
clip_weights = self.text_encoder.sanitize(clip_weights)
|
||||
@@ -242,7 +261,6 @@ class Model(nn.Module):
|
||||
unet_weights = self.unet.sanitize(unet_weights)
|
||||
self.unet.load_weights(list(unet_weights.items()), strict=True)
|
||||
|
||||
|
||||
def model_shards(shard:ShardConfig):
|
||||
def create_shard(shard, model_ranges):
|
||||
start_layer = shard.start_layer
|
||||
@@ -268,9 +286,10 @@ def model_shards(shard:ShardConfig):
|
||||
|
||||
# Define the ranges for different models
|
||||
model_ranges = {
|
||||
'clip': (0, 23),
|
||||
'unet':(23,32),
|
||||
'vae': (32, 37) # Example range for unet
|
||||
'clip': (0, 12),
|
||||
'vae_encoder':(12,17),
|
||||
'unet':(17,26),
|
||||
'vae_decoder': (26, 31) # Example range for unet
|
||||
}
|
||||
|
||||
# Call the function and get the shards for all models
|
||||
@@ -278,10 +297,11 @@ def model_shards(shard:ShardConfig):
|
||||
|
||||
# Access individual shards
|
||||
shard_clip = shards['clip']
|
||||
shard_encoder = shards['vae_encoder']
|
||||
shard_unet = shards['unet']
|
||||
shard_vae = shards['vae']
|
||||
shard_decoder = shards['vae_decoder']
|
||||
|
||||
return shard_clip, shard_unet, shard_vae
|
||||
return shard_clip, shard_encoder, shard_unet, shard_decoder
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -99,13 +100,15 @@ class CLIPTextModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.shard = shard
|
||||
|
||||
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
|
||||
if self.shard.is_first_layer():
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = []
|
||||
for i in range(config.num_layers):
|
||||
if self.shard.start_layer <= i <= self.shard.end_layer:
|
||||
for i in range(math.ceil(config.num_layers/2)):
|
||||
if 2*i in self.layers_range:
|
||||
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
|
||||
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
@@ -136,22 +139,18 @@ class CLIPTextModel(nn.Module):
|
||||
# Compute the features from the transformer
|
||||
mask = self._get_mask(N, x.dtype)
|
||||
|
||||
hidden_states = []
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
hidden_states.append(x)
|
||||
# Apply the final layernorm and return
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
x = self.final_layer_norm(x)
|
||||
last_hidden_state = x
|
||||
|
||||
|
||||
|
||||
|
||||
return x, mask
|
||||
def sanitize(self, weights):
|
||||
sanitized_weights = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
@@ -180,13 +179,13 @@ class CLIPTextModel(nn.Module):
|
||||
|
||||
if key.startswith("layers."):
|
||||
layer_num = int(key.split(".")[1])
|
||||
if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
|
||||
if layer_num not in self.layers_range:
|
||||
continue
|
||||
if not self.shard.start_layer == 0 and "embedding" in key:
|
||||
if not self.shard.is_first_layer() and "embedding" in key:
|
||||
continue
|
||||
if not self.shard.end_layer == 22 and key.startswith("final_layer_norm"):
|
||||
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
|
||||
continue
|
||||
if not self.shard.end_layer == 22 and key.startswith("text_projection"):
|
||||
if not self.shard.is_last_layer() and key.startswith("text_projection"):
|
||||
continue
|
||||
sanitized_weights[key] = value
|
||||
return sanitized_weights
|
||||
|
||||
@@ -128,62 +128,75 @@ class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
latent_channels_out: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
layers_range: List[int] = [],
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.layers_range = layers_range
|
||||
self.shard = shard
|
||||
if self.shard.is_first_layer():
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||
self.down_blocks = [
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=i < len(block_out_channels) - 1,
|
||||
add_upsample=False,
|
||||
self.down_blocks = []
|
||||
current_layer = 1
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
||||
if current_layer in self.layers_range:
|
||||
self.down_blocks.append(
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=i < len(block_out_channels) - 1,
|
||||
add_upsample=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_blocks.append(IdentityBlock())
|
||||
current_layer += 1
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||
]
|
||||
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
if self.shard.is_first_layer():
|
||||
x = self.conv_in(x)
|
||||
|
||||
for l in self.down_blocks:
|
||||
x = l(x)
|
||||
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
if self.shard.is_last_layer():
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -271,7 +284,7 @@ class Decoder(nn.Module):
|
||||
class Autoencoder(nn.Module):
|
||||
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||
|
||||
def __init__(self, config: AutoencoderConfig, shard: Shard):
|
||||
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
|
||||
super().__init__()
|
||||
self.shard = shard
|
||||
self.start_layer = shard.start_layer
|
||||
@@ -279,46 +292,51 @@ class Autoencoder(nn.Module):
|
||||
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
||||
self.latent_channels = config.latent_channels_in
|
||||
self.scaling_factor = config.scaling_factor
|
||||
self.decoder_only = True # stable diffusion text to speech only uses decoder from the autoencoder
|
||||
if not self.decoder_only:
|
||||
self.model_shard = model_shard
|
||||
if self.model_shard == "vae_encoder":
|
||||
self.encoder = Encoder(
|
||||
config.in_channels,
|
||||
config.latent_channels_out,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
layers_range=self.layers_range,
|
||||
shard=shard
|
||||
)
|
||||
self.quant_proj = nn.Linear(
|
||||
config.latent_channels_out, config.latent_channels_out
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
config.latent_channels_in,
|
||||
config.out_channels,
|
||||
shard,
|
||||
self.layers_range,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block + 1,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
if 0 in self.layers_range:
|
||||
self.post_quant_proj = nn.Linear(
|
||||
config.latent_channels_in, config.latent_channels_in
|
||||
if self.shard.is_last_layer():
|
||||
self.quant_proj = nn.Linear(
|
||||
config.latent_channels_out, config.latent_channels_out
|
||||
)
|
||||
if self.model_shard == "vae_decoder":
|
||||
self.decoder = Decoder(
|
||||
config.latent_channels_in,
|
||||
config.out_channels,
|
||||
shard,
|
||||
self.layers_range,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block + 1,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
if self.shard.is_first_layer():
|
||||
self.post_quant_proj = nn.Linear(
|
||||
config.latent_channels_in, config.latent_channels_in
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
if 0 in self.layers_range:
|
||||
if self.shard.is_first_layer():
|
||||
z = z / self.scaling_factor
|
||||
z=self.post_quant_proj(z)
|
||||
return self.decoder(z)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.quant_proj(x)
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
mean = mean * self.scaling_factor
|
||||
logvar = logvar + 2 * math.log(self.scaling_factor)
|
||||
|
||||
return mean, logvar
|
||||
if self.shard.is_last_layer():
|
||||
x = self.quant_proj(x)
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
mean = mean * self.scaling_factor
|
||||
logvar = logvar + 2 * math.log(self.scaling_factor)
|
||||
x = mean
|
||||
return x
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
mean, logvar = self.encode(x)
|
||||
@@ -328,46 +346,53 @@ class Autoencoder(nn.Module):
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard = self.shard
|
||||
layers = self.layers_range
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
if 'decoder' in key and self.decoder_only:
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map attention layers
|
||||
if "key" in key:
|
||||
key = key.replace("key", "key_proj")
|
||||
if "proj_attn" in key:
|
||||
key = key.replace("proj_attn", "out_proj")
|
||||
if "query" in key:
|
||||
key = key.replace("query", "query_proj")
|
||||
if "value" in key:
|
||||
key = key.replace("value", "value_proj")
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map the quant/post_quant layers
|
||||
if "quant_conv" in key:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
# Map the conv_shortcut to linear
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
# Map attention layers
|
||||
if "key" in key:
|
||||
key = key.replace("key", "key_proj")
|
||||
if "proj_attn" in key:
|
||||
key = key.replace("proj_attn", "out_proj")
|
||||
if "query" in key:
|
||||
key = key.replace("query", "query_proj")
|
||||
if "value" in key:
|
||||
key = key.replace("value", "value_proj")
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map the quant/post_quant layers
|
||||
if "quant_conv" in key:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
# Map the conv_shortcut to linear
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
|
||||
|
||||
if "post_quant_conv" in key :
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
if 'decoder' in key and self.model_shard == "vae_decoder":
|
||||
if key.startswith("decoder.mid_blocks."):
|
||||
if 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
@@ -381,10 +406,24 @@ class Autoencoder(nn.Module):
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("decoder.conv_out") and 4 in layers:
|
||||
sanitized_weights[key] = value
|
||||
|
||||
if "post_quant_conv" in key and 0 in layers:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
sanitized_weights[key] = value
|
||||
if self.model_shard == "vae_decoder":
|
||||
if key.startswith("post_quant_proj") and 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if self.model_shard == "vae_encoder":
|
||||
if key.startswith("encoder."):
|
||||
if "conv_in" in key and shard.is_first_layer():
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("encoder.down_blocks."):
|
||||
layer_num = int(key.split(".")[2])+1
|
||||
if layer_num in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if "conv_norm_out" in key and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if "conv_out" in key and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("quant_proj") and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
return sanitized_weights
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ model_cards = {
|
||||
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
|
||||
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
||||
# stable diffusion
|
||||
"stable-diffusion-2-1-base": { "layers": 37, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
|
||||
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
|
||||
# dummy
|
||||
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.PromptRequest(
|
||||
prompt=prompt,
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -78,6 +78,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
n_layers=shard.n_layers,
|
||||
),
|
||||
request_id=request_id,
|
||||
inference_state=self.serialize_inference_state(inference_state)
|
||||
)
|
||||
response = await self.stub.SendPrompt(request)
|
||||
|
||||
|
||||
@@ -52,7 +52,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
)
|
||||
prompt = request.prompt
|
||||
request_id = request.request_id
|
||||
result = await self.node.process_prompt(shard, prompt, request_id)
|
||||
inference_state = self.deserialize_inference_state(request.inference_state)
|
||||
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
|
||||
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
@@ -23,6 +23,7 @@ message PromptRequest {
|
||||
Shard shard = 1;
|
||||
string prompt = 2;
|
||||
optional string request_id = 3;
|
||||
optional InferenceState inference_state = 4;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: exo/networking/grpc/node_service.proto
|
||||
# source: node_service.proto
|
||||
# Protobuf Python Version: 5.26.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
@@ -14,11 +14,11 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&exo/networking/grpc/node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x82\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xb4\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x82\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xb4\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'exo.networking.grpc.node_service_pb2', _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None
|
||||
@@ -29,50 +29,50 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
|
||||
_globals['_SHARD']._serialized_start=56
|
||||
_globals['_SHARD']._serialized_end=139
|
||||
_globals['_PROMPTREQUEST']._serialized_start=141
|
||||
_globals['_PROMPTREQUEST']._serialized_end=248
|
||||
_globals['_TENSORREQUEST']._serialized_start=251
|
||||
_globals['_TENSORREQUEST']._serialized_end=460
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=462
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=509
|
||||
_globals['_INFERENCERESULT']._serialized_start=511
|
||||
_globals['_INFERENCERESULT']._serialized_end=603
|
||||
_globals['_TENSOR']._serialized_start=605
|
||||
_globals['_TENSOR']._serialized_end=664
|
||||
_globals['_TENSORLIST']._serialized_start=666
|
||||
_globals['_TENSORLIST']._serialized_end=717
|
||||
_globals['_INFERENCESTATE']._serialized_start=720
|
||||
_globals['_INFERENCESTATE']._serialized_end=1058
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=906
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=977
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=979
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1058
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1060
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1120
|
||||
_globals['_TOPOLOGY']._serialized_start=1123
|
||||
_globals['_TOPOLOGY']._serialized_end=1393
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1244
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1322
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1324
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1393
|
||||
_globals['_PEERS']._serialized_start=1395
|
||||
_globals['_PEERS']._serialized_end=1420
|
||||
_globals['_DEVICEFLOPS']._serialized_start=1422
|
||||
_globals['_DEVICEFLOPS']._serialized_end=1477
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=1479
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=1586
|
||||
_globals['_SENDRESULTREQUEST']._serialized_start=1589
|
||||
_globals['_SENDRESULTREQUEST']._serialized_end=1719
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1721
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1782
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=1784
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=1804
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=1806
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1847
|
||||
_globals['_EMPTY']._serialized_start=1849
|
||||
_globals['_EMPTY']._serialized_end=1856
|
||||
_globals['_NODESERVICE']._serialized_start=1859
|
||||
_globals['_NODESERVICE']._serialized_end=2423
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=122
|
||||
_globals['_PROMPTREQUEST']._serialized_end=309
|
||||
_globals['_TENSORREQUEST']._serialized_start=312
|
||||
_globals['_TENSORREQUEST']._serialized_end=521
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=523
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=570
|
||||
_globals['_INFERENCERESULT']._serialized_start=572
|
||||
_globals['_INFERENCERESULT']._serialized_end=664
|
||||
_globals['_TENSOR']._serialized_start=666
|
||||
_globals['_TENSOR']._serialized_end=725
|
||||
_globals['_TENSORLIST']._serialized_start=727
|
||||
_globals['_TENSORLIST']._serialized_end=778
|
||||
_globals['_INFERENCESTATE']._serialized_start=781
|
||||
_globals['_INFERENCESTATE']._serialized_end=1119
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=967
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1038
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1040
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1119
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1121
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1181
|
||||
_globals['_TOPOLOGY']._serialized_start=1184
|
||||
_globals['_TOPOLOGY']._serialized_end=1454
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1305
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1383
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1385
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1454
|
||||
_globals['_PEERS']._serialized_start=1456
|
||||
_globals['_PEERS']._serialized_end=1481
|
||||
_globals['_DEVICEFLOPS']._serialized_start=1483
|
||||
_globals['_DEVICEFLOPS']._serialized_end=1538
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=1540
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=1647
|
||||
_globals['_SENDRESULTREQUEST']._serialized_start=1650
|
||||
_globals['_SENDRESULTREQUEST']._serialized_end=1780
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1782
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1843
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=1845
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=1865
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=1867
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1908
|
||||
_globals['_EMPTY']._serialized_start=1910
|
||||
_globals['_EMPTY']._serialized_end=1917
|
||||
_globals['_NODESERVICE']._serialized_start=1920
|
||||
_globals['_NODESERVICE']._serialized_end=2484
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
|
||||
from exo.networking.grpc import node_service_pb2 as node__service__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.64.1'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
@@ -20,7 +20,7 @@ except ImportError:
|
||||
if _version_not_supported:
|
||||
warnings.warn(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
|
||||
+ f' but the generated code in node_service_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
@@ -41,38 +41,38 @@ class NodeServiceStub(object):
|
||||
"""
|
||||
self.SendPrompt = channel.unary_unary(
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTensor = channel.unary_unary(
|
||||
'/node_service.NodeService/SendTensor',
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.GetInferenceResult = channel.unary_unary(
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
|
||||
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.InferenceResult.FromString,
|
||||
_registered_method=True)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Topology.FromString,
|
||||
_registered_method=True)
|
||||
self.SendResult = channel.unary_unary(
|
||||
'/node_service.NodeService/SendResult',
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendOpaqueStatus = channel.unary_unary(
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.HealthCheck = channel.unary_unary(
|
||||
'/node_service.NodeService/HealthCheck',
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
|
||||
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
@@ -126,38 +126,38 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendPrompt': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPrompt,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
|
||||
request_deserializer=node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendTensor': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTensor,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
|
||||
request_deserializer=node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetInferenceResult,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
|
||||
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
|
||||
),
|
||||
'CollectTopology': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=node__service__pb2.Topology.SerializeToString,
|
||||
),
|
||||
'SendResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendResult,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
|
||||
request_deserializer=node__service__pb2.SendResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendOpaqueStatus,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
|
||||
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'HealthCheck': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.HealthCheck,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
|
||||
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
@@ -185,8 +185,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
node__service__pb2.PromptRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -212,8 +212,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendTensor',
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
node__service__pb2.TensorRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -239,8 +239,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
|
||||
node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
node__service__pb2.InferenceResult.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -266,8 +266,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
|
||||
node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
node__service__pb2.Topology.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -293,8 +293,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendResult',
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
node__service__pb2.SendResultRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -320,8 +320,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -347,8 +347,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/HealthCheck',
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
|
||||
node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
node__service__pb2.HealthCheckResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
|
||||
@@ -16,11 +16,11 @@ class Node(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -190,7 +190,7 @@ class StandardNode(Node):
|
||||
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
||||
if not shard.is_first_layer():
|
||||
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
||||
resp = await self.forward_prompt(shard, prompt, request_id, 0)
|
||||
resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
|
||||
return None
|
||||
else:
|
||||
result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
|
||||
@@ -268,6 +268,7 @@ class StandardNode(Node):
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
target_index: int,
|
||||
inference_state: Optional[dict] = None,
|
||||
) -> None:
|
||||
if DEBUG >= 1: print(f"target partition index: {target_index}")
|
||||
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
||||
@@ -280,7 +281,7 @@ class StandardNode(Node):
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer for {target_index} not found")
|
||||
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
||||
await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
|
||||
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
|
||||
|
||||
async def forward_tensor(
|
||||
self,
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 407 KiB |
@@ -1,3 +0,0 @@
|
||||
# images dir
|
||||
|
||||
Images generated in tinychat are stored and served from here.
|
||||
@@ -120,6 +120,16 @@
|
||||
const img = document.createElement('img');
|
||||
img.src = imageUrl;
|
||||
img.alt = 'Generated Image';
|
||||
img.onclick = async () => {
|
||||
try {
|
||||
const response = await fetch(img.src);
|
||||
const blob = await response.blob();
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
handleImageUpload({ target: { files: [file] } });
|
||||
} catch (error) {
|
||||
console.error('Error fetching image:', error);
|
||||
}
|
||||
};
|
||||
div.appendChild(img);
|
||||
} else {
|
||||
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
|
||||
@@ -207,7 +217,7 @@
|
||||
</span>
|
||||
</div>
|
||||
<div class="input">
|
||||
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
|
||||
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
|
||||
<i class="fas fa-image"></i>
|
||||
</button>
|
||||
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
|
||||
|
||||
@@ -243,6 +243,7 @@ document.addEventListener("alpine:init", () => {
|
||||
body: JSON.stringify({
|
||||
"model": 'stable-diffusion-2-1-base',
|
||||
"prompt": apiMessages[apiMessages.length - 1].content,
|
||||
"image_url": this.imageUrl
|
||||
}),
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user