dont broadcast every single process_tensor

This commit is contained in:
Alex Cheema
2024-12-16 20:54:38 +00:00
parent 35d90d947c
commit b17faa8199

View File

@@ -156,6 +156,7 @@ class Node:
request_id: Optional[str] = None,
) -> None:
shard = self.get_current_shard(base_shard)
start_time = time.perf_counter_ns()
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
@@ -170,7 +171,6 @@ class Node:
}),
)
)
start_time = time.perf_counter_ns()
await self._process_prompt(base_shard, prompt, request_id)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
@@ -351,39 +351,11 @@ class Node:
request_id: Optional[str] = None,
) -> None:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "start_process_tensor",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"tensor_size": tensor.size,
"tensor_shape": tensor.shape,
"request_id": request_id,
}),
)
)
start_time = time.perf_counter_ns()
await self._process_tensor(shard, tensor, request_id)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "end_process_tensor",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"request_id": request_id,
"elapsed_time_ns": elapsed_time_ns,
}),
)
)
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
async def _process_tensor(
self,
@@ -395,7 +367,6 @@ class Node:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
try:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)