mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
handle inference_state properly
This commit is contained in:
@@ -23,7 +23,7 @@ class InferenceEngine(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -77,7 +77,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
await self.ensure_shard(shard)
|
||||
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
loop = asyncio.get_running_loop()
|
||||
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
||||
|
||||
Reference in New Issue
Block a user