load mlx model shard on mlx thread so it doesnt block

This commit is contained in:
Alex Cheema
2025-01-28 18:49:19 +00:00
parent 7c649085a1
commit 6662d5668c

View File

@@ -4,7 +4,7 @@ import mlx.nn as nn
from mlx_lm.sample_utils import top_p_sampling, make_sampler
import mlx.optimizers as optim
from ..inference_engine import InferenceEngine
from .sharded_utils import load_shard, get_image_from_str
from .sharded_utils import load_shard, load_model_shard, resolve_tokenizer
from .losses import loss_fns
from ..shard import Shard
from typing import Dict, Optional, Tuple
@@ -157,7 +157,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
return
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
if self.shard != shard:
model_shard, self.tokenizer = await load_shard(model_path, shard)
model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
if hasattr(model_shard, "tokenizer"):
self.tokenizer = model_shard.tokenizer
else:
self.tokenizer = await resolve_tokenizer(model_path)
self.shard = shard
self.model = model_shard
self.caches = OrderedDict()