make sure mlx stuff is on separate thread non blocking

This commit is contained in:
Alex Cheema
2025-01-28 18:56:00 +00:00
parent 6662d5668c
commit 4a5b80a958

View File

@@ -93,7 +93,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
)
output_data, inference_state = result
output_data = np.array(output_data, copy=False)
await self._eval_mlx(output_data)
output_data = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: np.array(output_data, copy=False)
)
return output_data, inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):