mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
add support for multiple concurrent requests with request ids
This commit is contained in:
@@ -16,11 +16,16 @@ model_path = get_model_path(path_or_hf_repo)
|
||||
tokenizer_config = {}
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
peer = GRPCPeerHandle(
|
||||
peer1 = GRPCPeerHandle(
|
||||
"node1",
|
||||
"localhost:8080",
|
||||
DeviceCapabilities(model="test1", chip="test1", memory=10000)
|
||||
)
|
||||
peer2 = GRPCPeerHandle(
|
||||
"node2",
|
||||
"localhost:8081",
|
||||
DeviceCapabilities(model="test1", chip="test1", memory=10000)
|
||||
)
|
||||
shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
|
||||
|
||||
async def run_prompt(prompt: str):
|
||||
@@ -35,11 +40,30 @@ async def run_prompt(prompt: str):
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
await peer.connect()
|
||||
await peer.reset_shard(shard)
|
||||
for peer in [peer1, peer2]:
|
||||
await peer.connect()
|
||||
await peer.reset_shard(shard)
|
||||
|
||||
result = await peer.send_prompt(shard, prompt)
|
||||
print(tokenizer.decode(result))
|
||||
try:
|
||||
await peer1.send_prompt(shard, prompt, "request-id-1")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
import sys
|
||||
# poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
|
||||
previous_length = 0
|
||||
while True:
|
||||
result, is_finished = await peer2.get_inference_result("request-id-1")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Print the updated string in place
|
||||
updated_string = tokenizer.decode(result)
|
||||
print(updated_string[previous_length:], end='', flush=True)
|
||||
previous_length = len(updated_string)
|
||||
|
||||
if is_finished:
|
||||
print("\nDone")
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run prompt")
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import numpy as np
|
||||
import mlx.nn as nn
|
||||
|
||||
from typing import Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from .shard import Shard
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> Tuple[np.ndarray, bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> Tuple[np.ndarray, bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -99,7 +99,8 @@ class GRPCDiscovery(Discovery):
|
||||
peer_host = addr[0]
|
||||
peer_port = message['grpc_port']
|
||||
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
|
||||
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
|
||||
if peer_id not in self.known_peers:
|
||||
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
|
||||
self.peer_last_seen[peer_id] = time.time()
|
||||
except Exception as e:
|
||||
print(f"Error in peer discovery: {e}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import grpc
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
# These would be generated from the .proto file
|
||||
from . import node_service_pb2
|
||||
@@ -16,6 +16,8 @@ class GRPCPeerHandle(PeerHandle):
|
||||
self._id = id
|
||||
self.address = address
|
||||
self._device_capabilities = device_capabilities
|
||||
self.channel = None
|
||||
self.stub = None
|
||||
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
@@ -24,23 +26,30 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return self._device_capabilities
|
||||
|
||||
async def connect(self):
|
||||
self.channel = grpc.aio.insecure_channel(self.address)
|
||||
self.channel = grpc.aio.insecure_channel(self.address, options=[
|
||||
('grpc.max_metadata_size', 32*1024*1024)
|
||||
])
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
|
||||
async def disconnect(self):
|
||||
await self.channel.close()
|
||||
async def is_connected(self) -> bool:
|
||||
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
|
||||
|
||||
async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
||||
request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
|
||||
async def disconnect(self):
|
||||
if self.channel:
|
||||
await self.channel.close()
|
||||
self.channel = None
|
||||
self.stub = None
|
||||
|
||||
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id)
|
||||
response = await self.stub.SendPrompt(request)
|
||||
print(f"Sent prompt to {self.address}: {prompt}")
|
||||
|
||||
if not response.tensor_data or not response.shape or not response.dtype:
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.array]:
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
|
||||
tensor = node_service_pb2.Tensor(
|
||||
@@ -48,6 +57,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
shape=tensor.shape,
|
||||
dtype=str(tensor.dtype)
|
||||
),
|
||||
request_id=request_id
|
||||
)
|
||||
response = await self.stub.SendTensor(request)
|
||||
|
||||
@@ -56,10 +66,16 @@ class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
||||
request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
|
||||
response = await self.stub.GetInferenceResult(request)
|
||||
if response.tensor is None:
|
||||
return None, response.is_finished
|
||||
return np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), response.is_finished
|
||||
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
|
||||
await self.stub.ResetShard(request)
|
||||
print(f"Reset shard {shard} on {self.address}")
|
||||
|
||||
async def collect_topology(self, max_depth: int) -> Topology:
|
||||
request = node_service_pb2.CollectTopologyRequest(max_depth=max_depth)
|
||||
|
||||
@@ -8,6 +8,8 @@ from inference.shard import Shard
|
||||
|
||||
from orchestration import Node
|
||||
|
||||
import uuid
|
||||
|
||||
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
def __init__(self, node: Node, host: str, port: int):
|
||||
self.node = node
|
||||
@@ -17,7 +19,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
|
||||
async def start(self) -> None:
|
||||
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
|
||||
('grpc.max_metadata_size', 128*1024)
|
||||
('grpc.max_metadata_size', 32*1024*1024)
|
||||
])
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
listen_addr = f'{self.host}:{self.port}'
|
||||
@@ -27,23 +29,33 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.server:
|
||||
await self.server.stop(5) # 5 seconds grace period
|
||||
print("Server stopped")
|
||||
await self.server.stop(grace=5)
|
||||
await self.server.wait_for_termination()
|
||||
print("Server stopped and all connections are closed")
|
||||
|
||||
async def SendPrompt(self, request, context):
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
prompt = request.prompt
|
||||
result = await self.node.process_prompt(shard, prompt)
|
||||
request_id = request.request_id
|
||||
result = await self.node.process_prompt(shard, prompt, request_id)
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
async def SendTensor(self, request, context):
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
||||
result = await self.node.process_tensor(shard, tensor)
|
||||
request_id = request.request_id
|
||||
|
||||
result = await self.node.process_tensor(shard, tensor, request_id)
|
||||
print("SendTensor tensor result", result)
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
async def GetInferenceResult(self, request, context):
|
||||
request_id = request.request_id
|
||||
result = await self.node.get_inference_result(request_id)
|
||||
tensor_data = result[0].tobytes() if result[0] is not None else None
|
||||
return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype))) if result[0] is not None else node_service_pb2.InferenceResult()
|
||||
|
||||
async def ResetShard(self, request, context):
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
|
||||
@@ -6,6 +6,7 @@ service NodeService {
|
||||
rpc SendPrompt (PromptRequest) returns (Tensor) {}
|
||||
rpc SendTensor (TensorRequest) returns (Tensor) {}
|
||||
rpc ResetShard (ResetShardRequest) returns (Empty) {}
|
||||
rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
|
||||
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
|
||||
}
|
||||
|
||||
@@ -19,11 +20,22 @@ message Shard {
|
||||
message PromptRequest {
|
||||
Shard shard = 1;
|
||||
string prompt = 2;
|
||||
optional string request_id = 3;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
Shard shard = 1;
|
||||
Tensor tensor = 2;
|
||||
optional string request_id = 3;
|
||||
}
|
||||
|
||||
message GetInferenceResultRequest {
|
||||
string request_id = 1;
|
||||
}
|
||||
|
||||
message InferenceResult {
|
||||
optional Tensor tensor = 1;
|
||||
bool is_finished = 2;
|
||||
}
|
||||
|
||||
message Tensor {
|
||||
|
||||
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"C\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\"Y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"+\n\x16\x43ollectTopologyRequest\x12\x11\n\tmax_depth\x18\x01 \x01(\x05\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"A\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\"\x07\n\x05\x45mpty2\xac\x02\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"+\n\x16\x43ollectTopologyRequest\x12\x11\n\tmax_depth\x18\x01 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"A\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\"\x07\n\x05\x45mpty2\x8c\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
@@ -28,27 +28,31 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=121
|
||||
_globals['_PROMPTREQUEST']._serialized_end=188
|
||||
_globals['_TENSORREQUEST']._serialized_start=190
|
||||
_globals['_TENSORREQUEST']._serialized_end=279
|
||||
_globals['_TENSOR']._serialized_start=281
|
||||
_globals['_TENSOR']._serialized_end=340
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=342
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=397
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=399
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=442
|
||||
_globals['_PEERS']._serialized_start=444
|
||||
_globals['_PEERS']._serialized_end=469
|
||||
_globals['_TOPOLOGY']._serialized_start=472
|
||||
_globals['_TOPOLOGY']._serialized_end=742
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=593
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=671
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=673
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=742
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=744
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=809
|
||||
_globals['_EMPTY']._serialized_start=811
|
||||
_globals['_EMPTY']._serialized_end=818
|
||||
_globals['_NODESERVICE']._serialized_start=821
|
||||
_globals['_NODESERVICE']._serialized_end=1121
|
||||
_globals['_PROMPTREQUEST']._serialized_end=228
|
||||
_globals['_TENSORREQUEST']._serialized_start=231
|
||||
_globals['_TENSORREQUEST']._serialized_end=360
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=362
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=409
|
||||
_globals['_INFERENCERESULT']._serialized_start=411
|
||||
_globals['_INFERENCERESULT']._serialized_end=503
|
||||
_globals['_TENSOR']._serialized_start=505
|
||||
_globals['_TENSOR']._serialized_end=564
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=566
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=621
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=623
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=666
|
||||
_globals['_TOPOLOGY']._serialized_start=669
|
||||
_globals['_TOPOLOGY']._serialized_end=939
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=790
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=868
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=870
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=939
|
||||
_globals['_PEERS']._serialized_start=941
|
||||
_globals['_PEERS']._serialized_end=966
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=968
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=1033
|
||||
_globals['_EMPTY']._serialized_start=1035
|
||||
_globals['_EMPTY']._serialized_end=1042
|
||||
_globals['_NODESERVICE']._serialized_start=1045
|
||||
_globals['_NODESERVICE']._serialized_end=1441
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -54,6 +54,11 @@ class NodeServiceStub(object):
|
||||
request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.GetInferenceResult = channel.unary_unary(
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.InferenceResult.FromString,
|
||||
_registered_method=True)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
@@ -82,6 +87,12 @@ class NodeServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetInferenceResult(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CollectTopology(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
@@ -106,6 +117,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=node__service__pb2.ResetShardRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetInferenceResult,
|
||||
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
|
||||
),
|
||||
'CollectTopology': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
@@ -203,6 +219,33 @@ class NodeService(object):
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetInferenceResult(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
node__service__pb2.InferenceResult.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def CollectTopology(request,
|
||||
target,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from inference.shard import Shard
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
@@ -18,16 +18,24 @@ class PeerHandle(ABC):
|
||||
async def connect(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def is_connected(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
||||
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_tensor(self, shard: Shard, tensor: np.array) -> Optional[np.array]:
|
||||
async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from inference.shard import Shard
|
||||
@@ -25,5 +25,10 @@ class Node(ABC):
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def collect_topology(self, max_depth: int = 2) -> Topology:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Callable
|
||||
from typing import List, Dict, Optional, Callable, Tuple
|
||||
import numpy as np
|
||||
from networking import Discovery, PeerHandle, Server
|
||||
from inference.inference_engine import InferenceEngine, Shard
|
||||
@@ -7,6 +7,8 @@ from topology.topology import Topology
|
||||
from topology.device_capabilities import device_capabilities
|
||||
from topology.partitioning_strategy import PartitioningStrategy
|
||||
from topology.partitioning_strategy import Partition
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
class StandardNode(Node):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
|
||||
@@ -18,54 +20,70 @@ class StandardNode(Node):
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.topology: Topology = Topology()
|
||||
self.device_capabilities = device_capabilities()
|
||||
self.buffered_token_output: List[int] = []
|
||||
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
|
||||
self.on_token = on_token
|
||||
self.max_generate_tokens = max_generate_tokens
|
||||
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
await self.server.start()
|
||||
await self.discovery.start()
|
||||
self.peers = await self.discovery.discover_peers(wait_for_peers)
|
||||
print(f"Starting with the following peers: {self.peers}")
|
||||
print("Connecting to peers...")
|
||||
for peer in self.peers:
|
||||
await peer.connect()
|
||||
print(f"Connected to {peer.id()}")
|
||||
await self.update_peers(wait_for_peers)
|
||||
await self.collect_topology()
|
||||
print(f"Collected topology: {self.topology}")
|
||||
asyncio.create_task(self.periodic_topology_collection(5))
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.discovery.stop()
|
||||
await self.server.stop()
|
||||
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
|
||||
print("process prompt", shard, prompt)
|
||||
async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
if request_id not in self.buffered_token_output:
|
||||
self.buffered_token_output[request_id] = ([], False)
|
||||
|
||||
print(f"[{request_id}] process prompt: {shard}, {prompt}")
|
||||
result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
||||
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
|
||||
|
||||
print(f"result size: {result.size}, is finished: {is_finished}")
|
||||
if result.size == 1:
|
||||
self.buffered_token_output.append(result.item())
|
||||
self.on_token(self.buffered_token_output)
|
||||
self.buffered_token_output[request_id][0].append(result.item())
|
||||
self.on_token(self.buffered_token_output[request_id][0])
|
||||
|
||||
if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
||||
|
||||
return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
||||
if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
|
||||
asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
|
||||
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
|
||||
result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
||||
return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
|
||||
|
||||
print(f"result size: {result.size}, is finished: {is_finished}")
|
||||
if result.size == 1:
|
||||
self.buffered_token_output.append(result.item())
|
||||
self.on_token(self.buffered_token_output)
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
if request_id not in self.buffered_token_output:
|
||||
self.buffered_token_output[request_id] = ([], False)
|
||||
|
||||
if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
try:
|
||||
print(f"[{request_id}] process_tensor: {shard}, {tensor}")
|
||||
result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
||||
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
|
||||
|
||||
return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
||||
if result.size == 1: # we got a new token out
|
||||
self.buffered_token_output[request_id][0].append(result.item())
|
||||
self.on_token(self.buffered_token_output[request_id][0])
|
||||
print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
||||
|
||||
async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
|
||||
asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
|
||||
|
||||
return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"Error processing tensor for shard {shard}: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
|
||||
if not self.partitioning_strategy:
|
||||
print("No partitioning strategy found. Skipping forward.")
|
||||
return
|
||||
@@ -80,7 +98,7 @@ class StandardNode(Node):
|
||||
|
||||
if next_partition:
|
||||
if next_partition.node_id == self.id:
|
||||
await self.process_tensor(shard, tensor)
|
||||
await self.process_tensor(shard, tensor, request_id)
|
||||
return
|
||||
|
||||
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
||||
@@ -91,9 +109,9 @@ class StandardNode(Node):
|
||||
end_layer = int(next_partition.end * shard.n_layers) - 1
|
||||
next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
||||
|
||||
print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}")
|
||||
print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}: {tensor}")
|
||||
|
||||
await target_peer.send_tensor(next_shard, tensor)
|
||||
await target_peer.send_tensor(next_shard, tensor, request_id)
|
||||
|
||||
def get_current_shard(self, shard: Shard) -> Shard:
|
||||
partitions = self.partitioning_strategy.partition(self.topology)
|
||||
@@ -110,9 +128,20 @@ class StandardNode(Node):
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
# Implement shard reset logic
|
||||
print(f"Resetting shard: {shard}")
|
||||
self.buffered_token_output = []
|
||||
self.buffered_token_output = {}
|
||||
await self.inference_engine.reset_shard(self.get_current_shard(shard))
|
||||
|
||||
async def update_peers(self, wait_for_peers: int = 0) -> None:
|
||||
self.peers = await self.discovery.discover_peers(wait_for_peers)
|
||||
print(f"Starting with the following peers: {self.peers}")
|
||||
print("Connecting to new peers...")
|
||||
for peer in self.peers:
|
||||
is_connected = await peer.is_connected()
|
||||
print(f"Connected to {peer.id()}: {is_connected}")
|
||||
if not is_connected:
|
||||
await peer.connect()
|
||||
print(f"Connected to peer {peer.id()}")
|
||||
|
||||
async def collect_topology(self, max_depth: int = 4) -> Topology:
|
||||
self.topology.update_node(self.id, self.device_capabilities)
|
||||
|
||||
@@ -121,8 +150,28 @@ class StandardNode(Node):
|
||||
self.topology.add_edge(self.id, peer.id())
|
||||
|
||||
if max_depth > 0:
|
||||
other_topology = await peer.collect_topology(max_depth = max_depth - 1)
|
||||
print(f"Collected topology from: {peer.id()}: {other_topology}")
|
||||
self.topology.merge(other_topology)
|
||||
try:
|
||||
other_topology = await peer.collect_topology(max_depth = max_depth - 1)
|
||||
print(f"Collected topology from: {peer.id()}: {other_topology}")
|
||||
self.topology.merge(other_topology)
|
||||
except Exception as e:
|
||||
print(f"Error collecting topology from {peer.id()}: {e}")
|
||||
|
||||
return self.topology
|
||||
|
||||
async def periodic_topology_collection(self, interval: int):
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
try:
|
||||
await self.update_peers()
|
||||
await self.collect_topology()
|
||||
except Exception as e:
|
||||
print(f"Error collecting topology: {e}")
|
||||
|
||||
print("Topology collection task executed.")
|
||||
print(f"Current topology: {self.topology}")
|
||||
|
||||
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
||||
if request_id not in self.buffered_token_output:
|
||||
return None, False
|
||||
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
||||
|
||||
Reference in New Issue
Block a user