handle inference_state properly

This commit is contained in:
Alex Cheema
2025-01-12 03:13:17 +00:00
parent 2af5ee02e4
commit 2aed3f3518
2 changed files with 2 additions and 2 deletions

View File

@@ -23,7 +23,7 @@ class InferenceEngine(ABC):
pass
@abstractmethod
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
pass
@abstractmethod

View File

@@ -77,7 +77,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> np.ndarray:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}