mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #627 from exo-explore/deepseek
Deepseek, tinychat group models, latex formatting, thinking boxes
This commit is contained in:
135
exo/inference/mlx/models/deepseek_v3.py
Normal file
135
exo/inference/mlx/models/deepseek_v3.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
ModelArgs as V3ModelArgs,
|
||||
DeepseekV3DecoderLayer,
|
||||
)
|
||||
from .base import IdentityBlock
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(V3ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.vocab_size = config.vocab_size
|
||||
if self.args.shard.is_first_layer():
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(DeepseekV3DecoderLayer(config, i))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(x)
|
||||
else:
|
||||
h = x
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV3Model(config)
|
||||
if self.args.shard.is_last_layer():
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
return self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
expert_key = f"{prefix}.mlp.experts.0.{m}.{k}"
|
||||
if expert_key in shard_state_dict:
|
||||
to_join = [
|
||||
shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
||||
self.args.v_head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -88,6 +88,38 @@ model_cards = {
|
||||
### deepseek
|
||||
"deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
|
||||
"deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
|
||||
"deepseek-v3": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-4bit", }, },
|
||||
"deepseek-r1": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-4bit", }, },
|
||||
### deepseek distills
|
||||
"deepseek-r1-distill-qwen-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/deepseek-r1-distill-qwen-1.5b", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-3bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-6bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-8bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-bf16": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-3bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-6bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-8bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-MLX-8Bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-bf16": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-bf16", }, },
|
||||
"deepseek-r1-distill-llama-8b": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-3bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-3bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-6bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-6bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-8bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-8bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-bf16": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-bf16", }, },
|
||||
"deepseek-r1-distill-llama-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-4bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-3bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-3bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
|
||||
### llava
|
||||
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
|
||||
### qwen
|
||||
@@ -140,6 +172,8 @@ pretty_name = {
|
||||
"mistral-large": "Mistral Large",
|
||||
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
||||
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
||||
"deepseek-v3": "Deepseek V3",
|
||||
"deepseek-r1": "Deepseek R1",
|
||||
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
||||
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
|
||||
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
||||
@@ -159,6 +193,38 @@ pretty_name = {
|
||||
"llama-3-8b": "Llama 3 8B",
|
||||
"llama-3-70b": "Llama 3 70B",
|
||||
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
|
||||
"deepseek-r1-distill-qwen-1.5b": "DeepSeek R1 Distill Qwen 1.5B",
|
||||
"deepseek-r1-distill-qwen-1.5b-3bit": "DeepSeek R1 Distill Qwen 1.5B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-6bit": "DeepSeek R1 Distill Qwen 1.5B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-8bit": "DeepSeek R1 Distill Qwen 1.5B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-bf16": "DeepSeek R1 Distill Qwen 1.5B (BF16)",
|
||||
"deepseek-r1-distill-qwen-7b": "DeepSeek R1 Distill Qwen 7B",
|
||||
"deepseek-r1-distill-qwen-7b-3bit": "DeepSeek R1 Distill Qwen 7B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-6bit": "DeepSeek R1 Distill Qwen 7B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-8bit": "DeepSeek R1 Distill Qwen 7B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-bf16": "DeepSeek R1 Distill Qwen 7B (BF16)",
|
||||
"deepseek-r1-distill-qwen-14b": "DeepSeek R1 Distill Qwen 14B",
|
||||
"deepseek-r1-distill-qwen-14b-3bit": "DeepSeek R1 Distill Qwen 14B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-6bit": "DeepSeek R1 Distill Qwen 14B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-8bit": "DeepSeek R1 Distill Qwen 14B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-bf16": "DeepSeek R1 Distill Qwen 14B (BF16)",
|
||||
"deepseek-r1-distill-qwen-32b": "DeepSeek R1 Distill Qwen 32B",
|
||||
"deepseek-r1-distill-qwen-32b-3bit": "DeepSeek R1 Distill Qwen 32B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-8bit": "DeepSeek R1 Distill Qwen 32B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-bf16": "DeepSeek R1 Distill Qwen 32B (BF16)",
|
||||
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
|
||||
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
|
||||
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
|
||||
"deepseek-r1-distill-llama-8b": "DeepSeek R1 Distill Llama 8B",
|
||||
"deepseek-r1-distill-llama-8b-3bit": "DeepSeek R1 Distill Llama 8B (3-bit)",
|
||||
"deepseek-r1-distill-llama-8b-6bit": "DeepSeek R1 Distill Llama 8B (6-bit)",
|
||||
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
|
||||
"deepseek-r1-distill-llama-8b-bf16": "DeepSeek R1 Distill Llama 8B (BF16)",
|
||||
"deepseek-r1-distill-llama-70b": "DeepSeek R1 Distill Llama 70B",
|
||||
"deepseek-r1-distill-llama-70b-3bit": "DeepSeek R1 Distill Llama 70B (3-bit)",
|
||||
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
|
||||
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-6bit": "DeepSeek R1 Distill Qwen 32B (6-bit)",
|
||||
}
|
||||
|
||||
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
||||
|
||||
@@ -742,4 +742,91 @@ main {
|
||||
.peer-connection i {
|
||||
font-size: 0.8em;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.thinking-block {
|
||||
background-color: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 8px;
|
||||
margin: 8px 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.thinking-header {
|
||||
background-color: rgba(255, 255, 255, 0.1);
|
||||
padding: 8px 12px;
|
||||
font-size: 0.9em;
|
||||
color: #a0a0a0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.thinking-content {
|
||||
padding: 12px;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
@keyframes thinking-spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.thinking-header.thinking::before {
|
||||
content: '';
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border: 2px solid #a0a0a0;
|
||||
border-top-color: transparent;
|
||||
border-radius: 50%;
|
||||
animation: thinking-spin 1s linear infinite;
|
||||
}
|
||||
|
||||
.model-group {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.model-group-header,
|
||||
.model-subgroup-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 8px 12px;
|
||||
background-color: var(--primary-bg-color);
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.model-group-header:hover,
|
||||
.model-subgroup-header:hover {
|
||||
background-color: var(--secondary-color-transparent);
|
||||
}
|
||||
|
||||
.model-group-content {
|
||||
padding-left: 12px;
|
||||
}
|
||||
|
||||
.model-subgroup {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.model-subgroup-header {
|
||||
font-size: 0.9em;
|
||||
background-color: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
.model-subgroup-content {
|
||||
padding-left: 12px;
|
||||
}
|
||||
|
||||
.group-header-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.model-count {
|
||||
font-size: 0.8em;
|
||||
color: var(--secondary-color-transparent);
|
||||
font-family: monospace;
|
||||
}
|
||||
@@ -22,6 +22,7 @@
|
||||
<link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
|
||||
<link href="/index.css" rel="stylesheet"/>
|
||||
<link href="/common.css" rel="stylesheet"/>
|
||||
<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<main x-data="state" x-init="console.log(endpoint)">
|
||||
@@ -49,50 +50,78 @@
|
||||
<span>Loading models...</span>
|
||||
</div>
|
||||
|
||||
<template x-for="(model, key) in models" :key="key">
|
||||
<div class="model-option"
|
||||
:class="{ 'selected': cstate.selectedModel === key }"
|
||||
@click="cstate.selectedModel = key">
|
||||
<div class="model-header">
|
||||
<div class="model-name" x-text="model.name"></div>
|
||||
<button
|
||||
@click.stop="deleteModel(key, model)"
|
||||
class="model-delete-button"
|
||||
x-show="model.download_percentage > 0">
|
||||
<i class="fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
<div class="model-info">
|
||||
<div class="model-progress">
|
||||
<template x-if="model.loading">
|
||||
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
|
||||
</template>
|
||||
<div class="model-progress-info">
|
||||
<template x-if="!model.loading && model.download_percentage != null">
|
||||
<span>
|
||||
<!-- Check if there's an active download for this model -->
|
||||
<template x-if="downloadProgress?.some(p =>
|
||||
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
|
||||
)">
|
||||
<i class="fas fa-circle-notch fa-spin"></i>
|
||||
</template>
|
||||
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
|
||||
</span>
|
||||
</template>
|
||||
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
|
||||
<button
|
||||
@click.stop="handleDownload(key)"
|
||||
class="model-download-button">
|
||||
<i class="fas fa-download"></i>
|
||||
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
|
||||
</button>
|
||||
</template>
|
||||
</div>
|
||||
<!-- Group models by prefix -->
|
||||
<template x-for="[mainPrefix, subGroups] in Object.entries(groupModelsByPrefix(models))" :key="mainPrefix">
|
||||
<div class="model-group">
|
||||
<div class="model-group-header" @click="toggleGroup(mainPrefix)">
|
||||
<div class="group-header-content">
|
||||
<span x-text="mainPrefix"></span>
|
||||
<span class="model-count" x-text="getGroupCounts(Object.values(subGroups).flatMap(group => Object.values(group)))"></span>
|
||||
</div>
|
||||
<template x-if="model.total_size">
|
||||
<div class="model-size" x-text="model.total_downloaded ?
|
||||
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
|
||||
formatBytes(model.total_size)">
|
||||
<i class="fas" :class="isGroupExpanded(mainPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
|
||||
</div>
|
||||
|
||||
<div class="model-group-content" x-show="isGroupExpanded(mainPrefix)" x-transition>
|
||||
<template x-for="[subPrefix, groupModels] in Object.entries(subGroups)" :key="subPrefix">
|
||||
<div class="model-subgroup">
|
||||
<div class="model-subgroup-header" @click.stop="toggleGroup(mainPrefix, subPrefix)">
|
||||
<div class="group-header-content">
|
||||
<span x-text="subPrefix"></span>
|
||||
<span class="model-count" x-text="getGroupCounts(groupModels)"></span>
|
||||
</div>
|
||||
<i class="fas" :class="isGroupExpanded(mainPrefix, subPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
|
||||
</div>
|
||||
|
||||
<div class="model-subgroup-content" x-show="isGroupExpanded(mainPrefix, subPrefix)" x-transition>
|
||||
<template x-for="(model, key) in groupModels" :key="key">
|
||||
<div class="model-option"
|
||||
:class="{ 'selected': cstate.selectedModel === key }"
|
||||
@click="cstate.selectedModel = key">
|
||||
<div class="model-header">
|
||||
<div class="model-name" x-text="model.name"></div>
|
||||
<button
|
||||
@click.stop="deleteModel(key, model)"
|
||||
class="model-delete-button"
|
||||
x-show="model.download_percentage > 0">
|
||||
<i class="fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
<div class="model-info">
|
||||
<div class="model-progress">
|
||||
<template x-if="model.loading">
|
||||
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
|
||||
</template>
|
||||
<div class="model-progress-info">
|
||||
<template x-if="!model.loading && model.download_percentage != null">
|
||||
<span>
|
||||
<template x-if="downloadProgress?.some(p =>
|
||||
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
|
||||
)">
|
||||
<i class="fas fa-circle-notch fa-spin"></i>
|
||||
</template>
|
||||
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
|
||||
</span>
|
||||
</template>
|
||||
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
|
||||
<button
|
||||
@click.stop="handleDownload(key)"
|
||||
class="model-download-button">
|
||||
<i class="fas fa-download"></i>
|
||||
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
|
||||
</button>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
<template x-if="model.total_size">
|
||||
<div class="model-size" x-text="model.total_downloaded ?
|
||||
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
|
||||
formatBytes(model.total_size)">
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
@@ -177,6 +206,7 @@
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
@click="
|
||||
home = 0;
|
||||
@@ -190,67 +220,87 @@
|
||||
<i class="fas fa-arrow-left"></i>
|
||||
Back to Chats
|
||||
</button>
|
||||
<div class="messages" x-init="
|
||||
$watch('cstate', value => {
|
||||
$el.innerHTML = '';
|
||||
value.messages.forEach(({ role, content }) => {
|
||||
const div = document.createElement('div');
|
||||
div.className = `message message-role-${role}`;
|
||||
try {
|
||||
if (content.includes('![Generated Image]')) {
|
||||
const imageUrl = content.match(/\((.*?)\)/)[1];
|
||||
const img = document.createElement('img');
|
||||
img.src = imageUrl;
|
||||
img.alt = 'Generated Image';
|
||||
img.onclick = async () => {
|
||||
try {
|
||||
const response = await fetch(img.src);
|
||||
const blob = await response.blob();
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
handleImageUpload({ target: { files: [file] } });
|
||||
} catch (error) {
|
||||
console.error('Error fetching image:', error);
|
||||
}
|
||||
};
|
||||
div.appendChild(img);
|
||||
} else {
|
||||
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
|
||||
}
|
||||
} catch (e) {
|
||||
console.log(content);
|
||||
console.error(e);
|
||||
<div class="messages"
|
||||
x-init="
|
||||
$watch('cstate', (value) => {
|
||||
$el.innerHTML = '';
|
||||
|
||||
value.messages.forEach((msg) => {
|
||||
const div = document.createElement('div');
|
||||
div.className = `message message-role-${msg.role}`;
|
||||
|
||||
try {
|
||||
// If there's an embedded generated image
|
||||
if (msg.content.includes('![Generated Image]')) {
|
||||
const imageUrlMatch = msg.content.match(/\((.*?)\)/);
|
||||
if (imageUrlMatch) {
|
||||
const imageUrl = imageUrlMatch[1];
|
||||
const img = document.createElement('img');
|
||||
img.src = imageUrl;
|
||||
img.alt = 'Generated Image';
|
||||
|
||||
img.onclick = async () => {
|
||||
try {
|
||||
const response = await fetch(img.src);
|
||||
const blob = await response.blob();
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
handleImageUpload({ target: { files: [file] } });
|
||||
} catch (error) {
|
||||
console.error('Error fetching image:', error);
|
||||
}
|
||||
};
|
||||
div.appendChild(img);
|
||||
} else {
|
||||
// fallback if markdown is malformed
|
||||
div.textContent = msg.content;
|
||||
}
|
||||
} else {
|
||||
// Otherwise, transform message text (including streamed think blocks).
|
||||
div.innerHTML = transformMessageContent(msg);
|
||||
// Render math after content is inserted
|
||||
MathJax.typesetPromise([div]);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error rendering message:', e);
|
||||
div.textContent = msg.content; // fallback
|
||||
}
|
||||
|
||||
// add a clipboard button to all code blocks
|
||||
const codeBlocks = div.querySelectorAll('.hljs');
|
||||
codeBlocks.forEach(codeBlock => {
|
||||
const button = document.createElement('button');
|
||||
button.className = 'clipboard-button';
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
button.onclick = () => {
|
||||
// navigator.clipboard.writeText(codeBlock.textContent);
|
||||
const range = document.createRange();
|
||||
range.setStartBefore(codeBlock);
|
||||
range.setEndAfter(codeBlock);
|
||||
window.getSelection()?.removeAllRanges();
|
||||
window.getSelection()?.addRange(range);
|
||||
document.execCommand('copy');
|
||||
window.getSelection()?.removeAllRanges();
|
||||
// Add a clipboard button to code blocks
|
||||
const codeBlocks = div.querySelectorAll('.hljs');
|
||||
codeBlocks.forEach((codeBlock) => {
|
||||
const button = document.createElement('button');
|
||||
button.className = 'clipboard-button';
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
|
||||
button.innerHTML = '<i class=\'fas fa-check\'></i>';
|
||||
setTimeout(() => button.innerHTML = '<i class=\'fas fa-clipboard\'></i>', 1000);
|
||||
};
|
||||
codeBlock.appendChild(button);
|
||||
});
|
||||
button.onclick = () => {
|
||||
const range = document.createRange();
|
||||
range.setStartBefore(codeBlock);
|
||||
range.setEndAfter(codeBlock);
|
||||
window.getSelection()?.removeAllRanges();
|
||||
window.getSelection()?.addRange(range);
|
||||
document.execCommand('copy');
|
||||
window.getSelection()?.removeAllRanges();
|
||||
|
||||
$el.appendChild(div);
|
||||
button.innerHTML = '<i class=\'fas fa-check\'></i>';
|
||||
setTimeout(() => {
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
}, 1000);
|
||||
};
|
||||
|
||||
codeBlock.appendChild(button);
|
||||
});
|
||||
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
$el.appendChild(div);
|
||||
});
|
||||
" x-intersect="
|
||||
|
||||
// Scroll to bottom after rendering
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
" x-ref="messages" x-show="home === 2" x-transition="">
|
||||
});
|
||||
"
|
||||
x-ref="messages"
|
||||
x-show="home === 2"
|
||||
x-transition=""
|
||||
>
|
||||
</div>
|
||||
|
||||
<!-- Download Progress Section -->
|
||||
@@ -353,4 +403,42 @@
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
<script>
|
||||
/**
|
||||
* Transform a single message's content into HTML, preserving <think> blocks.
|
||||
* Ensure LaTeX expressions are properly delimited for MathJax.
|
||||
*/
|
||||
function transformMessageContent(message) {
|
||||
let text = message.content;
|
||||
console.log('Processing message content:', text);
|
||||
|
||||
// First replace think blocks
|
||||
text = text.replace(
|
||||
/<think>([\s\S]*?)(?:<\/think>|$)/g,
|
||||
(match, body) => {
|
||||
console.log('Found think block with content:', body);
|
||||
const isComplete = match.includes('</think>');
|
||||
const spinnerClass = isComplete ? '' : ' thinking';
|
||||
const parsedBody = DOMPurify.sanitize(marked.parse(body));
|
||||
return `
|
||||
<div class='thinking-block'>
|
||||
<div class='thinking-header${spinnerClass}'>Thinking...</div>
|
||||
<div class='thinking-content'>${parsedBody}</div>
|
||||
</div>`;
|
||||
}
|
||||
);
|
||||
|
||||
// Add backslashes to parentheses and brackets for LaTeX
|
||||
text = text
|
||||
.replace(/\((?=\s*[\d\\])/g, '\\(') // Add backslash before opening parentheses
|
||||
.replace(/\)(?!\w)/g, '\\)') // Add backslash before closing parentheses
|
||||
.replace(/\[(?=\s*[\d\\])/g, '\\[') // Add backslash before opening brackets
|
||||
.replace(/\](?!\w)/g, '\\]') // Add backslash before closing brackets
|
||||
.replace(/\[[\s\n]*\\boxed/g, '\\[\\boxed') // Ensure boxed expressions are properly delimited
|
||||
.replace(/\\!/g, '\\\\!'); // Preserve LaTeX spacing commands
|
||||
|
||||
return DOMPurify.sanitize(marked.parse(text));
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
|
||||
@@ -42,6 +42,9 @@ document.addEventListener("alpine:init", () => {
|
||||
topology: null,
|
||||
topologyInterval: null,
|
||||
|
||||
// Add these new properties
|
||||
expandedGroups: {},
|
||||
|
||||
init() {
|
||||
// Clean up any pending messages
|
||||
localStorage.removeItem("pendingMessage");
|
||||
@@ -393,8 +396,6 @@ document.addEventListener("alpine:init", () => {
|
||||
},
|
||||
|
||||
async *openaiChatCompletion(model, messages) {
|
||||
// stream response
|
||||
console.log("model", model)
|
||||
const response = await fetch(`${this.endpoint}/chat/completions`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -417,19 +418,17 @@ document.addEventListener("alpine:init", () => {
|
||||
|
||||
const reader = response.body.pipeThrough(new TextDecoderStream())
|
||||
.pipeThrough(new EventSourceParserStream()).getReader();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
if (done) break;
|
||||
|
||||
if (value.type === "event") {
|
||||
const json = JSON.parse(value.data);
|
||||
if (json.choices) {
|
||||
const choice = json.choices[0];
|
||||
if (choice.finish_reason === "stop") {
|
||||
break;
|
||||
}
|
||||
yield choice.delta.content;
|
||||
if (choice.finish_reason === "stop") break;
|
||||
if (choice.delta.content) yield choice.delta.content;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -668,7 +667,55 @@ document.addEventListener("alpine:init", () => {
|
||||
`;
|
||||
vizElement.appendChild(nodeElement);
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
// Add these helper methods
|
||||
countDownloadedModels(models) {
|
||||
return Object.values(models).filter(model => model.downloaded).length;
|
||||
},
|
||||
|
||||
getGroupCounts(groupModels) {
|
||||
const total = Object.keys(groupModels).length;
|
||||
const downloaded = this.countDownloadedModels(groupModels);
|
||||
return `[${downloaded}/${total}]`;
|
||||
},
|
||||
|
||||
// Update the existing groupModelsByPrefix method to include counts
|
||||
groupModelsByPrefix(models) {
|
||||
const groups = {};
|
||||
Object.entries(models).forEach(([key, model]) => {
|
||||
const parts = key.split('-');
|
||||
const mainPrefix = parts[0].toUpperCase();
|
||||
|
||||
let subPrefix;
|
||||
if (parts.length === 2) {
|
||||
subPrefix = parts[1].toUpperCase();
|
||||
} else if (parts.length > 2) {
|
||||
subPrefix = parts[1].toUpperCase();
|
||||
} else {
|
||||
subPrefix = 'OTHER';
|
||||
}
|
||||
|
||||
if (!groups[mainPrefix]) {
|
||||
groups[mainPrefix] = {};
|
||||
}
|
||||
if (!groups[mainPrefix][subPrefix]) {
|
||||
groups[mainPrefix][subPrefix] = {};
|
||||
}
|
||||
groups[mainPrefix][subPrefix][key] = model;
|
||||
});
|
||||
return groups;
|
||||
},
|
||||
|
||||
toggleGroup(prefix, subPrefix = null) {
|
||||
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
|
||||
this.expandedGroups[key] = !this.expandedGroups[key];
|
||||
},
|
||||
|
||||
isGroupExpanded(prefix, subPrefix = null) {
|
||||
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
|
||||
return this.expandedGroups[key] || false;
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
|
||||
4
setup.py
4
setup.py
@@ -35,8 +35,8 @@ install_requires = [
|
||||
extras_require = {
|
||||
"formatting": ["yapf==0.40.2",],
|
||||
"apple_silicon": [
|
||||
"mlx==0.21.1",
|
||||
"mlx-lm==0.20.4",
|
||||
"mlx==0.22.0",
|
||||
"mlx-lm==0.21.1",
|
||||
],
|
||||
"windows": ["pywin32==308",],
|
||||
"nvidia-gpu": ["nvidia-ml-py==12.560.30",],
|
||||
|
||||
@@ -37,5 +37,6 @@ verbose = os.environ.get("VERBOSE", "0").lower() == "1"
|
||||
for m in models:
|
||||
# TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
|
||||
# test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose)
|
||||
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True), verbose)
|
||||
test_tokenizer(m, AutoTokenizer.from_pretrained(m), verbose)
|
||||
if m not in ["mlx-community/DeepSeek-R1-4bit", "mlx-community/DeepSeek-V3-4bit"]:
|
||||
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True, trust_remote_code=True), verbose)
|
||||
test_tokenizer(m, AutoTokenizer.from_pretrained(m, trust_remote_code=True), verbose)
|
||||
|
||||
Reference in New Issue
Block a user