mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix inference engine for inference state
This commit is contained in:
@@ -43,9 +43,11 @@ class InferenceEngine(ABC):
|
||||
tokens = await self.encode(shard, prompt)
|
||||
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||
x = tokens.reshape(1, -1)
|
||||
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
|
||||
else:
|
||||
x = tokens
|
||||
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
|
||||
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
|
||||
|
||||
return output_data, inference_state
|
||||
|
||||
inference_engine_classes = {
|
||||
|
||||
@@ -82,7 +82,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
loop = asyncio.get_running_loop()
|
||||
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
||||
x = mx.array(input_data)
|
||||
output_data,inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
|
||||
if self.model.model_type != 'StableDiffusionPipeline':
|
||||
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
|
||||
else:
|
||||
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
|
||||
output_data = np.array(output_data)
|
||||
return output_data, inference_state
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -3,7 +3,7 @@
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from exo.networking.grpc import node_service_pb2 as node__service__pb2
|
||||
from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.68.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
@@ -18,7 +18,7 @@ except ImportError:
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in node_service_pb2_grpc.py depends on'
|
||||
+ f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
@@ -36,43 +36,43 @@ class NodeServiceStub(object):
|
||||
"""
|
||||
self.SendPrompt = channel.unary_unary(
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTensor = channel.unary_unary(
|
||||
'/node_service.NodeService/SendTensor',
|
||||
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.SendExample = channel.unary_unary(
|
||||
'/node_service.NodeService/SendExample',
|
||||
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Loss.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
|
||||
_registered_method=True)
|
||||
self.GetInferenceResult = channel.unary_unary(
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.InferenceResult.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
|
||||
_registered_method=True)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Topology.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
|
||||
_registered_method=True)
|
||||
self.SendResult = channel.unary_unary(
|
||||
'/node_service.NodeService/SendResult',
|
||||
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendOpaqueStatus = channel.unary_unary(
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.HealthCheck = channel.unary_unary(
|
||||
'/node_service.NodeService/HealthCheck',
|
||||
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
|
||||
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
@@ -132,43 +132,43 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendPrompt': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPrompt,
|
||||
request_deserializer=node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendTensor': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTensor,
|
||||
request_deserializer=node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendExample': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendExample,
|
||||
request_deserializer=node__service__pb2.ExampleRequest.FromString,
|
||||
response_serializer=node__service__pb2.Loss.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
|
||||
),
|
||||
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetInferenceResult,
|
||||
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
|
||||
),
|
||||
'CollectTopology': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=node__service__pb2.Topology.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
|
||||
),
|
||||
'SendResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendResult,
|
||||
request_deserializer=node__service__pb2.SendResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendOpaqueStatus,
|
||||
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'HealthCheck': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.HealthCheck,
|
||||
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
|
||||
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
@@ -196,8 +196,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
node__service__pb2.PromptRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -223,8 +223,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendTensor',
|
||||
node__service__pb2.TensorRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -250,8 +250,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendExample',
|
||||
node__service__pb2.ExampleRequest.SerializeToString,
|
||||
node__service__pb2.Loss.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -277,8 +277,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
node__service__pb2.InferenceResult.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -304,8 +304,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
node__service__pb2.Topology.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -331,8 +331,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendResult',
|
||||
node__service__pb2.SendResultRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -358,8 +358,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -385,8 +385,8 @@ class NodeService(object):
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/HealthCheck',
|
||||
node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
node__service__pb2.HealthCheckResponse.FromString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
|
||||
Reference in New Issue
Block a user