Merge pull request #596 from exo-explore/phi4

add phi 3.5, phi 4
This commit is contained in:
Alex Cheema
2025-01-08 16:39:32 +00:00
committed by GitHub
4 changed files with 127 additions and 4 deletions

View File

@@ -0,0 +1,117 @@
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__()
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
if self.args.shard.is_first_layer():
h = self.embed_tokens(inputs)
else:
h = inputs
mask = None
if h.shape[1] > 1:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
if self.args.shard.is_last_layer():
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Phi3Model(args)
if self.args.shard.is_last_layer():
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
out = self.lm_head(out)
return out
def sanitize(self, weights):
shard_state_dict = {}
for key, value in weights.items():
if "self_attn.rope.inv_freq" in key:
continue
if key.startswith('model.layers.'):
layer_num = int(key.split('.')[2])
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
shard_state_dict[key] = value
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
shard_state_dict[key] = value
return shard_state_dict
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__() # Ensure parent initializations are respected
super().__post_init__()
if isinstance(self.shard, Shard):
return
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
self.shard = Shard(**self.shard)
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

View File

@@ -111,6 +111,9 @@ model_cards = {
# gemma
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
# phi
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
# dummy
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
}
@@ -149,6 +152,8 @@ pretty_name = {
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-72b": "Qwen 2.5 72B",
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
"phi-3.5-mini": "Phi-3.5 Mini",
"phi-4": "Phi-4",
"llama-3-8b": "Llama 3 8B",
"llama-3-70b": "Llama 3 70B",
}

View File

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-Mini-Instruct-4bit", "mlx-community/Phi-4-4bit"]
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
models = []
for model_id in model_cards: