mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix
This commit is contained in:
@@ -25,7 +25,7 @@ class DummyInferenceEngine(InferenceEngine):
|
||||
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
return input_data + 1 if self.shard.is_last_layer() else input_data, None
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
n_layers=shard.n_layers,
|
||||
),
|
||||
request_id=request_id,
|
||||
inference_state=self.serialize_inference_state(inference_state)
|
||||
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
|
||||
)
|
||||
response = await self.stub.SendPrompt(request)
|
||||
|
||||
@@ -101,7 +101,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
),
|
||||
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
|
||||
request_id=request_id,
|
||||
inference_state=self.serialize_inference_state(inference_state)
|
||||
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
|
||||
)
|
||||
response = await self.stub.SendTensor(request)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
)
|
||||
prompt = request.prompt
|
||||
request_id = request.request_id
|
||||
inference_state = self.deserialize_inference_state(request.inference_state)
|
||||
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
||||
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
|
||||
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
@@ -68,7 +68,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
||||
request_id = request.request_id
|
||||
|
||||
inference_state = self.deserialize_inference_state(request.inference_state)
|
||||
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
||||
|
||||
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
|
||||
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
||||
|
||||
Reference in New Issue
Block a user