mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
@@ -29,15 +29,16 @@ class GRPCPeerHandle(PeerHandle):
|
||||
self.channel = None
|
||||
self.stub = None
|
||||
self.channel_options = [
|
||||
("grpc.max_metadata_size", 64 * 1024 * 1024),
|
||||
("grpc.max_metadata_size", 32 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_send_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_concurrent_streams", 100),
|
||||
("grpc.http2.min_time_between_pings_ms", 10000),
|
||||
("grpc.keepalive_time_ms", 20000),
|
||||
("grpc.keepalive_timeout_ms", 10000),
|
||||
("grpc.keepalive_time_ms", 10000),
|
||||
("grpc.keepalive_timeout_ms", 5000),
|
||||
("grpc.keepalive_permit_without_calls", 1),
|
||||
("grpc.http2.max_pings_without_data", 0),
|
||||
("grpc.http2.min_ping_interval_without_data_ms", 5000),
|
||||
("grpc.tcp_nodelay", 1),
|
||||
("grpc.optimization_target", "throughput"),
|
||||
]
|
||||
@@ -55,14 +56,13 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return self._device_capabilities
|
||||
|
||||
async def connect(self):
|
||||
if self.channel is None:
|
||||
self.channel = grpc.aio.insecure_channel(
|
||||
self.address,
|
||||
options=self.channel_options,
|
||||
compression=grpc.Compression.Gzip
|
||||
)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
await self.channel.channel_ready()
|
||||
self.channel = grpc.aio.insecure_channel(
|
||||
self.address,
|
||||
options=self.channel_options,
|
||||
compression=grpc.Compression.Gzip
|
||||
)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
await asyncio.wait_for(self.channel.channel_ready(), timeout=10.0)
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
|
||||
@@ -74,7 +74,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
self.stub = None
|
||||
|
||||
async def _ensure_connected(self):
|
||||
if not await self.is_connected():
|
||||
if not (await self.is_connected()):
|
||||
try:
|
||||
await asyncio.wait_for(self.connect(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -98,6 +98,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return False
|
||||
|
||||
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.PromptRequest(
|
||||
prompt=prompt,
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -112,6 +113,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
await self.stub.SendPrompt(request)
|
||||
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
@@ -131,6 +133,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.ExampleRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
@@ -153,6 +156,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return loss
|
||||
|
||||
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
@@ -171,6 +175,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
|
||||
response = await self.stub.CollectTopology(request)
|
||||
topology = Topology()
|
||||
@@ -185,6 +190,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return topology
|
||||
|
||||
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||
await self._ensure_connected()
|
||||
tensor = None
|
||||
if isinstance(result, np.ndarray):
|
||||
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
|
||||
@@ -193,8 +199,9 @@ class GRPCPeerHandle(PeerHandle):
|
||||
await self.stub.SendResult(request)
|
||||
|
||||
async def send_opaque_status(self, request_id: str, status: str) -> None:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
|
||||
await self.stub.SendOpaqueStatus(request)
|
||||
await asyncio.wait_for(self.stub.SendOpaqueStatus(request), timeout=10.0)
|
||||
|
||||
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
|
||||
proto_inference_state = node_service_pb2.InferenceState()
|
||||
|
||||
@@ -40,6 +40,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
("grpc.max_concurrent_streams", 100),
|
||||
("grpc.tcp_nodelay", 1),
|
||||
("grpc.optimization_target", "throughput"),
|
||||
("grpc.keepalive_permit_without_calls", 1),
|
||||
("grpc.http2.max_concurrent_streams", 0), # Unlimited concurrent streams
|
||||
],
|
||||
)
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
|
||||
Reference in New Issue
Block a user