This commit is contained in:
Varshith
2024-07-26 06:00:29 +05:30
parent dd8c5d63a9
commit 803a442141
4 changed files with 655 additions and 1 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@ __pycache__/
.venv
test_weights.npz
.exo_used_ports
.idea
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@@ -0,0 +1,595 @@
# Copyright © 2024 Apple Inc.
import math
import glob
import inspect
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Dict, Union, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download
@dataclass
class VisionConfig:
model_type: str
num_hidden_layers: int = 24
hidden_size: int = 1024
intermediate_size: int = 4096
num_attention_heads: int = 16
image_size: int = 336
patch_size: int = 14
projection_dim: int = 768
vocab_size: int = 32000
num_channels: int = 3
layer_norm_eps: float = 1e-5
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class VisionAttention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
def __call__(self, queries, keys, values, mask=None):
queries = self.q_proj(queries)
keys = self.k_proj(keys)
values = self.v_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
class VisionMLP(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.activation_fn = nn.GELU(approx="fast")
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def __call__(self, x: mx.array) -> mx.array:
x = self.activation_fn(self.fc1(x))
x = self.fc2(x)
return x
class VisionEncoderLayer(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = VisionAttention(
config.hidden_size, config.num_attention_heads, bias=True
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
y = self.layer_norm1(x)
y = self.self_attn(y, y, y, mask)
x = x + y
y = self.layer_norm2(x)
y = self.mlp(y)
return x + y
class VisionEncoder(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
class VisionEmbeddings(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = mx.zeros((config.hidden_size,))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def __call__(self, x: mx.array) -> mx.array:
batch_size = x.shape[0]
patch_embeddings = self.patch_embedding(x)
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
embed_dim = patch_embeddings.shape[-1]
cls_embeddings = mx.broadcast_to(
self.class_embedding, (batch_size, 1, embed_dim)
)
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
embeddings += self.position_embedding.weight
return embeddings
class ClipVisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.embeddings = VisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
self.encoder = VisionEncoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size)
def __call__(
self,
x: mx.array,
output_hidden_states: Optional[bool] = None,
) -> mx.array:
x = self.embeddings(x)
x = self.pre_layrnorm(x)
encoder_states = (x,) if output_hidden_states else None
for l in self.encoder.layers:
x = l(x, mask=None)
if output_hidden_states:
encoder_states = encoder_states + (x,)
pooler_output = self.post_layernorm(x[:, 0, :])
return pooler_output, x, encoder_states
class VisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.model_type = config.model_type
if self.model_type != "clip_vision_model":
raise ValueError(f"Unsupported model type: {self.model_type}")
self.vision_model = ClipVisionModel(config)
def __call__(
self, x: mx.array, output_hidden_states: Optional[bool] = None
) -> mx.array:
return self.vision_model(x, output_hidden_states)
@staticmethod
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if "position_ids" in k:
# Remove unused position_ids
continue
elif "patch_embedding.weight" in k:
# PyTorch conv2d weight tensors have shape:
# [out_channels, in_channels, kH, KW]
# MLX conv2d expects the weight be of shape:
# [out_channels, kH, KW, in_channels]
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
else:
sanitized_weights[k] = v
return sanitized_weights
@dataclass
class TextConfig:
model_type: str
hidden_size: int = 4096
num_hidden_layers: int = 32
intermediate_size: int = 11008
num_attention_heads: int = 32
rms_norm_eps: float = 1e-6
vocab_size: int = 32000
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class TextAttention(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
dim = config.hidden_size
self.n_heads = n_heads = config.num_attention_heads
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = config.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / config.rope_scaling["factor"]
if config.rope_scaling is not None
and config.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class TextMLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.self_attn = TextAttention(config)
self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.config = config
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class Llama(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
):
# for passing merged input embeddings
if inputs_embeds is None:
h = self.embed_tokens(inputs)
else:
h = inputs_embeds
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.model_type = config.model_type
if self.model_type != "llama":
raise ValueError(
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
)
self.model = Llama(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
):
out, cache = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out), cache
@staticmethod
def sanitize(weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@dataclass
class LlaVAConfig:
text_config: TextConfig
vision_config: VisionConfig
ignore_index: int = -100
image_token_index: int = 32000
vision_feature_select_strategy: str = "default"
vision_feature_layer: int = -2
vocab_size: int = 32000
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlaVAConfig):
super().__init__()
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
)
self.gelu = nn.GELU()
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear_1(x)
x = self.gelu(x)
x = self.linear_2(x)
return x
class LlavaModel(nn.Module):
def __init__(self, config: LlaVAConfig):
self.config = config
self.vision_tower = VisionModel(config.vision_config)
self.language_model = LanguageModel(config.text_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy
def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
)
# Select the hidden states from the desired layer
selected_image_feature = hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
"Unexpected feature selection strategy: "
f"{self.vision_feature_select_strategy}"
)
# Pass image features through the multi-modal projector
image_features = self.multi_modal_projector(selected_image_feature)
# Insert special image tokens in the input_ids
final_inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
return final_inputs_embeds
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape
# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
if len(image_positions) != num_images:
raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images})."
)
text_segments = []
start_idx = 0
for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]
# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.language_model(
input_ids, cache=cache, inputs_embeds=input_embddings
)
return logits, cache
@staticmethod
def from_pretrained(path_or_hf_repo: str):
path = Path(path_or_hf_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
],
)
)
with open(path / "config.json", "r") as f:
model_config = json.load(f)
model_config = LlaVAConfig.from_dict(model_config)
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
model_config.text_config = TextConfig.from_dict(model_config.text_config)
model = LlavaModel(model_config)
weight_files = glob.glob(str(path / "*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
weights = VisionModel.sanitize(weights)
weights = LanguageModel.sanitize(weights)
model.load_weights(list(weights.items()))
return model

View File

@@ -13,11 +13,13 @@ import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from transformers import AutoProcessor
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
from mlx_lm.tuner.utils import apply_lora_layers
from ..shard import Shard
from exo.inference.mlx.models.sharded_llava import LlavaModel, LlaVAConfig, VisionConfig, VisionModel, TextConfig, LanguageModel
class ModelNotFoundError(Exception):
def __init__(self, message):
@@ -228,4 +230,60 @@ async def load_shard(
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
return model, tokenizer
return model, tokenizer
async def load_shard_llava(
path_or_hf_repo: str,
shard: Shard,
tokenizer_config={},
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = await get_model_path(path_or_hf_repo)
processor = AutoProcessor.from_pretrained(model_path)
with open(model_path / "config.json", "r") as f:
model_config = json.load(f)
model_config = LlaVAConfig.from_dict(model_config)
model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
model_config.text_config = TextConfig.from_dict(model_config.text_config)
model = LlavaModel(model_config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
weights = VisionModel.sanitize(weights)
weights = LanguageModel.sanitize(weights)
model.load_weights(list(weights.items()))
return model, processor

View File