scaffolding for networking, inference and orchestration

This commit is contained in:
Alex Cheema
2024-06-23 23:28:10 +01:00
commit a21f59ff45
20 changed files with 779 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
__pycache__/
.venv

View File

@@ -0,0 +1,31 @@
import numpy as np
import mlx.nn as nn
from abc import ABC, abstractmethod
from .shard import Shard
class InferenceEngine(ABC):
@abstractmethod
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
pass
@abstractmethod
async def reset_shard(self, shard: Shard):
pass
class MLXFixedShardInferenceEngine(InferenceEngine):
def __init__(self, model: nn.Module, shard: Shard):
self.model = model
self.shard = shard
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
output_data = self.model.process(input_data)
print("Processed data through model shard")
return output_data
async def reset_shard(self, shard: Shard):
# TODO
print(f"Resetting shard: {shard}")

8
inference/shard.py Normal file
View File

@@ -0,0 +1,8 @@
from dataclasses import dataclass
@dataclass
class Shard:
model_id: str
n_layers: int
start_layer: int
end_layer: int

71
main.py Normal file
View File

@@ -0,0 +1,71 @@
import argparse
import asyncio
import signal
import mlx.core as mx
import mlx.nn as nn
from orchestration.standard_node import StandardNode
from networking.grpc.grpc_server import GRPCServer
from inference.inference_engine import MLXFixedShardInferenceEngine
from inference.shard import Shard
from networking.grpc.grpc_discovery import GRPCDiscovery
class SimpleMLXModel(nn.Module):
def __init__(self):
super(SimpleMLXModel, self).__init__()
self.linear = nn.Linear(10, 5) # Example dimensions
def forward(self, x):
return self.linear(x)
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
args = parser.parse_args()
mlx_model = SimpleMLXModel()
inference_engine = MLXFixedShardInferenceEngine(mlx_model, shard=Shard(model_id="test", n_layers=32, start_layer=0, end_layer=31))
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
node = StandardNode(args.node_id, None, inference_engine, discovery)
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.shutdown()
loop.stop()
async def main():
loop = asyncio.get_running_loop()
# Use a more direct approach to handle signals
def handle_exit():
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, handle_exit)
await node.start()
await asyncio.sleep(5)
print("Sending reset shard request")
await node.peers[0].reset_shard(f"regards from {node.id}")
await asyncio.Event().wait()
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.close()

5
networking/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .discovery import Discovery
from .peer_handle import PeerHandle
from .server import Server
__all__ = ['Discovery', 'PeerHandle', 'Server']

16
networking/discovery.py Normal file
View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
from .peer_handle import PeerHandle
class Discovery(ABC):
@abstractmethod
async def start(self) -> None:
pass
@abstractmethod
async def stop(self) -> None:
pass
@abstractmethod
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
pass

View File

View File

@@ -0,0 +1,105 @@
import asyncio
import json
import socket
import time
from typing import List, Dict
from ..discovery import Discovery
from ..peer_handle import PeerHandle
from .grpc_peer_handle import GRPCPeerHandle
class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1):
self.node_id = node_id
self.node_port = node_port
self.listen_port = listen_port
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
self.broadcast_interval = broadcast_interval
self.known_peers: Dict[str, GRPCPeerHandle] = {}
self.peer_last_seen: Dict[str, float] = {}
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
async def start(self):
self.broadcast_task = asyncio.create_task(self._broadcast_presence())
self.listen_task = asyncio.create_task(self._listen_for_peers())
self.cleanup_task = asyncio.create_task(self._cleanup_peers())
async def stop(self):
if self.broadcast_task:
self.broadcast_task.cancel()
if self.listen_task:
self.listen_task.cancel()
if self.cleanup_task:
self.cleanup_task.cancel()
if self.broadcast_task or self.listen_task or self.cleanup_task:
await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
print("Starting peer discovery process...")
if wait_for_peers > 0:
while not self.known_peers:
print("No peers discovered yet, retrying in 1 second...")
await asyncio.sleep(1) # Keep trying to find peers
print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
grace_period = 5 # seconds
while True:
initial_peer_count = len(self.known_peers)
print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
await asyncio.sleep(grace_period)
if len(self.known_peers) == initial_peer_count:
if wait_for_peers > 0:
print(f"Waiting additional {wait_for_peers} seconds for more peers.")
await asyncio.sleep(wait_for_peers)
wait_for_peers = 0
else:
print("No new peers discovered in the last grace period. Ending discovery process.")
break # No new peers found in the grace period, we are done
return list(self.known_peers.values())
async def _broadcast_presence(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.settimeout(0.5)
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port
}).encode('utf-8')
while True:
sock.sendto(message, ('<broadcast>', self.broadcast_port))
await asyncio.sleep(self.broadcast_interval)
async def _listen_for_peers(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('', self.listen_port))
sock.setblocking(False)
while True:
try:
data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
message = json.loads(data.decode('utf-8'))
if message['type'] == 'discovery' and message['node_id'] != self.node_id:
peer_id = message['node_id']
peer_host = addr[0]
peer_port = message['grpc_port']
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}")
self.peer_last_seen[peer_id] = time.time()
except Exception as e:
print(f"Error in peer discovery: {e}")
await asyncio.sleep(self.broadcast_interval / 2)
async def _cleanup_peers(self):
while True:
current_time = time.time()
timeout = 5 * self.broadcast_interval
peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
for peer_id in peers_to_remove:
del self.known_peers[peer_id]
del self.peer_last_seen[peer_id]
print(f"Removed peer {peer_id} due to inactivity.")
await asyncio.sleep(self.broadcast_interval)

View File

@@ -0,0 +1,47 @@
import grpc
import numpy as np
from typing import Optional
# These would be generated from the .proto file
from . import node_service_pb2
from . import node_service_pb2_grpc
from ..peer_handle import PeerHandle
class GRPCPeerHandle(PeerHandle):
def __init__(self, id: str, address: str):
self._id = id
self.address = address
def id(self) -> str:
return self._id
async def connect(self):
self.channel = grpc.aio.insecure_channel(self.address)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
async def disconnect(self):
await self.channel.close()
async def send_prompt(self, prompt: str) -> None:
request = node_service_pb2.PromptRequest(prompt=prompt)
await self.stub.SendPrompt(request)
print(f"Sent prompt to {self.address}: {prompt}")
async def send_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
request = node_service_pb2.TensorRequest(
tensor_data=tensor.tobytes(),
shape=tensor.shape,
dtype=str(tensor.dtype),
target=target
)
await self.stub.SendTensor(request)
if target:
print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
else:
print(f"Sent tensor to {self.address}: shape {tensor.shape}")
async def reset_shard(self, shard_id: str) -> None:
request = node_service_pb2.ResetShardRequest(shard_id=shard_id)
await self.stub.ResetShard(request)
print(f"Reset shard {shard_id} on {self.address}")

View File

@@ -0,0 +1,57 @@
import grpc
from concurrent import futures
import numpy as np
from . import node_service_pb2
from . import node_service_pb2_grpc
from orchestration import Node
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
def __init__(self, node: Node, host: str, port: int):
self.node = node
self.host = host
self.port = port
self.server = None
async def start(self) -> None:
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10))
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
listen_addr = f'{self.host}:{self.port}'
self.server.add_insecure_port(listen_addr)
await self.server.start()
print(f"Server started, listening on {listen_addr}")
async def stop(self) -> None:
if self.server:
await self.server.stop(5) # 5 seconds grace period
print("Server stopped")
async def SendPrompt(self, request, context):
prompt = request.prompt
target = request.target if request.HasField('target') else None
if target and target != self.node.node_id:
await self.node.process_prompt(prompt, target)
else:
# Process the prompt locally
# You'd need to implement this method in the Node class
await self.node.process_prompt(prompt)
return node_service_pb2.Empty()
async def SendTensor(self, request, context):
tensor = np.frombuffer(request.tensor_data, dtype=np.dtype(request.dtype)).reshape(request.shape)
target = request.target if request.HasField('target') else None
if target and target != self.node.node_id:
await self.node.process_tensor(tensor, target)
else:
# Process the tensor locally
await self.node.inference_strategy.process_inference(tensor)
return node_service_pb2.Empty()
async def ResetShard(self, request, context):
print(f"Received ResetShard request: {request}")
# TODO
# shard_id = request.shard_id
# You'd need to implement this method in the Node class
# await self.node.reset_shard(shard_id)
return node_service_pb2.Empty()

View File

@@ -0,0 +1,27 @@
syntax = "proto3";
package node_service;
service NodeService {
rpc SendPrompt (PromptRequest) returns (Empty) {}
rpc SendTensor (TensorRequest) returns (Empty) {}
rpc ResetShard (ResetShardRequest) returns (Empty) {}
}
message PromptRequest {
string prompt = 1;
optional string target = 2;
}
message TensorRequest {
bytes tensor_data = 1;
repeated int32 shape = 2;
string dtype = 3;
optional string target = 4;
}
message ResetShardRequest {
string shard_id = 1;
}
message Empty {}

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: node_service.proto
# Protobuf Python Version: 5.26.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"?\n\rPromptRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x13\n\x06target\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"b\n\rTensorRequest\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\x12\x13\n\x06target\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"%\n\x11ResetShardRequest\x12\x10\n\x08shard_id\x18\x01 \x01(\t\"\x07\n\x05\x45mpty2\xd7\x01\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_PROMPTREQUEST']._serialized_start=36
_globals['_PROMPTREQUEST']._serialized_end=99
_globals['_TENSORREQUEST']._serialized_start=101
_globals['_TENSORREQUEST']._serialized_end=199
_globals['_RESETSHARDREQUEST']._serialized_start=201
_globals['_RESETSHARDREQUEST']._serialized_end=238
_globals['_EMPTY']._serialized_start=240
_globals['_EMPTY']._serialized_end=247
_globals['_NODESERVICE']._serialized_start=250
_globals['_NODESERVICE']._serialized_end=465
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,188 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from . import node_service_pb2 as node__service__pb2
GRPC_GENERATED_VERSION = '1.64.1'
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = '1.65.0'
SCHEDULED_RELEASE_DATE = 'June 25, 2024'
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
warnings.warn(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in node_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
RuntimeWarning
)
class NodeServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.ResetShard = channel.unary_unary(
'/node_service.NodeService/ResetShard',
request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
class NodeServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def SendPrompt(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendTensor(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ResetShard(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_NodeServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'ResetShard': grpc.unary_unary_rpc_method_handler(
servicer.ResetShard,
request_deserializer=node__service__pb2.ResetShardRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'node_service.NodeService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class NodeService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def SendPrompt(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendTensor(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def ResetShard(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/ResetShard',
node__service__pb2.ResetShardRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,21 @@
import asyncio
import unittest
from .grpc_discovery import GRPCDiscovery
class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
await self.node1.start()
await self.node2.start()
async def asyncTearDown(self):
await self.node1.stop()
await self.node2.stop()
async def test_discovery(self):
await asyncio.sleep(4)
# Check discovered peers
print("Node1 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
print("Node2 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))

26
networking/peer_handle.py Normal file
View File

@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import Any
class PeerHandle(ABC):
def id(self) -> str:
pass
@abstractmethod
async def connect(self) -> None:
pass
@abstractmethod
async def disconnect(self) -> None:
pass
@abstractmethod
async def send_prompt(self, prompt: str) -> None:
pass
@abstractmethod
async def send_tensor(self, tensor: Any) -> None:
pass
@abstractmethod
async def reset_shard(self, shard_id: str) -> None:
pass

10
networking/server.py Normal file
View File

@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
class Server(ABC):
@abstractmethod
async def start(self) -> None:
pass
@abstractmethod
async def stop(self) -> None:
pass

View File

@@ -0,0 +1,4 @@
from .node import Node
from .standard_node import StandardNode
__all__ = ["Node", "StandardNode"]

24
orchestration/node.py Normal file
View File

@@ -0,0 +1,24 @@
from typing import Optional
import numpy as np
from abc import ABC, abstractmethod
class Node(ABC):
@abstractmethod
def start(self) -> None:
pass
@abstractmethod
def stop(self) -> None:
pass
@abstractmethod
def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
pass
@abstractmethod
def process_prompt(self, prompt: str, target: Optional[str] = None) -> None:
pass
@abstractmethod
def reset_shard(self, shard_id: str) -> None:
pass

View File

@@ -0,0 +1,47 @@
from typing import List, Optional
import numpy as np
from networking import Discovery, PeerHandle, Server
from inference.inference_engine import InferenceEngine, Shard
from .node import Node
class StandardNode(Node):
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
self.id = id
self.inference_engine = inference_engine
self.server = server
self.discovery = discovery
self.peers: List[PeerHandle] = {}
self.ring_order: List[str] = []
async def start(self) -> None:
await self.server.start()
await self.discovery.start()
self.peers = await self.discovery.discover_peers()
print(f"Starting with the following peers: {self.peers}")
print("Connecting to peers...")
for peer in self.peers:
await peer.connect()
print(f"Connected to {peer.id()}")
async def stop(self) -> None:
await self.discovery.stop()
await self.server.stop()
async def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
result = await self.inference_engine.process_shard(tensor)
if target:
if not filter(lambda p: p.id() == target, self.peers):
raise ValueError(f"Peer {target} not found")
await self.peers[target].send_tensor(result)
async def process_prompt(self, prompt: str) -> None:
# Implement prompt processing logic
print(f"Processing prompt: {prompt}")
# You might want to initiate inference here
async def reset_shard(self, shard: Shard) -> None:
# Implement shard reset logic
print(f"Resetting shard: {shard}")
await self.inference_engine.reset_shard(shard)

View File

@@ -0,0 +1,56 @@
import unittest
from unittest.mock import Mock, AsyncMock
import numpy as np
from .standard_node import StandardNode
from networking.peer_handle import PeerHandle
class TestNode(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_inference_engine = AsyncMock()
self.mock_server = AsyncMock()
self.mock_server.start = AsyncMock()
self.mock_server.stop = AsyncMock()
self.mock_discovery = AsyncMock()
self.mock_discovery.start = AsyncMock()
self.mock_discovery.stop = AsyncMock()
mock_peer1 = Mock(spec=PeerHandle)
mock_peer1.id.return_value = "peer1"
mock_peer2 = Mock(spec=PeerHandle)
mock_peer2.id.return_value = "peer2"
self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
async def asyncSetUp(self):
await self.node.start()
async def asyncTearDown(self):
await self.node.stop()
async def test_node_initialization(self):
self.assertEqual(self.node.node_id, "test_node")
self.assertEqual(self.node.host, "localhost")
self.assertEqual(self.node.port, 50051)
async def test_node_start(self):
self.mock_server.start.assert_called_once_with("localhost", 50051)
async def test_node_stop(self):
await self.node.stop()
self.mock_server.stop.assert_called_once()
async def test_discover_and_connect_to_peers(self):
await self.node.discover_and_connect_to_peers()
self.assertEqual(len(self.node.peers), 2)
self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
async def test_process_tensor_calls_inference_engine(self):
mock_peer = Mock()
self.node.peers = [mock_peer]
input_tensor = np.array([69, 1, 2])
await self.node.process_tensor(input_tensor, None)
self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)