fix treating token as a list

This commit is contained in:
Alex Cheema
2025-01-22 22:13:13 +00:00
parent 09e12d8673
commit 9954ce8e4d
5 changed files with 38 additions and 38 deletions

View File

@@ -408,16 +408,16 @@ class ChatGPTAPI:
# Stream tokens while waiting for inference to complete
while True:
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
token, is_finished = await asyncio.wait_for(
tokens, is_finished = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
if token == eos_token_id:
if tokens[-1] == eos_token_id:
if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
@@ -428,7 +428,7 @@ class ChatGPTAPI:
tokenizer,
prompt,
request_id,
[token],
tokens,
stream,
finish_reason,
"chat.completion",

View File

@@ -123,7 +123,7 @@ class GRPCPeerHandle(PeerHandle):
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response =await self.stub.SendTensor(request)
response = await self.stub.SendTensor(request)
if not response.tensor_data or not response.shape or not response.dtype:
return None

View File

@@ -3,11 +3,11 @@ syntax = "proto3";
package node_service;
service NodeService {
rpc SendPrompt (PromptRequest) returns (Empty) {}
rpc SendTensor (TensorRequest) returns (Empty) {}
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc SendExample (ExampleRequest) returns (Loss) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
rpc SendResult (SendResultRequest) returns (Empty) {}
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
}
@@ -95,7 +95,7 @@ message DeviceCapabilities {
DeviceFlops flops = 4;
}
message SendNewTokenRequest {
message SendResultRequest {
string request_id = 1;
repeated int32 result = 2;
optional Tensor tensor = 3;

View File

@@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x84\x01\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x82\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x97\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -75,16 +75,16 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_DEVICEFLOPS']._serialized_end=1822
_globals['_DEVICECAPABILITIES']._serialized_start=1824
_globals['_DEVICECAPABILITIES']._serialized_end=1931
_globals['_SENDNEWTOKENREQUEST']._serialized_start=1934
_globals['_SENDNEWTOKENREQUEST']._serialized_end=2066
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2068
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2129
_globals['_HEALTHCHECKREQUEST']._serialized_start=2131
_globals['_HEALTHCHECKREQUEST']._serialized_end=2151
_globals['_HEALTHCHECKRESPONSE']._serialized_start=2153
_globals['_HEALTHCHECKRESPONSE']._serialized_end=2194
_globals['_EMPTY']._serialized_start=2196
_globals['_EMPTY']._serialized_end=2203
_globals['_NODESERVICE']._serialized_start=2206
_globals['_NODESERVICE']._serialized_end=2743
_globals['_SENDRESULTREQUEST']._serialized_start=1934
_globals['_SENDRESULTREQUEST']._serialized_end=2064
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2066
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2127
_globals['_HEALTHCHECKREQUEST']._serialized_start=2129
_globals['_HEALTHCHECKREQUEST']._serialized_end=2149
_globals['_HEALTHCHECKRESPONSE']._serialized_start=2151
_globals['_HEALTHCHECKRESPONSE']._serialized_end=2192
_globals['_EMPTY']._serialized_start=2194
_globals['_EMPTY']._serialized_end=2201
_globals['_NODESERVICE']._serialized_start=2204
_globals['_NODESERVICE']._serialized_end=2739
# @@protoc_insertion_point(module_scope)

View File

@@ -37,12 +37,12 @@ class NodeServiceStub(object):
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendExample = channel.unary_unary(
'/node_service.NodeService/SendExample',
@@ -54,9 +54,9 @@ class NodeServiceStub(object):
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendNewToken = channel.unary_unary(
'/node_service.NodeService/SendNewToken',
request_serializer=node__service__pb2.SendNewTokenRequest.SerializeToString,
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
@@ -98,7 +98,7 @@ class NodeServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendNewToken(self, request, context):
def SendResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
@@ -122,12 +122,12 @@ def add_NodeServiceServicer_to_server(servicer, server):
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendExample': grpc.unary_unary_rpc_method_handler(
servicer.SendExample,
@@ -139,9 +139,9 @@ def add_NodeServiceServicer_to_server(servicer, server):
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
),
'SendNewToken': grpc.unary_unary_rpc_method_handler(
servicer.SendNewToken,
request_deserializer=node__service__pb2.SendNewTokenRequest.FromString,
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
@@ -181,7 +181,7 @@ class NodeService(object):
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Empty.FromString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -208,7 +208,7 @@ class NodeService(object):
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Empty.FromString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -274,7 +274,7 @@ class NodeService(object):
_registered_method=True)
@staticmethod
def SendNewToken(request,
def SendResult(request,
target,
options=(),
channel_credentials=None,
@@ -287,8 +287,8 @@ class NodeService(object):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendNewToken',
node__service__pb2.SendNewTokenRequest.SerializeToString,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,