shareded inference

This commit is contained in:
Varshith
2024-07-28 00:30:34 +05:30
parent 7cbf6a35bd
commit 9d2616b9cf
6 changed files with 181 additions and 208 deletions

1
.gitignore vendored
View File

@@ -83,6 +83,7 @@ target/
# Jupyter Notebook
.ipynb_checkpoints
Untitled.ipynb
# IPython
profile_default/

View File

@@ -1,18 +1,15 @@
# Copyright © 2024 Apple Inc.
import math
import glob
import inspect
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Dict, Union, Tuple
from dataclasses import dataclass, field
from typing import Optional, Dict, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import KVCache
from mlx_lm.models.base import BaseModelArgs, KVCache
from exo.inference.shard import Shard
import numpy as np
from huggingface_hub import snapshot_download
@dataclass
@@ -42,15 +39,15 @@ class VisionConfig:
class VisionAttention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
super().__init__()
@@ -206,7 +203,7 @@ class VisionModel(nn.Module):
self.vision_model = ClipVisionModel(config)
def __call__(
self, x: mx.array, output_hidden_states: Optional[bool] = None
self, x: mx.array, output_hidden_states: Optional[bool] = None
) -> mx.array:
return self.vision_model(x, output_hidden_states)
@@ -228,6 +225,7 @@ class VisionModel(nn.Module):
return sanitized_weights
@dataclass
class TextConfig:
model_type: str
@@ -235,10 +233,10 @@ class TextConfig:
num_hidden_layers: int = 32
intermediate_size: int = 11008
num_attention_heads: int = 32
head_dim: int = None
rms_norm_eps: float = 1e-6
vocab_size: int = 32000
n_kv_heads: int = None
head_dim: Optional[int] = None
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@@ -254,12 +252,15 @@ class TextConfig:
)
def __post_init__(self):
if self.n_kv_heads is None:
self.n_kv_heads = self.num_attention_heads
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.head_dim is None:
self.head_dim = self.hidden_size // self.num_attention_heads
if self.model_type is None:
self.model_type = "llama"
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
@@ -275,12 +276,12 @@ class TextAttention(nn.Module):
dim = config.hidden_size
self.n_heads = n_heads = config.num_attention_heads
self.n_kv_heads = n_kv_heads = config.n_kv_heads
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = config.hidden_size // n_heads
self.scale = head_dim**-0.5
self.scale = head_dim ** -0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
@@ -290,7 +291,7 @@ class TextAttention(nn.Module):
rope_scale = (
1 / config.rope_scaling["factor"]
if config.rope_scaling is not None
and config.rope_scaling["type"] == "linear"
and config.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
@@ -301,10 +302,10 @@ class TextAttention(nn.Module):
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@@ -355,10 +356,10 @@ class TransformerBlock(nn.Module):
self.config = config
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -368,12 +369,15 @@ class TransformerBlock(nn.Module):
class Llama(nn.Module):
def __init__(self, config: TextConfig):
def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
super().__init__()
self.config = config
self.is_first_layer = is_first_layer
self.is_last_layer = is_last_layer
self.vocab_size = config.vocab_size
self.model_type = config.model_type
self.num_hidden_layers = config.num_hidden_layers
self.n_kv_heads = config.n_kv_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
@@ -383,14 +387,17 @@ class Llama(nn.Module):
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
):
# for passing merged input embeddings
if inputs_embeds is None:
h = self.embed_tokens(inputs)
if self.is_first_layer:
h = self.embed_tokens(inputs)
else:
h = inputs
else:
h = inputs_embeds
@@ -406,18 +413,20 @@ class Llama(nn.Module):
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
if self.is_last_layer:
h = self.norm(h)
return h
class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
super().__init__()
self.model_type = config.model_type
if self.model_type != "llama":
raise ValueError(
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
)
self.model = Llama(config)
self.is_last_layer = is_last_layer
self.model = Llama(config, is_first_layer, is_last_layer)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
@@ -427,7 +436,9 @@ class LanguageModel(nn.Module):
inputs_embeds=None,
):
out = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out)
if self.is_last_layer:
out = self.lm_head(out)
return out
@staticmethod
def sanitize(weights):
@@ -436,11 +447,10 @@ class LanguageModel(nn.Module):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@dataclass
class LlaVAConfig:
class LlaVAConfig(BaseModelArgs):
text_config: TextConfig
vision_config: VisionConfig
vision_config: VisionConfig = None
model_type: str = "llava"
ignore_index: int = -100
image_token_index: int = 32000
@@ -450,13 +460,32 @@ class LlaVAConfig:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
updated_params = {}
class_params = inspect.signature(cls).parameters
for k, v in params.items():
if k in class_params:
if k in ["text_config", "vision_config"]:
v = class_params[k].annotation.from_dict(v)
updated_params.update({k: v})
return cls(**updated_params)
@dataclass
class ModelArgs(LlaVAConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
self.text_config.num_hidden_layers = self.shard.get_layer_count()
class LlavaMultiModalProjector(nn.Module):
@@ -477,19 +506,22 @@ class LlavaMultiModalProjector(nn.Module):
return x
class LlavaModel(nn.Module):
def __init__(self, config: LlaVAConfig):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.vision_tower = VisionModel(config.vision_config)
self.language_model = LanguageModel(config.text_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy
self.model_type = config.model_type
if config.vision_config:
self.vision_tower = VisionModel(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy
self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer())
def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
@@ -525,7 +557,7 @@ class LlavaModel(nn.Module):
return final_inputs_embeds
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape
@@ -554,49 +586,32 @@ class LlavaModel(nn.Module):
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
input_embddings = None
if pixel_values is not None:
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
input_ids, cache=cache, inputs_embeds=input_embddings
)
return logits
@staticmethod
def from_pretrained(path_or_hf_repo: str):
path = Path(path_or_hf_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
],
)
)
def sanitize(self, weights):
if self.config.vision_config:
weights = self.vision_tower.sanitize(weights)
weights = self.language_model.sanitize(weights)
with open(path / "config.json", "r") as f:
model_config = json.load(f)
return weights
model_config = LlaVAConfig.from_dict(model_config)
@property
def layers(self):
return self.language_model.model.layers
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
model_config.text_config = TextConfig.from_dict(model_config.text_config)
@property
def head_dim(self):
return (
self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
)
model = LlavaModel(model_config)
weight_files = glob.glob(str(path / "*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
weights = VisionModel.sanitize(weights)
weights = LanguageModel.sanitize(weights)
model.load_weights(list(weights.items()))
return model
@property
def n_kv_heads(self):
return self.language_model.model.num_key_value_heads

View File

@@ -15,7 +15,8 @@ class StatefulShardedModel:
def step(
self,
x,
y,
pixel_values=None,
temp: float = 0.0,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
@@ -36,9 +37,11 @@ class StatefulShardedModel:
return token
y = x
output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
# TODO : revert hacky fix
if pixel_values is None:
output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
else:
output = self.model(y, pixel_values=pixel_values, cache=self.cache)
if self.shard.is_last_layer():
logits = output[:, -1, :]
@@ -57,14 +60,9 @@ class StatefulShardedModel:
return self.step(x, temp, top_p, logit_bias)
def reset(self):
if hasattr(self.model.config, "vision_config"):
model = self.model.language_model.model
else:
model = self.model
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
self.cache = [KVCache(model.head_dim, n) for n in kv_heads]
self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]

View File

@@ -19,7 +19,6 @@ from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
from mlx_lm.tuner.utils import apply_lora_layers
from ..shard import Shard
from exo.inference.mlx.models.sharded_llava import LlavaModel, LlaVAConfig, VisionConfig, VisionModel, TextConfig, LanguageModel
class ModelNotFoundError(Exception):
def __init__(self, message):
@@ -29,6 +28,7 @@ class ModelNotFoundError(Exception):
MODEL_REMAPPING = {
"sharded_mistral": "sharded_llama", # mistral is compatible with llama
"sharded_phi-msft": "sharded_phixtral",
"sharded_llava": "sharded_llava"
}
def _get_classes(config: dict):
@@ -113,6 +113,7 @@ def load_model_shard(
for wf in weight_files:
weights_dict = mx.load(wf)
all_weights_keys.update(weights_dict.keys())
weights.update({k: v for k, v in weights_dict.items() if not k.startswith("language_model.model.layers.") or shard.start_layer <= int(k.split('.')[3]) <= shard.end_layer})
weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer})
model_class, model_args_class = _get_classes(config=config)
@@ -137,6 +138,11 @@ def load_model_shard(
if shard.start_layer <= layer_num <= shard.end_layer:
new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
filtered_weights[new_key] = v
elif k.startswith("language_model.model.layers."):
layer_num = int(k.split('.')[3])
if shard.start_layer <= layer_num <= shard.end_layer:
new_key = f"language_model.model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[4:])
filtered_weights[new_key] = v
else:
filtered_weights[k] = v
weights = filtered_weights
@@ -228,62 +234,11 @@ async def load_shard(
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
return model, tokenizer
async def load_shard_llava(
path_or_hf_repo: str,
shard: Shard,
tokenizer_config={},
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = await get_model_path(path_or_hf_repo)
processor = AutoProcessor.from_pretrained(model_path)
with open(model_path / "config.json", "r") as f:
model_config = json.load(f)
model_config = LlaVAConfig.from_dict(model_config)
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
model_config.text_config = TextConfig.from_dict(model_config.text_config)
model = LlavaModel(model_config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
weights = VisionModel.sanitize(weights)
weights = LanguageModel.sanitize(weights)
model.load_weights(list(weights.items()))
return model, processor
# TODO: figure out a better way
if "llama" in str(model_path):
tokenizer = load_tokenizer(model_path, tokenizer_config)
return model, tokenizer
elif "llava" in str(model_path):
processor = AutoProcessor.from_pretrained(model_path)
return model, processor

View File

@@ -9,42 +9,20 @@ import mlx.core as mx
from mlx_lm.models.base import KVCache
from exo.inference.mlx.sharded_model import StatefulShardedModel
from exo.inference.mlx.sharded_utils import load_shard_llava
from exo.inference.mlx.sharded_utils import load_shard
from exo.inference.shard import Shard
def sample(logits, temperature=0.0):
if temperature == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temperature))
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
kv_heads = (
[model.language_model.model.n_kv_heads] * len(model.language_model.model.layers)
if isinstance(model.language_model.model.n_kv_heads, int)
else model.language_model.model.n_kv_heads
)
cache = [KVCache(model.language_model.model.head_dim, n) for n in kv_heads]
logits = model(input_ids, pixel_values, cache=cache)
logits = logits[:, -1, :]
y = sample(logits, temperature=temperature)
tokens = [y.item()]
for n in range(max_tokens - 1):
logits = model.language_model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits, temperature)
token = y.item()
if token == processor.tokenizer.eos_token_id:
break
tokens.append(token)
return processor.tokenizer.decode(tokens)
shard_full = Shard("llava", 0, 31, 32)
shard1 = Shard("llava", 0, 12, 32)
shard2 = Shard("llava", 13, 31, 32)
full_model_shard, full_processor = asyncio.run(load_shard_llava("llava-hf/llava-1.5-7b-hf", shard=shard_full))
full_model_shard, full_processor = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard_full))
model_shard1, processor1 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard1))
model_shard2, processor2 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard2))
full = StatefulShardedModel(shard_full, full_model_shard)
m1 = StatefulShardedModel(shard1, model_shard1)
m2 = StatefulShardedModel(shard2, model_shard2)
PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
@@ -56,7 +34,30 @@ pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
print(prompt)
generated_text = generate_text(
input_ids, pixel_values, full_model_shard, full_processor, 10, 0
)
print(generated_text)
y = full.step(input_ids, pixel_values, temp=0)
full_generated_tokens = [y.item()]
for _ in range(13):
y = full.step(y, temp=0)
full_generated_tokens.append(y.item())
full_response = full_processor.tokenizer.decode(full_generated_tokens)
print("full response:", full_response)
inputs = processor1(prompt, img, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
y = m1.step(input_ids, pixel_values, temp=0)
y = m2.step(y, temp=0)
full_generated_tokens = [y.item()]
for _ in range(13):
y = m1.step(y, temp=0)
y = m2.step(y, temp=0)
full_generated_tokens.append(y.item())
sharded_response = processor2.tokenizer.decode(full_generated_tokens)
print("sharded response:", sharded_response)
assert full_response == sharded_response

View File

@@ -13,6 +13,9 @@ class Shard:
def is_last_layer(self) -> bool:
return self.end_layer == self.n_layers - 1
def get_layer_count(self) -> int:
return self.end_layer - self.start_layer + 1
def to_dict(self) -> dict:
return {
"model_id": self.model_id,