mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
scaffolding for networking, inference and orchestration
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
__pycache__/
|
||||
.venv
|
||||
31
inference/inference_engine.py
Normal file
31
inference/inference_engine.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import numpy as np
|
||||
import mlx.nn as nn
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from .shard import Shard
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset_shard(self, shard: Shard):
|
||||
pass
|
||||
|
||||
class MLXFixedShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, model: nn.Module, shard: Shard):
|
||||
self.model = model
|
||||
self.shard = shard
|
||||
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
output_data = self.model.process(input_data)
|
||||
print("Processed data through model shard")
|
||||
return output_data
|
||||
|
||||
async def reset_shard(self, shard: Shard):
|
||||
# TODO
|
||||
print(f"Resetting shard: {shard}")
|
||||
8
inference/shard.py
Normal file
8
inference/shard.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Shard:
|
||||
model_id: str
|
||||
n_layers: int
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
71
main.py
Normal file
71
main.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import signal
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from orchestration.standard_node import StandardNode
|
||||
from networking.grpc.grpc_server import GRPCServer
|
||||
from inference.inference_engine import MLXFixedShardInferenceEngine
|
||||
from inference.shard import Shard
|
||||
from networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
|
||||
class SimpleMLXModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleMLXModel, self).__init__()
|
||||
self.linear = nn.Linear(10, 5) # Example dimensions
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
|
||||
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
||||
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
args = parser.parse_args()
|
||||
|
||||
mlx_model = SimpleMLXModel()
|
||||
inference_engine = MLXFixedShardInferenceEngine(mlx_model, shard=Shard(model_id="test", n_layers=32, start_layer=0, end_layer=31))
|
||||
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
|
||||
node = StandardNode(args.node_id, None, inference_engine, discovery)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in server_tasks]
|
||||
print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
||||
await asyncio.gather(*server_tasks, return_exceptions=True)
|
||||
await server.shutdown()
|
||||
loop.stop()
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Use a more direct approach to handle signals
|
||||
def handle_exit():
|
||||
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
|
||||
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
|
||||
await node.start()
|
||||
|
||||
await asyncio.sleep(5)
|
||||
print("Sending reset shard request")
|
||||
await node.peers[0].reset_shard(f"regards from {node.id}")
|
||||
|
||||
await asyncio.Event().wait()
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
5
networking/__init__.py
Normal file
5
networking/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .discovery import Discovery
|
||||
from .peer_handle import PeerHandle
|
||||
from .server import Server
|
||||
|
||||
__all__ = ['Discovery', 'PeerHandle', 'Server']
|
||||
16
networking/discovery.py
Normal file
16
networking/discovery.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from .peer_handle import PeerHandle
|
||||
|
||||
class Discovery(ABC):
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
pass
|
||||
0
networking/grpc/__init__.py
Normal file
0
networking/grpc/__init__.py
Normal file
105
networking/grpc/grpc_discovery.py
Normal file
105
networking/grpc/grpc_discovery.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import asyncio
|
||||
import json
|
||||
import socket
|
||||
import time
|
||||
from typing import List, Dict
|
||||
from ..discovery import Discovery
|
||||
from ..peer_handle import PeerHandle
|
||||
from .grpc_peer_handle import GRPCPeerHandle
|
||||
|
||||
class GRPCDiscovery(Discovery):
|
||||
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1):
|
||||
self.node_id = node_id
|
||||
self.node_port = node_port
|
||||
self.listen_port = listen_port
|
||||
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
|
||||
self.broadcast_interval = broadcast_interval
|
||||
self.known_peers: Dict[str, GRPCPeerHandle] = {}
|
||||
self.peer_last_seen: Dict[str, float] = {}
|
||||
self.broadcast_task = None
|
||||
self.listen_task = None
|
||||
self.cleanup_task = None
|
||||
|
||||
async def start(self):
|
||||
self.broadcast_task = asyncio.create_task(self._broadcast_presence())
|
||||
self.listen_task = asyncio.create_task(self._listen_for_peers())
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_peers())
|
||||
|
||||
async def stop(self):
|
||||
if self.broadcast_task:
|
||||
self.broadcast_task.cancel()
|
||||
if self.listen_task:
|
||||
self.listen_task.cancel()
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
if self.broadcast_task or self.listen_task or self.cleanup_task:
|
||||
await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
|
||||
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
print("Starting peer discovery process...")
|
||||
|
||||
if wait_for_peers > 0:
|
||||
while not self.known_peers:
|
||||
print("No peers discovered yet, retrying in 1 second...")
|
||||
await asyncio.sleep(1) # Keep trying to find peers
|
||||
print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
|
||||
|
||||
grace_period = 5 # seconds
|
||||
while True:
|
||||
initial_peer_count = len(self.known_peers)
|
||||
print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
|
||||
await asyncio.sleep(grace_period)
|
||||
if len(self.known_peers) == initial_peer_count:
|
||||
if wait_for_peers > 0:
|
||||
print(f"Waiting additional {wait_for_peers} seconds for more peers.")
|
||||
await asyncio.sleep(wait_for_peers)
|
||||
wait_for_peers = 0
|
||||
else:
|
||||
print("No new peers discovered in the last grace period. Ending discovery process.")
|
||||
break # No new peers found in the grace period, we are done
|
||||
|
||||
return list(self.known_peers.values())
|
||||
|
||||
async def _broadcast_presence(self):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.settimeout(0.5)
|
||||
message = json.dumps({
|
||||
"type": "discovery",
|
||||
"node_id": self.node_id,
|
||||
"grpc_port": self.node_port
|
||||
}).encode('utf-8')
|
||||
|
||||
while True:
|
||||
sock.sendto(message, ('<broadcast>', self.broadcast_port))
|
||||
await asyncio.sleep(self.broadcast_interval)
|
||||
|
||||
async def _listen_for_peers(self):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.bind(('', self.listen_port))
|
||||
sock.setblocking(False)
|
||||
|
||||
while True:
|
||||
try:
|
||||
data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
|
||||
message = json.loads(data.decode('utf-8'))
|
||||
if message['type'] == 'discovery' and message['node_id'] != self.node_id:
|
||||
peer_id = message['node_id']
|
||||
peer_host = addr[0]
|
||||
peer_port = message['grpc_port']
|
||||
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}")
|
||||
self.peer_last_seen[peer_id] = time.time()
|
||||
except Exception as e:
|
||||
print(f"Error in peer discovery: {e}")
|
||||
await asyncio.sleep(self.broadcast_interval / 2)
|
||||
|
||||
async def _cleanup_peers(self):
|
||||
while True:
|
||||
current_time = time.time()
|
||||
timeout = 5 * self.broadcast_interval
|
||||
peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
|
||||
for peer_id in peers_to_remove:
|
||||
del self.known_peers[peer_id]
|
||||
del self.peer_last_seen[peer_id]
|
||||
print(f"Removed peer {peer_id} due to inactivity.")
|
||||
await asyncio.sleep(self.broadcast_interval)
|
||||
47
networking/grpc/grpc_peer_handle.py
Normal file
47
networking/grpc/grpc_peer_handle.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import grpc
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
# These would be generated from the .proto file
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
|
||||
from ..peer_handle import PeerHandle
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, id: str, address: str):
|
||||
self._id = id
|
||||
self.address = address
|
||||
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
async def connect(self):
|
||||
self.channel = grpc.aio.insecure_channel(self.address)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
|
||||
async def disconnect(self):
|
||||
await self.channel.close()
|
||||
|
||||
async def send_prompt(self, prompt: str) -> None:
|
||||
request = node_service_pb2.PromptRequest(prompt=prompt)
|
||||
await self.stub.SendPrompt(request)
|
||||
print(f"Sent prompt to {self.address}: {prompt}")
|
||||
|
||||
async def send_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
tensor_data=tensor.tobytes(),
|
||||
shape=tensor.shape,
|
||||
dtype=str(tensor.dtype),
|
||||
target=target
|
||||
)
|
||||
await self.stub.SendTensor(request)
|
||||
if target:
|
||||
print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
|
||||
else:
|
||||
print(f"Sent tensor to {self.address}: shape {tensor.shape}")
|
||||
|
||||
async def reset_shard(self, shard_id: str) -> None:
|
||||
request = node_service_pb2.ResetShardRequest(shard_id=shard_id)
|
||||
await self.stub.ResetShard(request)
|
||||
print(f"Reset shard {shard_id} on {self.address}")
|
||||
57
networking/grpc/grpc_server.py
Normal file
57
networking/grpc/grpc_server.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import numpy as np
|
||||
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
|
||||
from orchestration import Node
|
||||
|
||||
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
def __init__(self, node: Node, host: str, port: int):
|
||||
self.node = node
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.server = None
|
||||
|
||||
async def start(self) -> None:
|
||||
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
listen_addr = f'{self.host}:{self.port}'
|
||||
self.server.add_insecure_port(listen_addr)
|
||||
await self.server.start()
|
||||
print(f"Server started, listening on {listen_addr}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.server:
|
||||
await self.server.stop(5) # 5 seconds grace period
|
||||
print("Server stopped")
|
||||
|
||||
async def SendPrompt(self, request, context):
|
||||
prompt = request.prompt
|
||||
target = request.target if request.HasField('target') else None
|
||||
if target and target != self.node.node_id:
|
||||
await self.node.process_prompt(prompt, target)
|
||||
else:
|
||||
# Process the prompt locally
|
||||
# You'd need to implement this method in the Node class
|
||||
await self.node.process_prompt(prompt)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
async def SendTensor(self, request, context):
|
||||
tensor = np.frombuffer(request.tensor_data, dtype=np.dtype(request.dtype)).reshape(request.shape)
|
||||
target = request.target if request.HasField('target') else None
|
||||
if target and target != self.node.node_id:
|
||||
await self.node.process_tensor(tensor, target)
|
||||
else:
|
||||
# Process the tensor locally
|
||||
await self.node.inference_strategy.process_inference(tensor)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
async def ResetShard(self, request, context):
|
||||
print(f"Received ResetShard request: {request}")
|
||||
# TODO
|
||||
# shard_id = request.shard_id
|
||||
# You'd need to implement this method in the Node class
|
||||
# await self.node.reset_shard(shard_id)
|
||||
return node_service_pb2.Empty()
|
||||
27
networking/grpc/node_service.proto
Normal file
27
networking/grpc/node_service.proto
Normal file
@@ -0,0 +1,27 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package node_service;
|
||||
|
||||
service NodeService {
|
||||
rpc SendPrompt (PromptRequest) returns (Empty) {}
|
||||
rpc SendTensor (TensorRequest) returns (Empty) {}
|
||||
rpc ResetShard (ResetShardRequest) returns (Empty) {}
|
||||
}
|
||||
|
||||
message PromptRequest {
|
||||
string prompt = 1;
|
||||
optional string target = 2;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
bytes tensor_data = 1;
|
||||
repeated int32 shape = 2;
|
||||
string dtype = 3;
|
||||
optional string target = 4;
|
||||
}
|
||||
|
||||
message ResetShardRequest {
|
||||
string shard_id = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
34
networking/grpc/node_service_pb2.py
Normal file
34
networking/grpc/node_service_pb2.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: node_service.proto
|
||||
# Protobuf Python Version: 5.26.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"?\n\rPromptRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x13\n\x06target\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"b\n\rTensorRequest\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\x12\x13\n\x06target\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"%\n\x11ResetShardRequest\x12\x10\n\x08shard_id\x18\x01 \x01(\t\"\x07\n\x05\x45mpty2\xd7\x01\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_PROMPTREQUEST']._serialized_start=36
|
||||
_globals['_PROMPTREQUEST']._serialized_end=99
|
||||
_globals['_TENSORREQUEST']._serialized_start=101
|
||||
_globals['_TENSORREQUEST']._serialized_end=199
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=201
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=238
|
||||
_globals['_EMPTY']._serialized_start=240
|
||||
_globals['_EMPTY']._serialized_end=247
|
||||
_globals['_NODESERVICE']._serialized_start=250
|
||||
_globals['_NODESERVICE']._serialized_end=465
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
188
networking/grpc/node_service_pb2_grpc.py
Normal file
188
networking/grpc/node_service_pb2_grpc.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import node_service_pb2 as node__service__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.64.1'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
EXPECTED_ERROR_RELEASE = '1.65.0'
|
||||
SCHEDULED_RELEASE_DATE = 'June 25, 2024'
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
warnings.warn(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in node_service_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
|
||||
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
|
||||
RuntimeWarning
|
||||
)
|
||||
|
||||
|
||||
class NodeServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendPrompt = channel.unary_unary(
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTensor = channel.unary_unary(
|
||||
'/node_service.NodeService/SendTensor',
|
||||
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.ResetShard = channel.unary_unary(
|
||||
'/node_service.NodeService/ResetShard',
|
||||
request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class NodeServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def SendPrompt(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTensor(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ResetShard(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_NodeServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendPrompt': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPrompt,
|
||||
request_deserializer=node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendTensor': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTensor,
|
||||
request_deserializer=node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'ResetShard': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ResetShard,
|
||||
request_deserializer=node__service__pb2.ResetShardRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'node_service.NodeService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class NodeService(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def SendPrompt(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
node__service__pb2.PromptRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendTensor(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendTensor',
|
||||
node__service__pb2.TensorRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def ResetShard(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/ResetShard',
|
||||
node__service__pb2.ResetShardRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
21
networking/grpc/test_grpc_discovery.py
Normal file
21
networking/grpc/test_grpc_discovery.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from .grpc_discovery import GRPCDiscovery
|
||||
|
||||
class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
|
||||
self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
|
||||
await self.node1.start()
|
||||
await self.node2.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.node1.stop()
|
||||
await self.node2.stop()
|
||||
|
||||
async def test_discovery(self):
|
||||
await asyncio.sleep(4)
|
||||
|
||||
# Check discovered peers
|
||||
print("Node1 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
|
||||
print("Node2 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))
|
||||
26
networking/peer_handle.py
Normal file
26
networking/peer_handle.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
class PeerHandle(ABC):
|
||||
def id(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_prompt(self, prompt: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_tensor(self, tensor: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset_shard(self, shard_id: str) -> None:
|
||||
pass
|
||||
10
networking/server.py
Normal file
10
networking/server.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class Server(ABC):
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
4
orchestration/__init__.py
Normal file
4
orchestration/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .node import Node
|
||||
from .standard_node import StandardNode
|
||||
|
||||
__all__ = ["Node", "StandardNode"]
|
||||
24
orchestration/node.py
Normal file
24
orchestration/node.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class Node(ABC):
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_prompt(self, prompt: str, target: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_shard(self, shard_id: str) -> None:
|
||||
pass
|
||||
47
orchestration/standard_node.py
Normal file
47
orchestration/standard_node.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
from networking import Discovery, PeerHandle, Server
|
||||
from inference.inference_engine import InferenceEngine, Shard
|
||||
from .node import Node
|
||||
|
||||
class StandardNode(Node):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
|
||||
self.id = id
|
||||
self.inference_engine = inference_engine
|
||||
self.server = server
|
||||
self.discovery = discovery
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.ring_order: List[str] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.server.start()
|
||||
await self.discovery.start()
|
||||
self.peers = await self.discovery.discover_peers()
|
||||
print(f"Starting with the following peers: {self.peers}")
|
||||
print("Connecting to peers...")
|
||||
for peer in self.peers:
|
||||
await peer.connect()
|
||||
print(f"Connected to {peer.id()}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.discovery.stop()
|
||||
await self.server.stop()
|
||||
|
||||
async def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
result = await self.inference_engine.process_shard(tensor)
|
||||
|
||||
if target:
|
||||
if not filter(lambda p: p.id() == target, self.peers):
|
||||
raise ValueError(f"Peer {target} not found")
|
||||
|
||||
await self.peers[target].send_tensor(result)
|
||||
|
||||
async def process_prompt(self, prompt: str) -> None:
|
||||
# Implement prompt processing logic
|
||||
print(f"Processing prompt: {prompt}")
|
||||
# You might want to initiate inference here
|
||||
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
# Implement shard reset logic
|
||||
print(f"Resetting shard: {shard}")
|
||||
await self.inference_engine.reset_shard(shard)
|
||||
56
orchestration/test_node.py
Normal file
56
orchestration/test_node.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
import numpy as np
|
||||
|
||||
from .standard_node import StandardNode
|
||||
from networking.peer_handle import PeerHandle
|
||||
|
||||
class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.mock_inference_engine = AsyncMock()
|
||||
self.mock_server = AsyncMock()
|
||||
self.mock_server.start = AsyncMock()
|
||||
self.mock_server.stop = AsyncMock()
|
||||
self.mock_discovery = AsyncMock()
|
||||
self.mock_discovery.start = AsyncMock()
|
||||
self.mock_discovery.stop = AsyncMock()
|
||||
mock_peer1 = Mock(spec=PeerHandle)
|
||||
mock_peer1.id.return_value = "peer1"
|
||||
mock_peer2 = Mock(spec=PeerHandle)
|
||||
mock_peer2.id.return_value = "peer2"
|
||||
self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
|
||||
|
||||
self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await self.node.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.node.stop()
|
||||
|
||||
async def test_node_initialization(self):
|
||||
self.assertEqual(self.node.node_id, "test_node")
|
||||
self.assertEqual(self.node.host, "localhost")
|
||||
self.assertEqual(self.node.port, 50051)
|
||||
|
||||
async def test_node_start(self):
|
||||
self.mock_server.start.assert_called_once_with("localhost", 50051)
|
||||
|
||||
async def test_node_stop(self):
|
||||
await self.node.stop()
|
||||
self.mock_server.stop.assert_called_once()
|
||||
|
||||
async def test_discover_and_connect_to_peers(self):
|
||||
await self.node.discover_and_connect_to_peers()
|
||||
self.assertEqual(len(self.node.peers), 2)
|
||||
self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
|
||||
self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
|
||||
|
||||
async def test_process_tensor_calls_inference_engine(self):
|
||||
mock_peer = Mock()
|
||||
self.node.peers = [mock_peer]
|
||||
|
||||
input_tensor = np.array([69, 1, 2])
|
||||
await self.node.process_tensor(input_tensor, None)
|
||||
|
||||
self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)
|
||||
Reference in New Issue
Block a user