mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
192 lines
6.7 KiB
Python
192 lines
6.7 KiB
Python
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from dataclasses import field, dataclass
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.mlx.models.base import IdentityBlock
|
|
|
|
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
|
|
|
|
|
|
|
@dataclass
|
|
class CLIPTextModelConfig:
|
|
num_layers: int = 23
|
|
model_dims: int = 1024
|
|
num_heads: int = 16
|
|
max_length: int = 77
|
|
vocab_size: int = 49408
|
|
projection_dim: Optional[int] = None
|
|
hidden_act: str = "quick_gelu"
|
|
|
|
@classmethod
|
|
def from_dict(cls, config):
|
|
return ModelArgs(
|
|
num_layers=config["num_hidden_layers"],
|
|
model_dims=config["hidden_size"],
|
|
num_heads=config["num_attention_heads"],
|
|
max_length=config["max_position_embeddings"],
|
|
vocab_size=config["vocab_size"],
|
|
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
|
|
hidden_act=config.get("hidden_act", "quick_gelu"),
|
|
weight_files=config.get("weight_files", [])
|
|
)
|
|
|
|
@dataclass
|
|
class ModelArgs(CLIPTextModelConfig):
|
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
|
weight_files: List[str] = field(default_factory=lambda: [])
|
|
def __post_init__(self):
|
|
if isinstance(self.shard, dict):
|
|
self.shard = Shard(**self.shard)
|
|
|
|
if not isinstance(self.shard, Shard):
|
|
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
|
|
|
if not self.shard.is_first_layer():
|
|
self.vision_config = None
|
|
|
|
|
|
@dataclass
|
|
class CLIPOutput:
|
|
pooled_output: Optional[mx.array] = None
|
|
last_hidden_state: Optional[mx.array] = None
|
|
hidden_states: Optional[List[mx.array]] = None
|
|
|
|
|
|
class CLIPEncoderLayer(nn.Module):
|
|
"""The transformer encoder layer from CLIP."""
|
|
|
|
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
|
super().__init__()
|
|
|
|
self.layer_norm1 = nn.LayerNorm(model_dims)
|
|
self.layer_norm2 = nn.LayerNorm(model_dims)
|
|
|
|
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
|
self.attention.query_proj.bias = mx.zeros(model_dims)
|
|
self.attention.key_proj.bias = mx.zeros(model_dims)
|
|
self.attention.value_proj.bias = mx.zeros(model_dims)
|
|
self.attention.out_proj.bias = mx.zeros(model_dims)
|
|
|
|
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
|
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
|
|
|
self.act = _ACTIVATIONS[activation]
|
|
|
|
def __call__(self, x, attn_mask=None):
|
|
|
|
y = self.layer_norm1(x)
|
|
y = self.attention(y, y, y, attn_mask)
|
|
x = y + x
|
|
|
|
y = self.layer_norm2(x)
|
|
y = self.linear1(y)
|
|
y = self.act(y)
|
|
y = self.linear2(y)
|
|
x = y + x
|
|
return x
|
|
|
|
|
|
class CLIPTextModel(nn.Module):
|
|
"""Implements the text encoder transformer from CLIP."""
|
|
|
|
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
|
|
super().__init__()
|
|
|
|
self.shard = shard
|
|
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
|
|
if self.shard.is_first_layer():
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
|
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
|
self.layers = []
|
|
for i in range(math.ceil(config.num_layers/2)):
|
|
if 2*i in self.layers_range:
|
|
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
|
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
|
|
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
|
else:
|
|
self.layers.append(IdentityBlock())
|
|
if self.shard.is_last_layer():
|
|
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
|
|
|
if config.projection_dim is not None:
|
|
self.text_projection = nn.Linear(
|
|
config.model_dims, config.projection_dim, bias=False
|
|
)
|
|
|
|
def _get_mask(self, N, dtype):
|
|
indices = mx.arange(N)
|
|
mask = indices[:, None] < indices[None]
|
|
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
|
return mask
|
|
|
|
def __call__(self, x, mask=None):
|
|
# Extract some shapes
|
|
if self.shard.is_first_layer():
|
|
B, N = x.shape
|
|
eos_tokens = x.argmax(-1)
|
|
|
|
# Compute the embeddings
|
|
x = self.token_embedding(x)
|
|
|
|
x = x + self.position_embedding.weight[:N]
|
|
# Compute the features from the transformer
|
|
mask = self._get_mask(N, x.dtype)
|
|
|
|
for l in self.layers:
|
|
x = l(x, mask)
|
|
# Apply the final layernorm and return
|
|
|
|
if self.shard.is_last_layer():
|
|
x = self.final_layer_norm(x)
|
|
|
|
|
|
|
|
return x, mask
|
|
def sanitize(self, weights):
|
|
sanitized_weights = {}
|
|
for key, value in weights.items():
|
|
if "position_ids" in key:
|
|
continue
|
|
if key.startswith("text_model."):
|
|
key = key[11:]
|
|
if key.startswith("embeddings."):
|
|
key = key[11:]
|
|
if key.startswith("encoder."):
|
|
key = key[8:]
|
|
|
|
# Map attention layers
|
|
if "self_attn." in key:
|
|
key = key.replace("self_attn.", "attention.")
|
|
if "q_proj." in key:
|
|
key = key.replace("q_proj.", "query_proj.")
|
|
if "k_proj." in key:
|
|
key = key.replace("k_proj.", "key_proj.")
|
|
if "v_proj." in key:
|
|
key = key.replace("v_proj.", "value_proj.")
|
|
|
|
# Map ffn layers
|
|
if "mlp.fc1" in key:
|
|
key = key.replace("mlp.fc1", "linear1")
|
|
if "mlp.fc2" in key:
|
|
key = key.replace("mlp.fc2", "linear2")
|
|
|
|
if key.startswith("layers."):
|
|
layer_num = int(key.split(".")[1])
|
|
if layer_num not in self.layers_range:
|
|
continue
|
|
if not self.shard.is_first_layer() and "embedding" in key:
|
|
continue
|
|
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
|
|
continue
|
|
if not self.shard.is_last_layer() and key.startswith("text_projection"):
|
|
continue
|
|
sanitized_weights[key] = value
|
|
return sanitized_weights
|