mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
shareded inference
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -83,6 +83,7 @@ target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
Untitled.ipynb
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Union, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.base import KVCache
|
||||
from mlx_lm.models.base import BaseModelArgs, KVCache
|
||||
from exo.inference.shard import Shard
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,15 +39,15 @@ class VisionConfig:
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -206,7 +203,7 @@ class VisionModel(nn.Module):
|
||||
self.vision_model = ClipVisionModel(config)
|
||||
|
||||
def __call__(
|
||||
self, x: mx.array, output_hidden_states: Optional[bool] = None
|
||||
self, x: mx.array, output_hidden_states: Optional[bool] = None
|
||||
) -> mx.array:
|
||||
return self.vision_model(x, output_hidden_states)
|
||||
|
||||
@@ -228,6 +225,7 @@ class VisionModel(nn.Module):
|
||||
|
||||
return sanitized_weights
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextConfig:
|
||||
model_type: str
|
||||
@@ -235,10 +233,10 @@ class TextConfig:
|
||||
num_hidden_layers: int = 32
|
||||
intermediate_size: int = 11008
|
||||
num_attention_heads: int = 32
|
||||
head_dim: int = None
|
||||
rms_norm_eps: float = 1e-6
|
||||
vocab_size: int = 32000
|
||||
n_kv_heads: int = None
|
||||
head_dim: Optional[int] = None
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
@@ -254,12 +252,15 @@ class TextConfig:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_kv_heads is None:
|
||||
self.n_kv_heads = self.num_attention_heads
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
|
||||
if self.model_type is None:
|
||||
self.model_type = "llama"
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
@@ -275,12 +276,12 @@ class TextAttention(nn.Module):
|
||||
|
||||
dim = config.hidden_size
|
||||
self.n_heads = n_heads = config.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = config.n_kv_heads
|
||||
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
head_dim = config.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
@@ -290,7 +291,7 @@ class TextAttention(nn.Module):
|
||||
rope_scale = (
|
||||
1 / config.rope_scaling["factor"]
|
||||
if config.rope_scaling is not None
|
||||
and config.rope_scaling["type"] == "linear"
|
||||
and config.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
@@ -301,10 +302,10 @@ class TextAttention(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -355,10 +356,10 @@ class TransformerBlock(nn.Module):
|
||||
self.config = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -368,12 +369,15 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_first_layer = is_first_layer
|
||||
self.is_last_layer = is_last_layer
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model_type = config.model_type
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.n_kv_heads = config.n_kv_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
@@ -383,14 +387,17 @@ class Llama(nn.Module):
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
inputs_embeds=None,
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
# for passing merged input embeddings
|
||||
if inputs_embeds is None:
|
||||
h = self.embed_tokens(inputs)
|
||||
if self.is_first_layer:
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
else:
|
||||
h = inputs_embeds
|
||||
|
||||
@@ -406,18 +413,20 @@ class Llama(nn.Module):
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
if self.is_last_layer:
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
if self.model_type != "llama":
|
||||
raise ValueError(
|
||||
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
|
||||
)
|
||||
self.model = Llama(config)
|
||||
self.is_last_layer = is_last_layer
|
||||
self.model = Llama(config, is_first_layer, is_last_layer)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
@@ -427,7 +436,9 @@ class LanguageModel(nn.Module):
|
||||
inputs_embeds=None,
|
||||
):
|
||||
out = self.model(inputs, cache, inputs_embeds)
|
||||
return self.lm_head(out)
|
||||
if self.is_last_layer:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def sanitize(weights):
|
||||
@@ -436,11 +447,10 @@ class LanguageModel(nn.Module):
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlaVAConfig:
|
||||
class LlaVAConfig(BaseModelArgs):
|
||||
text_config: TextConfig
|
||||
vision_config: VisionConfig
|
||||
vision_config: VisionConfig = None
|
||||
model_type: str = "llava"
|
||||
ignore_index: int = -100
|
||||
image_token_index: int = 32000
|
||||
@@ -450,13 +460,32 @@ class LlaVAConfig:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
updated_params = {}
|
||||
class_params = inspect.signature(cls).parameters
|
||||
for k, v in params.items():
|
||||
if k in class_params:
|
||||
if k in ["text_config", "vision_config"]:
|
||||
v = class_params[k].annotation.from_dict(v)
|
||||
updated_params.update({k: v})
|
||||
|
||||
return cls(**updated_params)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(LlaVAConfig):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
if not self.shard.is_first_layer():
|
||||
self.vision_config = None
|
||||
|
||||
self.text_config.num_hidden_layers = self.shard.get_layer_count()
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
@@ -477,19 +506,22 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class LlavaModel(nn.Module):
|
||||
def __init__(self, config: LlaVAConfig):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vision_tower = VisionModel(config.vision_config)
|
||||
self.language_model = LanguageModel(config.text_config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.vision_feature_layer = config.vision_feature_layer
|
||||
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
||||
self.model_type = config.model_type
|
||||
if config.vision_config:
|
||||
self.vision_tower = VisionModel(config.vision_config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.vision_feature_layer = config.vision_feature_layer
|
||||
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
||||
self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer())
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Optional[mx.array] = None,
|
||||
pixel_values: Optional[mx.array] = None,
|
||||
self,
|
||||
input_ids: Optional[mx.array] = None,
|
||||
pixel_values: Optional[mx.array] = None,
|
||||
):
|
||||
if pixel_values is None:
|
||||
return self.language_model(input_ids)
|
||||
@@ -525,7 +557,7 @@ class LlavaModel(nn.Module):
|
||||
return final_inputs_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids
|
||||
self, image_features, inputs_embeds, input_ids
|
||||
):
|
||||
image_token_index = self.config.image_token_index
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
@@ -554,49 +586,32 @@ class LlavaModel(nn.Module):
|
||||
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||
return mx.concatenate(final_embeddings, axis=1)
|
||||
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
|
||||
input_embddings = None
|
||||
if pixel_values is not None:
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
logits = self.language_model(
|
||||
input_ids, cache=cache, inputs_embeds=input_embddings
|
||||
)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path_or_hf_repo: str):
|
||||
path = Path(path_or_hf_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
],
|
||||
)
|
||||
)
|
||||
def sanitize(self, weights):
|
||||
if self.config.vision_config:
|
||||
weights = self.vision_tower.sanitize(weights)
|
||||
weights = self.language_model.sanitize(weights)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
model_config = json.load(f)
|
||||
return weights
|
||||
|
||||
model_config = LlaVAConfig.from_dict(model_config)
|
||||
@property
|
||||
def layers(self):
|
||||
return self.language_model.model.layers
|
||||
|
||||
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
|
||||
model_config.text_config = TextConfig.from_dict(model_config.text_config)
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
|
||||
)
|
||||
|
||||
model = LlavaModel(model_config)
|
||||
weight_files = glob.glob(str(path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
weights = VisionModel.sanitize(weights)
|
||||
weights = LanguageModel.sanitize(weights)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
return model
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.language_model.model.num_key_value_heads
|
||||
|
||||
@@ -15,7 +15,8 @@ class StatefulShardedModel:
|
||||
|
||||
def step(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
pixel_values=None,
|
||||
temp: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
@@ -36,9 +37,11 @@ class StatefulShardedModel:
|
||||
|
||||
return token
|
||||
|
||||
y = x
|
||||
|
||||
output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
|
||||
# TODO : revert hacky fix
|
||||
if pixel_values is None:
|
||||
output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
|
||||
else:
|
||||
output = self.model(y, pixel_values=pixel_values, cache=self.cache)
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
logits = output[:, -1, :]
|
||||
@@ -57,14 +60,9 @@ class StatefulShardedModel:
|
||||
return self.step(x, temp, top_p, logit_bias)
|
||||
|
||||
def reset(self):
|
||||
if hasattr(self.model.config, "vision_config"):
|
||||
model = self.model.language_model.model
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
[self.model.n_kv_heads] * len(self.model.layers)
|
||||
if isinstance(self.model.n_kv_heads, int)
|
||||
else self.model.n_kv_heads
|
||||
)
|
||||
self.cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
||||
|
||||
@@ -19,7 +19,6 @@ from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
|
||||
from mlx_lm.tuner.utils import apply_lora_layers
|
||||
|
||||
from ..shard import Shard
|
||||
from exo.inference.mlx.models.sharded_llava import LlavaModel, LlaVAConfig, VisionConfig, VisionModel, TextConfig, LanguageModel
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
@@ -29,6 +28,7 @@ class ModelNotFoundError(Exception):
|
||||
MODEL_REMAPPING = {
|
||||
"sharded_mistral": "sharded_llama", # mistral is compatible with llama
|
||||
"sharded_phi-msft": "sharded_phixtral",
|
||||
"sharded_llava": "sharded_llava"
|
||||
}
|
||||
|
||||
def _get_classes(config: dict):
|
||||
@@ -113,6 +113,7 @@ def load_model_shard(
|
||||
for wf in weight_files:
|
||||
weights_dict = mx.load(wf)
|
||||
all_weights_keys.update(weights_dict.keys())
|
||||
weights.update({k: v for k, v in weights_dict.items() if not k.startswith("language_model.model.layers.") or shard.start_layer <= int(k.split('.')[3]) <= shard.end_layer})
|
||||
weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer})
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
@@ -137,6 +138,11 @@ def load_model_shard(
|
||||
if shard.start_layer <= layer_num <= shard.end_layer:
|
||||
new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
|
||||
filtered_weights[new_key] = v
|
||||
elif k.startswith("language_model.model.layers."):
|
||||
layer_num = int(k.split('.')[3])
|
||||
if shard.start_layer <= layer_num <= shard.end_layer:
|
||||
new_key = f"language_model.model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[4:])
|
||||
filtered_weights[new_key] = v
|
||||
else:
|
||||
filtered_weights[k] = v
|
||||
weights = filtered_weights
|
||||
@@ -228,62 +234,11 @@ async def load_shard(
|
||||
if adapter_path is not None:
|
||||
model = apply_lora_layers(model, adapter_path)
|
||||
model.eval()
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
async def load_shard_llava(
|
||||
path_or_hf_repo: str,
|
||||
shard: Shard,
|
||||
tokenizer_config={},
|
||||
model_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
model_config(dict, optional): Configuration parameters specifically for the model.
|
||||
Defaults to an empty dictionary.
|
||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||
to the model. Default: ``None``.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
Returns:
|
||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
"""
|
||||
model_path = await get_model_path(path_or_hf_repo)
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
model_config = LlaVAConfig.from_dict(model_config)
|
||||
|
||||
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
|
||||
model_config.text_config = TextConfig.from_dict(model_config.text_config)
|
||||
|
||||
model = LlavaModel(model_config)
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
weights = VisionModel.sanitize(weights)
|
||||
weights = LanguageModel.sanitize(weights)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
return model, processor
|
||||
# TODO: figure out a better way
|
||||
if "llama" in str(model_path):
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
return model, tokenizer
|
||||
elif "llava" in str(model_path):
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
return model, processor
|
||||
@@ -9,42 +9,20 @@ import mlx.core as mx
|
||||
from mlx_lm.models.base import KVCache
|
||||
|
||||
from exo.inference.mlx.sharded_model import StatefulShardedModel
|
||||
from exo.inference.mlx.sharded_utils import load_shard_llava
|
||||
from exo.inference.mlx.sharded_utils import load_shard
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
def sample(logits, temperature=0.0):
|
||||
if temperature == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temperature))
|
||||
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
|
||||
kv_heads = (
|
||||
[model.language_model.model.n_kv_heads] * len(model.language_model.model.layers)
|
||||
if isinstance(model.language_model.model.n_kv_heads, int)
|
||||
else model.language_model.model.n_kv_heads
|
||||
)
|
||||
cache = [KVCache(model.language_model.model.head_dim, n) for n in kv_heads]
|
||||
logits = model(input_ids, pixel_values, cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits, temperature=temperature)
|
||||
tokens = [y.item()]
|
||||
|
||||
for n in range(max_tokens - 1):
|
||||
logits = model.language_model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits, temperature)
|
||||
token = y.item()
|
||||
if token == processor.tokenizer.eos_token_id:
|
||||
break
|
||||
tokens.append(token)
|
||||
|
||||
return processor.tokenizer.decode(tokens)
|
||||
|
||||
shard_full = Shard("llava", 0, 31, 32)
|
||||
shard1 = Shard("llava", 0, 12, 32)
|
||||
shard2 = Shard("llava", 13, 31, 32)
|
||||
|
||||
full_model_shard, full_processor = asyncio.run(load_shard_llava("llava-hf/llava-1.5-7b-hf", shard=shard_full))
|
||||
full_model_shard, full_processor = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard_full))
|
||||
model_shard1, processor1 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard1))
|
||||
model_shard2, processor2 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard2))
|
||||
|
||||
full = StatefulShardedModel(shard_full, full_model_shard)
|
||||
m1 = StatefulShardedModel(shard1, model_shard1)
|
||||
m2 = StatefulShardedModel(shard2, model_shard2)
|
||||
|
||||
PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
|
||||
IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
@@ -56,7 +34,30 @@ pixel_values = mx.array(inputs["pixel_values"])
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
|
||||
print(prompt)
|
||||
generated_text = generate_text(
|
||||
input_ids, pixel_values, full_model_shard, full_processor, 10, 0
|
||||
)
|
||||
print(generated_text)
|
||||
y = full.step(input_ids, pixel_values, temp=0)
|
||||
full_generated_tokens = [y.item()]
|
||||
|
||||
for _ in range(13):
|
||||
y = full.step(y, temp=0)
|
||||
full_generated_tokens.append(y.item())
|
||||
|
||||
full_response = full_processor.tokenizer.decode(full_generated_tokens)
|
||||
print("full response:", full_response)
|
||||
|
||||
inputs = processor1(prompt, img, return_tensors="np")
|
||||
pixel_values = mx.array(inputs["pixel_values"])
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
|
||||
y = m1.step(input_ids, pixel_values, temp=0)
|
||||
y = m2.step(y, temp=0)
|
||||
full_generated_tokens = [y.item()]
|
||||
|
||||
for _ in range(13):
|
||||
y = m1.step(y, temp=0)
|
||||
y = m2.step(y, temp=0)
|
||||
full_generated_tokens.append(y.item())
|
||||
|
||||
sharded_response = processor2.tokenizer.decode(full_generated_tokens)
|
||||
print("sharded response:", sharded_response)
|
||||
|
||||
assert full_response == sharded_response
|
||||
@@ -13,6 +13,9 @@ class Shard:
|
||||
def is_last_layer(self) -> bool:
|
||||
return self.end_layer == self.n_layers - 1
|
||||
|
||||
def get_layer_count(self) -> int:
|
||||
return self.end_layer - self.start_layer + 1
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
|
||||
Reference in New Issue
Block a user