mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
dont broadcast every single process_tensor
This commit is contained in:
@@ -156,6 +156,7 @@ class Node:
|
||||
request_id: Optional[str] = None,
|
||||
) -> None:
|
||||
shard = self.get_current_shard(base_shard)
|
||||
start_time = time.perf_counter_ns()
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
@@ -170,7 +171,6 @@ class Node:
|
||||
}),
|
||||
)
|
||||
)
|
||||
start_time = time.perf_counter_ns()
|
||||
await self._process_prompt(base_shard, prompt, request_id)
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_time_ns = end_time - start_time
|
||||
@@ -351,39 +351,11 @@ class Node:
|
||||
request_id: Optional[str] = None,
|
||||
) -> None:
|
||||
shard = self.get_current_shard(base_shard)
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "start_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"tensor_size": tensor.size,
|
||||
"tensor_shape": tensor.shape,
|
||||
"request_id": request_id,
|
||||
}),
|
||||
)
|
||||
)
|
||||
start_time = time.perf_counter_ns()
|
||||
await self._process_tensor(shard, tensor, request_id)
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_time_ns = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "end_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
}),
|
||||
)
|
||||
)
|
||||
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
|
||||
|
||||
async def _process_tensor(
|
||||
self,
|
||||
@@ -395,7 +367,6 @@ class Node:
|
||||
request_id = str(uuid.uuid4())
|
||||
shard = self.get_current_shard(base_shard)
|
||||
|
||||
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
||||
try:
|
||||
self.outstanding_requests[request_id] = "processing"
|
||||
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
||||
|
||||
Reference in New Issue
Block a user