Merge pull request #627 from exo-explore/deepseek

Deepseek, tinychat group models, latex formatting, thinking boxes
This commit is contained in:
Alex Cheema
2025-01-24 18:14:57 +00:00
committed by GitHub
7 changed files with 534 additions and 110 deletions

View File

@@ -0,0 +1,135 @@
from dataclasses import dataclass, field
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import (
ModelArgs as V3ModelArgs,
DeepseekV3DecoderLayer,
)
from .base import IdentityBlock
from exo.inference.shard import Shard
@dataclass
class ModelArgs(V3ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__()
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class DeepseekV3Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.num_hidden_layers = config.num_hidden_layers
self.vocab_size = config.vocab_size
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(DeepseekV3DecoderLayer(config, i))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
) -> mx.array:
if self.args.shard.is_first_layer():
h = self.embed_tokens(x)
else:
h = x
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None]*len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
if self.args.shard.is_last_layer():
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV3Model(config)
if self.args.shard.is_last_layer():
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
return self.lm_head(out)
return out
def sanitize(self, weights):
shard_state_dict = {}
for key, value in weights.items():
if key.startswith('model.layers.'):
layer_num = int(key.split('.')[2])
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
shard_state_dict[key] = value
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
shard_state_dict[key] = value
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
expert_key = f"{prefix}.mlp.experts.0.{m}.{k}"
if expert_key in shard_state_dict:
to_join = [
shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return shard_state_dict
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
self.args.v_head_dim,
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -88,6 +88,38 @@ model_cards = {
### deepseek
"deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
"deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
"deepseek-v3": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-4bit", }, },
"deepseek-r1": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-4bit", }, },
### deepseek distills
"deepseek-r1-distill-qwen-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/deepseek-r1-distill-qwen-1.5b", }, },
"deepseek-r1-distill-qwen-1.5b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-3bit", }, },
"deepseek-r1-distill-qwen-1.5b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-6bit", }, },
"deepseek-r1-distill-qwen-1.5b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit", }, },
"deepseek-r1-distill-qwen-1.5b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-bf16", }, },
"deepseek-r1-distill-qwen-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit", }, },
"deepseek-r1-distill-qwen-7b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-3bit", }, },
"deepseek-r1-distill-qwen-7b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-6bit", }, },
"deepseek-r1-distill-qwen-7b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-8bit", }, },
"deepseek-r1-distill-qwen-7b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-bf16", }, },
"deepseek-r1-distill-qwen-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-4bit", }, },
"deepseek-r1-distill-qwen-14b-3bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-3bit", }, },
"deepseek-r1-distill-qwen-14b-6bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-6bit", }, },
"deepseek-r1-distill-qwen-14b-8bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-8bit", }, },
"deepseek-r1-distill-qwen-14b-bf16": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-bf16", }, },
"deepseek-r1-distill-qwen-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", }, },
"deepseek-r1-distill-qwen-32b-3bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-3bit", }, },
"deepseek-r1-distill-qwen-32b-6bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-6bit", }, },
"deepseek-r1-distill-qwen-32b-8bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-MLX-8Bit", }, },
"deepseek-r1-distill-qwen-32b-bf16": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-bf16", }, },
"deepseek-r1-distill-llama-8b": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit", }, },
"deepseek-r1-distill-llama-8b-3bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-3bit", }, },
"deepseek-r1-distill-llama-8b-6bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-6bit", }, },
"deepseek-r1-distill-llama-8b-8bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-8bit", }, },
"deepseek-r1-distill-llama-8b-bf16": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-bf16", }, },
"deepseek-r1-distill-llama-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-4bit", }, },
"deepseek-r1-distill-llama-70b-3bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-3bit", }, },
"deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
"deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
### llava
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
### qwen
@@ -140,6 +172,8 @@ pretty_name = {
"mistral-large": "Mistral Large",
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
"deepseek-v3": "Deepseek V3",
"deepseek-r1": "Deepseek R1",
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
@@ -159,6 +193,38 @@ pretty_name = {
"llama-3-8b": "Llama 3 8B",
"llama-3-70b": "Llama 3 70B",
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
"deepseek-r1-distill-qwen-1.5b": "DeepSeek R1 Distill Qwen 1.5B",
"deepseek-r1-distill-qwen-1.5b-3bit": "DeepSeek R1 Distill Qwen 1.5B (3-bit)",
"deepseek-r1-distill-qwen-1.5b-6bit": "DeepSeek R1 Distill Qwen 1.5B (6-bit)",
"deepseek-r1-distill-qwen-1.5b-8bit": "DeepSeek R1 Distill Qwen 1.5B (8-bit)",
"deepseek-r1-distill-qwen-1.5b-bf16": "DeepSeek R1 Distill Qwen 1.5B (BF16)",
"deepseek-r1-distill-qwen-7b": "DeepSeek R1 Distill Qwen 7B",
"deepseek-r1-distill-qwen-7b-3bit": "DeepSeek R1 Distill Qwen 7B (3-bit)",
"deepseek-r1-distill-qwen-7b-6bit": "DeepSeek R1 Distill Qwen 7B (6-bit)",
"deepseek-r1-distill-qwen-7b-8bit": "DeepSeek R1 Distill Qwen 7B (8-bit)",
"deepseek-r1-distill-qwen-7b-bf16": "DeepSeek R1 Distill Qwen 7B (BF16)",
"deepseek-r1-distill-qwen-14b": "DeepSeek R1 Distill Qwen 14B",
"deepseek-r1-distill-qwen-14b-3bit": "DeepSeek R1 Distill Qwen 14B (3-bit)",
"deepseek-r1-distill-qwen-14b-6bit": "DeepSeek R1 Distill Qwen 14B (6-bit)",
"deepseek-r1-distill-qwen-14b-8bit": "DeepSeek R1 Distill Qwen 14B (8-bit)",
"deepseek-r1-distill-qwen-14b-bf16": "DeepSeek R1 Distill Qwen 14B (BF16)",
"deepseek-r1-distill-qwen-32b": "DeepSeek R1 Distill Qwen 32B",
"deepseek-r1-distill-qwen-32b-3bit": "DeepSeek R1 Distill Qwen 32B (3-bit)",
"deepseek-r1-distill-qwen-32b-8bit": "DeepSeek R1 Distill Qwen 32B (8-bit)",
"deepseek-r1-distill-qwen-32b-bf16": "DeepSeek R1 Distill Qwen 32B (BF16)",
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
"deepseek-r1-distill-llama-8b": "DeepSeek R1 Distill Llama 8B",
"deepseek-r1-distill-llama-8b-3bit": "DeepSeek R1 Distill Llama 8B (3-bit)",
"deepseek-r1-distill-llama-8b-6bit": "DeepSeek R1 Distill Llama 8B (6-bit)",
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
"deepseek-r1-distill-llama-8b-bf16": "DeepSeek R1 Distill Llama 8B (BF16)",
"deepseek-r1-distill-llama-70b": "DeepSeek R1 Distill Llama 70B",
"deepseek-r1-distill-llama-70b-3bit": "DeepSeek R1 Distill Llama 70B (3-bit)",
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
"deepseek-r1-distill-qwen-32b-6bit": "DeepSeek R1 Distill Qwen 32B (6-bit)",
}
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

View File

@@ -742,4 +742,91 @@ main {
.peer-connection i {
font-size: 0.8em;
color: #666;
}
.thinking-block {
background-color: rgba(255, 255, 255, 0.05);
border-radius: 8px;
margin: 8px 0;
overflow: hidden;
}
.thinking-header {
background-color: rgba(255, 255, 255, 0.1);
padding: 8px 12px;
font-size: 0.9em;
color: #a0a0a0;
display: flex;
align-items: center;
gap: 8px;
}
.thinking-content {
padding: 12px;
white-space: pre-wrap;
}
@keyframes thinking-spin {
to { transform: rotate(360deg); }
}
.thinking-header.thinking::before {
content: '';
width: 12px;
height: 12px;
border: 2px solid #a0a0a0;
border-top-color: transparent;
border-radius: 50%;
animation: thinking-spin 1s linear infinite;
}
.model-group {
margin-bottom: 12px;
}
.model-group-header,
.model-subgroup-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 8px 12px;
background-color: var(--primary-bg-color);
border-radius: 6px;
cursor: pointer;
transition: all 0.2s ease;
margin-bottom: 8px;
}
.model-group-header:hover,
.model-subgroup-header:hover {
background-color: var(--secondary-color-transparent);
}
.model-group-content {
padding-left: 12px;
}
.model-subgroup {
margin-bottom: 8px;
}
.model-subgroup-header {
font-size: 0.9em;
background-color: rgba(255, 255, 255, 0.05);
}
.model-subgroup-content {
padding-left: 12px;
}
.group-header-content {
display: flex;
align-items: center;
gap: 8px;
}
.model-count {
font-size: 0.8em;
color: var(--secondary-color-transparent);
font-family: monospace;
}

View File

@@ -22,6 +22,7 @@
<link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
<link href="/index.css" rel="stylesheet"/>
<link href="/common.css" rel="stylesheet"/>
<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
</head>
<body>
<main x-data="state" x-init="console.log(endpoint)">
@@ -49,50 +50,78 @@
<span>Loading models...</span>
</div>
<template x-for="(model, key) in models" :key="key">
<div class="model-option"
:class="{ 'selected': cstate.selectedModel === key }"
@click="cstate.selectedModel = key">
<div class="model-header">
<div class="model-name" x-text="model.name"></div>
<button
@click.stop="deleteModel(key, model)"
class="model-delete-button"
x-show="model.download_percentage > 0">
<i class="fas fa-trash"></i>
</button>
</div>
<div class="model-info">
<div class="model-progress">
<template x-if="model.loading">
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
</template>
<div class="model-progress-info">
<template x-if="!model.loading && model.download_percentage != null">
<span>
<!-- Check if there's an active download for this model -->
<template x-if="downloadProgress?.some(p =>
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
)">
<i class="fas fa-circle-notch fa-spin"></i>
</template>
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
</span>
</template>
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
<button
@click.stop="handleDownload(key)"
class="model-download-button">
<i class="fas fa-download"></i>
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
</button>
</template>
</div>
<!-- Group models by prefix -->
<template x-for="[mainPrefix, subGroups] in Object.entries(groupModelsByPrefix(models))" :key="mainPrefix">
<div class="model-group">
<div class="model-group-header" @click="toggleGroup(mainPrefix)">
<div class="group-header-content">
<span x-text="mainPrefix"></span>
<span class="model-count" x-text="getGroupCounts(Object.values(subGroups).flatMap(group => Object.values(group)))"></span>
</div>
<template x-if="model.total_size">
<div class="model-size" x-text="model.total_downloaded ?
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
formatBytes(model.total_size)">
<i class="fas" :class="isGroupExpanded(mainPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
</div>
<div class="model-group-content" x-show="isGroupExpanded(mainPrefix)" x-transition>
<template x-for="[subPrefix, groupModels] in Object.entries(subGroups)" :key="subPrefix">
<div class="model-subgroup">
<div class="model-subgroup-header" @click.stop="toggleGroup(mainPrefix, subPrefix)">
<div class="group-header-content">
<span x-text="subPrefix"></span>
<span class="model-count" x-text="getGroupCounts(groupModels)"></span>
</div>
<i class="fas" :class="isGroupExpanded(mainPrefix, subPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
</div>
<div class="model-subgroup-content" x-show="isGroupExpanded(mainPrefix, subPrefix)" x-transition>
<template x-for="(model, key) in groupModels" :key="key">
<div class="model-option"
:class="{ 'selected': cstate.selectedModel === key }"
@click="cstate.selectedModel = key">
<div class="model-header">
<div class="model-name" x-text="model.name"></div>
<button
@click.stop="deleteModel(key, model)"
class="model-delete-button"
x-show="model.download_percentage > 0">
<i class="fas fa-trash"></i>
</button>
</div>
<div class="model-info">
<div class="model-progress">
<template x-if="model.loading">
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
</template>
<div class="model-progress-info">
<template x-if="!model.loading && model.download_percentage != null">
<span>
<template x-if="downloadProgress?.some(p =>
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
)">
<i class="fas fa-circle-notch fa-spin"></i>
</template>
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
</span>
</template>
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
<button
@click.stop="handleDownload(key)"
class="model-download-button">
<i class="fas fa-download"></i>
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
</button>
</template>
</div>
</div>
<template x-if="model.total_size">
<div class="model-size" x-text="model.total_downloaded ?
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
formatBytes(model.total_size)">
</div>
</template>
</div>
</div>
</template>
</div>
</div>
</template>
</div>
@@ -177,6 +206,7 @@
</template>
</div>
</div>
</div>
<button
@click="
home = 0;
@@ -190,67 +220,87 @@
<i class="fas fa-arrow-left"></i>
Back to Chats
</button>
<div class="messages" x-init="
$watch('cstate', value =&gt; {
$el.innerHTML = '';
value.messages.forEach(({ role, content }) =&gt; {
const div = document.createElement('div');
div.className = `message message-role-${role}`;
try {
if (content.includes('![Generated Image]')) {
const imageUrl = content.match(/\((.*?)\)/)[1];
const img = document.createElement('img');
img.src = imageUrl;
img.alt = 'Generated Image';
img.onclick = async () => {
try {
const response = await fetch(img.src);
const blob = await response.blob();
const file = new File([blob], 'image.png', { type: 'image/png' });
handleImageUpload({ target: { files: [file] } });
} catch (error) {
console.error('Error fetching image:', error);
}
};
div.appendChild(img);
} else {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
}
} catch (e) {
console.log(content);
console.error(e);
<div class="messages"
x-init="
$watch('cstate', (value) => {
$el.innerHTML = '';
value.messages.forEach((msg) => {
const div = document.createElement('div');
div.className = `message message-role-${msg.role}`;
try {
// If there's an embedded generated image
if (msg.content.includes('![Generated Image]')) {
const imageUrlMatch = msg.content.match(/\((.*?)\)/);
if (imageUrlMatch) {
const imageUrl = imageUrlMatch[1];
const img = document.createElement('img');
img.src = imageUrl;
img.alt = 'Generated Image';
img.onclick = async () => {
try {
const response = await fetch(img.src);
const blob = await response.blob();
const file = new File([blob], 'image.png', { type: 'image/png' });
handleImageUpload({ target: { files: [file] } });
} catch (error) {
console.error('Error fetching image:', error);
}
};
div.appendChild(img);
} else {
// fallback if markdown is malformed
div.textContent = msg.content;
}
} else {
// Otherwise, transform message text (including streamed think blocks).
div.innerHTML = transformMessageContent(msg);
// Render math after content is inserted
MathJax.typesetPromise([div]);
}
} catch (e) {
console.error('Error rendering message:', e);
div.textContent = msg.content; // fallback
}
// add a clipboard button to all code blocks
const codeBlocks = div.querySelectorAll('.hljs');
codeBlocks.forEach(codeBlock =&gt; {
const button = document.createElement('button');
button.className = 'clipboard-button';
button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;';
button.onclick = () =&gt; {
// navigator.clipboard.writeText(codeBlock.textContent);
const range = document.createRange();
range.setStartBefore(codeBlock);
range.setEndAfter(codeBlock);
window.getSelection()?.removeAllRanges();
window.getSelection()?.addRange(range);
document.execCommand('copy');
window.getSelection()?.removeAllRanges();
// Add a clipboard button to code blocks
const codeBlocks = div.querySelectorAll('.hljs');
codeBlocks.forEach((codeBlock) => {
const button = document.createElement('button');
button.className = 'clipboard-button';
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
button.innerHTML = '&lt;i class=\'fas fa-check\'&gt;&lt;/i&gt;';
setTimeout(() =&gt; button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;', 1000);
};
codeBlock.appendChild(button);
});
button.onclick = () => {
const range = document.createRange();
range.setStartBefore(codeBlock);
range.setEndAfter(codeBlock);
window.getSelection()?.removeAllRanges();
window.getSelection()?.addRange(range);
document.execCommand('copy');
window.getSelection()?.removeAllRanges();
$el.appendChild(div);
button.innerHTML = '<i class=\'fas fa-check\'></i>';
setTimeout(() => {
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
}, 1000);
};
codeBlock.appendChild(button);
});
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
$el.appendChild(div);
});
" x-intersect="
// Scroll to bottom after rendering
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
" x-ref="messages" x-show="home === 2" x-transition="">
});
"
x-ref="messages"
x-show="home === 2"
x-transition=""
>
</div>
<!-- Download Progress Section -->
@@ -353,4 +403,42 @@
</div>
</div>
</main>
<script>
/**
* Transform a single message's content into HTML, preserving <think> blocks.
* Ensure LaTeX expressions are properly delimited for MathJax.
*/
function transformMessageContent(message) {
let text = message.content;
console.log('Processing message content:', text);
// First replace think blocks
text = text.replace(
/<think>([\s\S]*?)(?:<\/think>|$)/g,
(match, body) => {
console.log('Found think block with content:', body);
const isComplete = match.includes('</think>');
const spinnerClass = isComplete ? '' : ' thinking';
const parsedBody = DOMPurify.sanitize(marked.parse(body));
return `
<div class='thinking-block'>
<div class='thinking-header${spinnerClass}'>Thinking...</div>
<div class='thinking-content'>${parsedBody}</div>
</div>`;
}
);
// Add backslashes to parentheses and brackets for LaTeX
text = text
.replace(/\((?=\s*[\d\\])/g, '\\(') // Add backslash before opening parentheses
.replace(/\)(?!\w)/g, '\\)') // Add backslash before closing parentheses
.replace(/\[(?=\s*[\d\\])/g, '\\[') // Add backslash before opening brackets
.replace(/\](?!\w)/g, '\\]') // Add backslash before closing brackets
.replace(/\[[\s\n]*\\boxed/g, '\\[\\boxed') // Ensure boxed expressions are properly delimited
.replace(/\\!/g, '\\\\!'); // Preserve LaTeX spacing commands
return DOMPurify.sanitize(marked.parse(text));
}
</script>
</body>

View File

@@ -42,6 +42,9 @@ document.addEventListener("alpine:init", () => {
topology: null,
topologyInterval: null,
// Add these new properties
expandedGroups: {},
init() {
// Clean up any pending messages
localStorage.removeItem("pendingMessage");
@@ -393,8 +396,6 @@ document.addEventListener("alpine:init", () => {
},
async *openaiChatCompletion(model, messages) {
// stream response
console.log("model", model)
const response = await fetch(`${this.endpoint}/chat/completions`, {
method: "POST",
headers: {
@@ -417,19 +418,17 @@ document.addEventListener("alpine:init", () => {
const reader = response.body.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream()).getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
if (done) break;
if (value.type === "event") {
const json = JSON.parse(value.data);
if (json.choices) {
const choice = json.choices[0];
if (choice.finish_reason === "stop") {
break;
}
yield choice.delta.content;
if (choice.finish_reason === "stop") break;
if (choice.delta.content) yield choice.delta.content;
}
}
}
@@ -668,7 +667,55 @@ document.addEventListener("alpine:init", () => {
`;
vizElement.appendChild(nodeElement);
});
}
},
// Add these helper methods
countDownloadedModels(models) {
return Object.values(models).filter(model => model.downloaded).length;
},
getGroupCounts(groupModels) {
const total = Object.keys(groupModels).length;
const downloaded = this.countDownloadedModels(groupModels);
return `[${downloaded}/${total}]`;
},
// Update the existing groupModelsByPrefix method to include counts
groupModelsByPrefix(models) {
const groups = {};
Object.entries(models).forEach(([key, model]) => {
const parts = key.split('-');
const mainPrefix = parts[0].toUpperCase();
let subPrefix;
if (parts.length === 2) {
subPrefix = parts[1].toUpperCase();
} else if (parts.length > 2) {
subPrefix = parts[1].toUpperCase();
} else {
subPrefix = 'OTHER';
}
if (!groups[mainPrefix]) {
groups[mainPrefix] = {};
}
if (!groups[mainPrefix][subPrefix]) {
groups[mainPrefix][subPrefix] = {};
}
groups[mainPrefix][subPrefix][key] = model;
});
return groups;
},
toggleGroup(prefix, subPrefix = null) {
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
this.expandedGroups[key] = !this.expandedGroups[key];
},
isGroupExpanded(prefix, subPrefix = null) {
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
return this.expandedGroups[key] || false;
},
}));
});

View File

@@ -35,8 +35,8 @@ install_requires = [
extras_require = {
"formatting": ["yapf==0.40.2",],
"apple_silicon": [
"mlx==0.21.1",
"mlx-lm==0.20.4",
"mlx==0.22.0",
"mlx-lm==0.21.1",
],
"windows": ["pywin32==308",],
"nvidia-gpu": ["nvidia-ml-py==12.560.30",],

View File

@@ -37,5 +37,6 @@ verbose = os.environ.get("VERBOSE", "0").lower() == "1"
for m in models:
# TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
# test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose)
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True), verbose)
test_tokenizer(m, AutoTokenizer.from_pretrained(m), verbose)
if m not in ["mlx-community/DeepSeek-R1-4bit", "mlx-community/DeepSeek-V3-4bit"]:
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True, trust_remote_code=True), verbose)
test_tokenizer(m, AutoTokenizer.from_pretrained(m, trust_remote_code=True), verbose)