This commit is contained in:
Alex Cheema
2025-01-12 03:40:59 +00:00
parent e7b98f5ae5
commit fcc699a55f
3 changed files with 5 additions and 5 deletions

View File

@@ -25,7 +25,7 @@ class DummyInferenceEngine(InferenceEngine):
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
return self.tokenizer.decode(tokens)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
return input_data + 1 if self.shard.is_last_layer() else input_data, None

View File

@@ -82,7 +82,7 @@ class GRPCPeerHandle(PeerHandle):
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendPrompt(request)
@@ -101,7 +101,7 @@ class GRPCPeerHandle(PeerHandle):
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendTensor(request)

View File

@@ -52,7 +52,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
)
prompt = request.prompt
request_id = request.request_id
inference_state = self.deserialize_inference_state(request.inference_state)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
@@ -68,7 +68,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id
inference_state = self.deserialize_inference_state(request.inference_state)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")