mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix unit tests
This commit is contained in:
@@ -77,15 +77,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
await self.ensure_shard(shard)
|
||||
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
loop = asyncio.get_running_loop()
|
||||
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
||||
x = mx.array(input_data)
|
||||
if self.model.model_type != 'StableDiffusionPipeline':
|
||||
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
|
||||
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
|
||||
else:
|
||||
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
|
||||
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
|
||||
output_data = np.array(output_data)
|
||||
return output_data, inference_state
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
|
||||
token_full = await inference_engine_1.sample(resp_full)
|
||||
token_full = token_full.reshape(1, -1)
|
||||
next_resp_full = await inference_engine_1.infer_tensor(
|
||||
next_resp_full, _ = await inference_engine_1.infer_tensor(
|
||||
"A",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
|
||||
input_data=token_full,
|
||||
|
||||
@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
|
||||
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
|
||||
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
|
||||
|
||||
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
|
||||
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
|
||||
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
|
||||
models = []
|
||||
for model_id in model_cards:
|
||||
|
||||
Reference in New Issue
Block a user