Formatting

This commit is contained in:
Sandesh Bharadwaj
2025-01-17 05:40:42 -05:00
parent 5f06aa2759
commit b9eccedc3d
6 changed files with 186 additions and 243 deletions

View File

@@ -34,6 +34,7 @@ import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
self.role = role
@@ -47,7 +48,6 @@ class Message:
return data
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
self.model = model
@@ -138,11 +138,7 @@ def remap_messages(messages: List[Message]) -> List[Message]:
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
chat_template_args = {
"conversation": [m.to_dict() for m in messages],
"tokenize": False,
"add_generation_prompt": True
}
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
if tools: chat_template_args["tools"] = tools
prompt = tokenizer.apply_chat_template(**chat_template_args)
@@ -171,8 +167,17 @@ class PromptSession:
self.timestamp = timestamp
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
def __init__(
self,
node: Node,
inference_engine_classname: str,
response_timeout: int = 90,
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
default_model: Optional[str] = None,
system_prompt: Optional[str] = None
):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
@@ -208,7 +213,6 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
@@ -219,7 +223,7 @@ class ChatGPTAPI:
self.app.middlewares.append(self.log_request)
async def handle_quit(self, request):
if DEBUG>=1: print("Received quit signal")
if DEBUG >= 1: print("Received quit signal")
response = web.json_response({"detail": "Quit signal received"}, status=200)
await response.prepare(request)
await response.write_eof()
@@ -249,61 +253,48 @@ class ChatGPTAPI:
async def handle_model_support(self, request):
try:
response = web.StreamResponse(
status=200,
reason='OK',
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
)
await response.prepare(request)
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
})
await response.prepare(request)
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False
download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False
model_data = {
model_name: {
"name": pretty,
"downloaded": download_percentage == 100 if download_percentage is not None else False,
"download_percentage": download_percentage,
"total_size": total_size,
"total_downloaded": total_downloaded
}
}
model_data = {
model_name: {
"name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
"total_downloaded": total_downloaded
}
}
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
# Process all models in parallel
await asyncio.gather(*[
process_model(model_name, pretty)
for model_name, pretty in pretty_name.items()
])
# Process all models in parallel
await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
await response.write(b"data: [DONE]\n\n")
return response
await response.write(b"data: [DONE]\n\n")
return response
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
async def handle_get_models(self, request):
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
@@ -472,7 +463,6 @@ class ChatGPTAPI:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
async def handle_post_image_generations(self, request):
data = await request.json()
@@ -485,7 +475,7 @@ class ChatGPTAPI:
shard = build_base_shard(model, self.inference_engine_classname)
if DEBUG >= 2: print(f"shard: {shard}")
if not shard:
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
request_id = str(uuid.uuid4())
callback_id = f"chatgpt-api-wait-response-{request_id}"
@@ -497,77 +487,73 @@ class ChatGPTAPI:
img = None
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'application/octet-stream',
"Cache-Control": "no-cache",
})
await response.prepare(request)
def get_progress_bar(current_step, total_steps, bar_length=50):
# Calculate the percentage of completion
percent = float(current_step) / total_steps
percent = float(current_step)/total_steps
# Calculate the number of hashes to display
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
spaces = ' ' * (bar_length - len(arrow))
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
spaces = ' '*(bar_length - len(arrow))
# Create the progress bar string
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
return progress_bar
async def stream_image(_request_id: str, result, is_finished: bool):
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
elif isinstance(result, np.ndarray):
im = Image.fromarray(np.array(result))
images_folder = get_exo_images_dir()
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = images_folder / image_filename
im.save(image_path)
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
# Construct the full URL correctly
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
elif isinstance(result, np.ndarray):
im = Image.fromarray(np.array(result))
images_folder = get_exo_images_dir()
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = images_folder/image_filename
im.save(image_path)
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
# Construct the full URL correctly
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
stream_task = None
def on_result(_request_id: str, result, is_finished: bool):
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
await callback.wait(on_result, timeout=self.response_timeout*10)
if stream_task:
# Wait for the stream task to complete before returning
await stream_task
# Wait for the stream task to complete before returning
await stream_task
return response
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_delete_model(self, request):
try:
model_name = request.match_info.get('model_name')
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
if not model_name or model_name not in model_cards:
return web.json_response(
{"detail": f"Invalid model name: {model_name}"},
status=400
)
return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard:
return web.json_response(
{"detail": "Could not build shard for model"},
status=400
)
return web.json_response({"detail": "Could not build shard for model"}, status=400)
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
@@ -582,38 +568,28 @@ class ChatGPTAPI:
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(cache_dir)
})
return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
except Exception as e:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
else:
return web.json_response({
"detail": f"Model files not found at {cache_dir}"
}, status=404)
return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({
"detail": f"Server error: {str(e)}"
}, status=500)
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
async def handle_get_initial_models(self, request):
model_data = {}
for model_name, pretty in pretty_name.items():
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
return web.json_response(model_data)
async def handle_create_animation(self, request):
@@ -639,17 +615,9 @@ class ChatGPTAPI:
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
# Create the animation
create_animation_mp4(
replacement_image_path,
output_path,
device_name,
prompt_text
)
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
return web.json_response({
"status": "success",
"output_path": output_path
})
return web.json_response({"status": "success", "output_path": output_path})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
@@ -665,10 +633,7 @@ class ChatGPTAPI:
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
return web.json_response({
"status": "success",
"message": f"Download started for model: {model_name}"
})
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"error": str(e)}, status=500)
@@ -682,10 +647,7 @@ class ChatGPTAPI:
return web.json_response({})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response(
{"detail": f"Error getting topology: {str(e)}"},
status=500
)
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)
@@ -696,15 +658,14 @@ class ChatGPTAPI:
def base64_decode(self, base64_string):
#decode and reshape image
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
img = Image.open(BytesIO(image_data))
W, H = (dim - dim % 64 for dim in (img.width, img.height))
W, H = (dim - dim%64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
img = img[None]
return img

View File

@@ -245,15 +245,13 @@ def get_all_ip_addresses_and_interfaces():
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
try:
# Use the shared subprocess_pool
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
['system_profiler', 'SPNetworkDataType', '-json'],
capture_output=True,
text=True,
close_fds=True
).stdout)
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
)
data = json.loads(output)
@@ -279,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
return None
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# On macOS, try to get interface type using networksetup
if psutil.MACOS:
@@ -286,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
if macos_type is not None: return macos_type
# Local container/virtual interfaces
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
'bridge' in ifname):
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
return (7, "Container Virtual")
# Loopback interface
@@ -313,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# Other physical interfaces
return (2, "Other")
async def shutdown(signal, loop, server):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
@@ -332,16 +331,16 @@ def is_frozen():
def get_exo_home() -> Path:
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else: docs_folder = Path.home() / "Documents"
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
else: docs_folder = Path.home()/"Documents"
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
exo_folder = docs_folder / "Exo"
exo_folder = docs_folder/"Exo"
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
return exo_folder
def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
images_dir = exo_home/"Images"
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
return images_dir

View File

@@ -14,7 +14,7 @@ class InferenceEngine(ABC):
@abstractmethod
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
pass
@abstractmethod
async def sample(self, x: np.ndarray) -> np.ndarray:
pass
@@ -33,13 +33,13 @@ class InferenceEngine(ABC):
async def save_checkpoint(self, shard: Shard, path: str):
pass
async def save_session(self, key, value):
self.session[key] = value
async def clear_session(self):
self.session.empty()
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
tokens = await self.encode(shard, prompt)
if shard.model_id != 'stable-diffusion-2-1-base':
@@ -50,12 +50,14 @@ class InferenceEngine(ABC):
return output_data, inference_state
inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
"tinygrad": "TinygradDynamicShardInferenceEngine",
"dummy": "DummyInferenceEngine",
}
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")

View File

@@ -19,6 +19,7 @@ if platform.system().lower() == "darwin" and platform.machine().lower() == "arm6
else:
import numpy as mx
class GRPCPeerHandle(PeerHandle):
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
self._id = _id
@@ -42,11 +43,9 @@ class GRPCPeerHandle(PeerHandle):
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[
("grpc.max_metadata_size", 32*1024*1024),
('grpc.max_receive_message_length', 32*1024*1024),
('grpc.max_send_message_length', 32*1024*1024)
])
self.channel = grpc.aio.insecure_channel(
self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
@@ -114,7 +113,7 @@ class GRPCPeerHandle(PeerHandle):
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(
shard=node_service_pb2.Shard(
@@ -136,7 +135,7 @@ class GRPCPeerHandle(PeerHandle):
return loss, grads
else:
return loss
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
@@ -171,10 +170,7 @@ class GRPCPeerHandle(PeerHandle):
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(
model=capabilities.model,
chip=capabilities.chip,
memory=capabilities.memory,
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
)
topology.update_node(node_id, device_capabilities)
for node_id, peer_connections in response.peer_graph.items():
@@ -198,28 +194,20 @@ class GRPCPeerHandle(PeerHandle):
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state

View File

@@ -80,7 +80,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def SendExample(self, request, context):
shard = Shard(
model_id=request.shard.model_id,
@@ -102,7 +102,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
else:
loss = await self.node.process_example(shard, example, target, length, train, request_id)
return node_service_pb2.Loss(loss=loss, grads=None)
async def CollectTopology(self, request, context):
max_depth = request.max_depth
visited = set(request.visited)
@@ -118,12 +118,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
for node_id, cap in topology.nodes.items()
}
peer_graph = {
node_id: node_service_pb2.PeerConnections(
connections=[
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
for conn in connections
]
)
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
for node_id, connections in topology.peer_graph.items()
}
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
@@ -137,7 +132,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
result = list(result)
if len(img.tensor_data) > 0:
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()
@@ -151,21 +146,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)
def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
inference_state = {}
for k, tensor_data in inference_state_proto.tensor_data.items():
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
for k, tensor_list in inference_state_proto.tensor_list_data.items():
inference_state[k] = [
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
for tensor in tensor_list.tensors
]
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
if inference_state_proto.other_data_json:
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
return inference_state

View File

@@ -195,7 +195,7 @@ def linux_device_capabilities() -> DeviceCapabilities:
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
pynvml.nvmlShutdown()
return DeviceCapabilities(
@@ -207,22 +207,22 @@ def linux_device_capabilities() -> DeviceCapabilities:
elif Device.DEFAULT == "AMD":
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Linux Box ({gpu_name})",
chip={gpu_name},
memory=gpu_memory_info.total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
else:
return DeviceCapabilities(
model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -234,30 +234,31 @@ def linux_device_capabilities() -> DeviceCapabilities:
def windows_device_capabilities() -> DeviceCapabilities:
import psutil
def get_gpu_info():
import win32com.client # install pywin32
import win32com.client # install pywin32
wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
gpu_info = []
for gpu in gpus:
info = {
"Name": gpu.Name,
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
"DriverVersion": gpu.DriverVersion,
"VideoProcessor": gpu.VideoProcessor
}
gpu_info.append(info)
info = {
"Name": gpu.Name,
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
"DriverVersion": gpu.DriverVersion,
"VideoProcessor": gpu.VideoProcessor
}
gpu_info.append(info)
return gpu_info
gpus_info = get_gpu_info()
gpu_names = [gpu['Name'] for gpu in gpus_info]
contains_nvidia = any('nvidia' in gpu_name.lower()for gpu_name in gpu_names)
contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
if contains_nvidia:
import pynvml
@@ -266,7 +267,7 @@ def windows_device_capabilities() -> DeviceCapabilities:
gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
return DeviceCapabilities(
@@ -278,15 +279,15 @@ def windows_device_capabilities() -> DeviceCapabilities:
elif contains_amd:
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Windows Box ({gpu_name})",
chip={gpu_name},
@@ -299,4 +300,4 @@ def windows_device_capabilities() -> DeviceCapabilities:
chip=f"Unknown Chip (Device(s): {gpu_names})",
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
)