mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
add exaone-3.5 LLM Model
This commit is contained in:
80
exo/inference/mlx/models/exaone.py
Normal file
80
exo/inference/mlx/models/exaone.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx_lm.models.base import create_attention_mask
|
||||||
|
from mlx_lm.models.exaone import TransformerBlock, ModelArgs
|
||||||
|
from ...shard import Shard
|
||||||
|
from .base import IdentityBlock
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(ModelArgs):
|
||||||
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# super().__post_init__() # Ensure parent initializations are respected
|
||||||
|
|
||||||
|
if isinstance(self.shard, Shard):
|
||||||
|
return
|
||||||
|
if not isinstance(self.shard, dict):
|
||||||
|
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||||
|
|
||||||
|
self.shard = Shard(**self.shard)
|
||||||
|
|
||||||
|
|
||||||
|
class ExaoneModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
|
||||||
|
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.wte(inputs)
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
|
for layer, c in zip(self.h, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.ln_f(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.transformer = ExaoneModel(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.transformer(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.transformer.wte.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.h
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.head_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
@@ -110,6 +110,8 @@ model_cards = {
|
|||||||
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
||||||
# dummy
|
# dummy
|
||||||
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
||||||
|
"exaone-3.5-7.8b": {"layers": 32, "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/EXAONE-3.5-7.8B-Instruct-4bit"}, },
|
||||||
|
"exaone-3.5-2.4b": {"layers": 30, "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/EXAONE-3.5-2.4B-Instruct-4bit"}, },
|
||||||
}
|
}
|
||||||
|
|
||||||
pretty_name = {
|
pretty_name = {
|
||||||
@@ -145,6 +147,8 @@ pretty_name = {
|
|||||||
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
||||||
"llama-3-8b": "Llama 3 8B",
|
"llama-3-8b": "Llama 3 8B",
|
||||||
"llama-3-70b": "Llama 3 70B",
|
"llama-3-70b": "Llama 3 70B",
|
||||||
|
"exaone-3.5-2.4b": "EXAONE-3.5 2.4B",
|
||||||
|
"exaone-3.5-7.8b": "EXAONE-3.5 7.8B",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user