mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix ruff lint errors
This commit is contained in:
@@ -1 +1 @@
|
||||
from exo.helpers import DEBUG, DEBUG_DISCOVERY, VERSION
|
||||
from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
|
||||
|
||||
@@ -1 +1 @@
|
||||
from exo.api.chatgpt_api import ChatGPTAPI
|
||||
from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI
|
||||
|
||||
@@ -85,20 +85,18 @@ async def resolve_tokenizer(model_id: str):
|
||||
try:
|
||||
if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
|
||||
return AutoTokenizer.from_pretrained(model_id)
|
||||
except:
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
|
||||
import traceback
|
||||
|
||||
if DEBUG >= 2: print(traceback.format_exc())
|
||||
if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
|
||||
|
||||
try:
|
||||
if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
|
||||
return resolve_tinygrad_tokenizer(model_id)
|
||||
except:
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
|
||||
import traceback
|
||||
|
||||
if DEBUG >= 2: print(traceback.format_exc())
|
||||
if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
|
||||
|
||||
if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
|
||||
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
@@ -312,7 +310,7 @@ class ChatGPTAPI:
|
||||
if (
|
||||
request_id in self.stream_tasks
|
||||
): # in case there is still a stream task running, wait for it to complete
|
||||
if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
|
||||
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
|
||||
try:
|
||||
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
@@ -19,7 +18,7 @@ async def test_inference_engine(
|
||||
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
|
||||
"A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
|
||||
)
|
||||
next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
|
||||
next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
|
||||
"A",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
|
||||
input_data=resp_full,
|
||||
@@ -41,7 +40,7 @@ async def test_inference_engine(
|
||||
input_data=resp2,
|
||||
inference_state=inference_state_2,
|
||||
)
|
||||
resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
|
||||
resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
||||
input_data=resp3,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,11 +32,11 @@ class NormalModelArgs(BaseModelArgs):
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
if not "factor" in self.rope_scaling:
|
||||
raise ValueError(f"rope_scaling must contain 'factor'")
|
||||
if "factor" not in self.rope_scaling:
|
||||
raise ValueError("rope_scaling must contain 'factor'")
|
||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get("rope_type")
|
||||
if rope_type is None:
|
||||
raise ValueError(f"rope_scaling must contain either 'type' or 'rope_type'")
|
||||
raise ValueError("rope_scaling must contain either 'type' or 'rope_type'")
|
||||
if rope_type not in ["linear", "dynamic", "llama3"]:
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'")
|
||||
|
||||
@@ -186,7 +186,7 @@ class Attention(nn.Module):
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
B, L, _D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
from typing import Dict, Generator, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.mlx.sharded_model import StatefulShardedModel
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
@@ -14,7 +13,7 @@ async def test_inference_engine(
|
||||
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
|
||||
"A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
|
||||
)
|
||||
next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
|
||||
next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
|
||||
"A",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
|
||||
input_data=resp_full,
|
||||
@@ -36,7 +35,7 @@ async def test_inference_engine(
|
||||
input_data=resp2,
|
||||
inference_state=inference_state_2,
|
||||
)
|
||||
resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
|
||||
resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
||||
input_data=resp3,
|
||||
|
||||
@@ -2,12 +2,12 @@ import asyncio
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
import json, argparse, random, time
|
||||
import json
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
||||
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
|
||||
from tinygrad import Tensor, nn, Context, GlobalCounters
|
||||
from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
|
||||
@@ -13,8 +13,8 @@ from exo.topology.device_capabilities import DeviceCapabilities
|
||||
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
|
||||
self._id = id
|
||||
def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
|
||||
self._id = _id
|
||||
self.address = address
|
||||
self._device_capabilities = device_capabilities
|
||||
self.channel = None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple, List, Callable
|
||||
from typing import Optional, Tuple, List
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from exo.helpers import AsyncCallbackSystem
|
||||
|
||||
@@ -18,7 +18,7 @@ from exo.viz.topology_viz import TopologyViz
|
||||
class StandardNode(Node):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
_id: str,
|
||||
server: Server,
|
||||
inference_engine: InferenceEngine,
|
||||
discovery: Discovery,
|
||||
@@ -28,7 +28,7 @@ class StandardNode(Node):
|
||||
web_chat_url: Optional[str] = None,
|
||||
disable_tui: Optional[bool] = False,
|
||||
):
|
||||
self.id = id
|
||||
self.id = _id
|
||||
self.inference_engine = inference_engine
|
||||
self.server = server
|
||||
self.discovery = discovery
|
||||
@@ -358,7 +358,7 @@ class StandardNode(Node):
|
||||
continue
|
||||
|
||||
if max_depth <= 0:
|
||||
if DEBUG >= 2: print(f"Max depth reached. Skipping...")
|
||||
if DEBUG >= 2: print("Max depth reached. Skipping...")
|
||||
continue
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from exo.orchestration import Node
|
||||
from prometheus_client import start_http_server, Counter, Histogram
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
# Create metrics to track time spent and requests made.
|
||||
PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
|
||||
@@ -14,9 +13,9 @@ def start_metrics_server(node: Node, port: int):
|
||||
|
||||
def _on_opaque_status(request_id, opaque_status: str):
|
||||
status_data = json.loads(opaque_status)
|
||||
type = status_data.get("type", "")
|
||||
_type = status_data.get("type", "")
|
||||
node_id = status_data.get("node_id", "")
|
||||
if type != "node_status":
|
||||
if _type != "node_status":
|
||||
return
|
||||
status = status_data.get("status", "")
|
||||
|
||||
|
||||
@@ -116,8 +116,8 @@ def device_capabilities() -> DeviceCapabilities:
|
||||
return linux_device_capabilities()
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model=f"Unknown Device",
|
||||
chip=f"Unknown Chip",
|
||||
model="Unknown Device",
|
||||
chip="Unknown Chip",
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
@@ -151,7 +151,7 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
|
||||
if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
|
||||
if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
|
||||
import pynvml, pynvml_utils
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from .topology import Topology
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import List
|
||||
from .partitioning_strategy import PartitioningStrategy
|
||||
from exo.inference.shard import Shard
|
||||
from .topology import Topology
|
||||
from .partitioning_strategy import Partition
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
|
||||
|
||||
class TestMacDeviceCapabilities(unittest.TestCase):
|
||||
@patch("subprocess.check_output")
|
||||
def test_mac_device_capabilities(self, mock_check_output):
|
||||
def test_mac_device_capabilities_pro(self, mock_check_output):
|
||||
# Mock the subprocess output
|
||||
mock_check_output.return_value = b"""
|
||||
Hardware:
|
||||
@@ -40,7 +40,7 @@ Activation Lock Status: Enabled
|
||||
)
|
||||
|
||||
@patch("subprocess.check_output")
|
||||
def test_mac_device_capabilities(self, mock_check_output):
|
||||
def test_mac_device_capabilities_air(self, mock_check_output):
|
||||
# Mock the subprocess output
|
||||
mock_check_output.return_value = b"""
|
||||
Hardware:
|
||||
|
||||
@@ -8,7 +8,6 @@ from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from rich.live import Live
|
||||
from rich.style import Style
|
||||
from rich.color import Color
|
||||
from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
|
||||
|
||||
|
||||
@@ -20,7 +19,7 @@ class TopologyViz:
|
||||
self.partitions: List[Partition] = []
|
||||
|
||||
self.console = Console()
|
||||
self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
|
||||
self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
|
||||
self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
|
||||
self.live_panel.start()
|
||||
|
||||
|
||||
@@ -47,9 +47,9 @@ def adjust_indentation(content):
|
||||
def process_file(file_path, process_func):
|
||||
with open(file_path, 'r') as file:
|
||||
content = file.read()
|
||||
|
||||
|
||||
modified_content = process_func(content)
|
||||
|
||||
|
||||
if content != modified_content:
|
||||
with open(file_path, 'w') as file:
|
||||
file.write(modified_content)
|
||||
|
||||
15
main.py
15
main.py
@@ -2,7 +2,6 @@ import argparse
|
||||
import asyncio
|
||||
import signal
|
||||
import uuid
|
||||
from typing import List
|
||||
from exo.orchestration.standard_node import StandardNode
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
@@ -41,11 +40,21 @@ if args.node_port is None:
|
||||
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
|
||||
|
||||
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
|
||||
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui, max_generate_tokens=args.max_generate_tokens)
|
||||
node = StandardNode(
|
||||
args.node_id,
|
||||
None,
|
||||
inference_engine,
|
||||
discovery,
|
||||
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
|
||||
chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions",
|
||||
web_chat_url=f"http://localhost:{args.chatgpt_api_port}",
|
||||
disable_tui=args.disable_tui,
|
||||
max_generate_tokens=args.max_generate_tokens,
|
||||
)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
|
||||
node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
|
||||
node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
|
||||
if args.prometheus_client_port:
|
||||
from exo.stats.metrics import start_metrics_server
|
||||
start_metrics_server(node, args.prometheus_client_port)
|
||||
|
||||
Reference in New Issue
Block a user