adjust grpc settings, ensure connected before sending any grpc commands

This commit is contained in:
Alex Cheema
2025-02-28 20:52:12 +00:00
parent 36a6389af0
commit 4081305e60
2 changed files with 22 additions and 13 deletions

View File

@@ -29,15 +29,16 @@ class GRPCPeerHandle(PeerHandle):
self.channel = None
self.stub = None
self.channel_options = [
("grpc.max_metadata_size", 64 * 1024 * 1024),
("grpc.max_metadata_size", 32 * 1024 * 1024),
("grpc.max_receive_message_length", 256 * 1024 * 1024),
("grpc.max_send_message_length", 256 * 1024 * 1024),
("grpc.max_concurrent_streams", 100),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.keepalive_time_ms", 20000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_time_ms", 10000),
("grpc.keepalive_timeout_ms", 5000),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_pings_without_data", 0),
("grpc.http2.min_ping_interval_without_data_ms", 5000),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
]
@@ -55,14 +56,13 @@ class GRPCPeerHandle(PeerHandle):
return self._device_capabilities
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
self.channel = grpc.aio.insecure_channel(
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await asyncio.wait_for(self.channel.channel_ready(), timeout=10.0)
async def is_connected(self) -> bool:
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
@@ -74,7 +74,7 @@ class GRPCPeerHandle(PeerHandle):
self.stub = None
async def _ensure_connected(self):
if not await self.is_connected():
if not (await self.is_connected()):
try:
await asyncio.wait_for(self.connect(), timeout=10.0)
except asyncio.TimeoutError:
@@ -98,6 +98,7 @@ class GRPCPeerHandle(PeerHandle):
return False
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
@@ -112,6 +113,7 @@ class GRPCPeerHandle(PeerHandle):
await self.stub.SendPrompt(request)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
@@ -131,6 +133,7 @@ class GRPCPeerHandle(PeerHandle):
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.ExampleRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
@@ -153,6 +156,7 @@ class GRPCPeerHandle(PeerHandle):
return loss
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
@@ -171,6 +175,7 @@ class GRPCPeerHandle(PeerHandle):
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
await self._ensure_connected()
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
response = await self.stub.CollectTopology(request)
topology = Topology()
@@ -185,6 +190,7 @@ class GRPCPeerHandle(PeerHandle):
return topology
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
await self._ensure_connected()
tensor = None
if isinstance(result, np.ndarray):
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
@@ -193,8 +199,9 @@ class GRPCPeerHandle(PeerHandle):
await self.stub.SendResult(request)
async def send_opaque_status(self, request_id: str, status: str) -> None:
await self._ensure_connected()
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
await self.stub.SendOpaqueStatus(request)
await asyncio.wait_for(self.stub.SendOpaqueStatus(request), timeout=10.0)
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
proto_inference_state = node_service_pb2.InferenceState()

View File

@@ -40,6 +40,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
("grpc.max_concurrent_streams", 100),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_concurrent_streams", 0), # Unlimited concurrent streams
],
)
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)