mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
load mlx model shard on mlx thread so it doesnt block
This commit is contained in:
@@ -4,7 +4,7 @@ import mlx.nn as nn
|
||||
from mlx_lm.sample_utils import top_p_sampling, make_sampler
|
||||
import mlx.optimizers as optim
|
||||
from ..inference_engine import InferenceEngine
|
||||
from .sharded_utils import load_shard, get_image_from_str
|
||||
from .sharded_utils import load_shard, load_model_shard, resolve_tokenizer
|
||||
from .losses import loss_fns
|
||||
from ..shard import Shard
|
||||
from typing import Dict, Optional, Tuple
|
||||
@@ -157,7 +157,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
return
|
||||
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
||||
if self.shard != shard:
|
||||
model_shard, self.tokenizer = await load_shard(model_path, shard)
|
||||
model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
|
||||
if hasattr(model_shard, "tokenizer"):
|
||||
self.tokenizer = model_shard.tokenizer
|
||||
else:
|
||||
self.tokenizer = await resolve_tokenizer(model_path)
|
||||
self.shard = shard
|
||||
self.model = model_shard
|
||||
self.caches = OrderedDict()
|
||||
|
||||
Reference in New Issue
Block a user