fix inference engine for inference state

This commit is contained in:
Pranav Veldurthi
2024-12-30 18:36:53 -05:00
parent 54605299b8
commit fff8a1a690
4 changed files with 113 additions and 100 deletions

View File

@@ -43,9 +43,11 @@ class InferenceEngine(ABC):
tokens = await self.encode(shard, prompt)
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
else:
x = tokens
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
return output_data, inference_state
inference_engine_classes = {

View File

@@ -82,7 +82,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
output_data,inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
if self.model.model_type != 'StableDiffusionPipeline':
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
else:
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
output_data = np.array(output_data)
return output_data, inference_state

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,7 @@
import grpc
import warnings
from exo.networking.grpc import node_service_pb2 as node__service__pb2
from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
GRPC_GENERATED_VERSION = '1.68.0'
GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in node_service_pb2_grpc.py depends on'
+ f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -36,43 +36,43 @@ class NodeServiceStub(object):
"""
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendExample = channel.unary_unary(
'/node_service.NodeService/SendExample',
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=node__service__pb2.Loss.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
_registered_method=True)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=node__service__pb2.InferenceResult.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
_registered_method=True)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
'/node_service.NodeService/SendOpaqueStatus',
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
_registered_method=True)
self.HealthCheck = channel.unary_unary(
'/node_service.NodeService/HealthCheck',
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
_registered_method=True)
@@ -132,43 +132,43 @@ def add_NodeServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
),
'SendExample': grpc.unary_unary_rpc_method_handler(
servicer.SendExample,
request_deserializer=node__service__pb2.ExampleRequest.FromString,
response_serializer=node__service__pb2.Loss.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
),
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
),
'CollectTopology': grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
),
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
servicer.SendOpaqueStatus,
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
),
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@@ -196,8 +196,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -223,8 +223,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -250,8 +250,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendExample',
node__service__pb2.ExampleRequest.SerializeToString,
node__service__pb2.Loss.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
options,
channel_credentials,
insecure,
@@ -277,8 +277,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/GetInferenceResult',
node__service__pb2.GetInferenceResultRequest.SerializeToString,
node__service__pb2.InferenceResult.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
options,
channel_credentials,
insecure,
@@ -304,8 +304,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/CollectTopology',
node__service__pb2.CollectTopologyRequest.SerializeToString,
node__service__pb2.Topology.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
options,
channel_credentials,
insecure,
@@ -331,8 +331,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -358,8 +358,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendOpaqueStatus',
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
node__service__pb2.Empty.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -385,8 +385,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/HealthCheck',
node__service__pb2.HealthCheckRequest.SerializeToString,
node__service__pb2.HealthCheckResponse.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,