fix ruff lint errors

This commit is contained in:
Alex Cheema
2024-07-27 17:08:32 -07:00
parent ce761038ac
commit 57b2f2a4e2
20 changed files with 49 additions and 48 deletions

View File

@@ -1 +1 @@
from exo.helpers import DEBUG, DEBUG_DISCOVERY, VERSION
from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

View File

@@ -1 +1 @@
from exo.api.chatgpt_api import ChatGPTAPI
from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI

View File

@@ -85,20 +85,18 @@ async def resolve_tokenizer(model_id: str):
try:
if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
return AutoTokenizer.from_pretrained(model_id)
except:
except Exception as e:
if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
import traceback
if DEBUG >= 2: print(traceback.format_exc())
if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
try:
if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
return resolve_tinygrad_tokenizer(model_id)
except:
except Exception as e:
if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
import traceback
if DEBUG >= 2: print(traceback.format_exc())
if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
@@ -312,7 +310,7 @@ class ChatGPTAPI:
if (
request_id in self.stream_tasks
): # in case there is still a stream task running, wait for it to complete
if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
try:
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
except asyncio.TimeoutError:

View File

@@ -1,4 +1,3 @@
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.inference.inference_engine import InferenceEngine
from exo.inference.shard import Shard
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
@@ -19,7 +18,7 @@ async def test_inference_engine(
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
"A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
)
next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=resp_full,
@@ -41,7 +40,7 @@ async def test_inference_engine(
input_data=resp2,
inference_state=inference_state_2,
)
resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp3,

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
@@ -32,11 +32,11 @@ class NormalModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
if "factor" not in self.rope_scaling:
raise ValueError("rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get("rope_type")
if rope_type is None:
raise ValueError(f"rope_scaling must contain either 'type' or 'rope_type'")
raise ValueError("rope_scaling must contain either 'type' or 'rope_type'")
if rope_type not in ["linear", "dynamic", "llama3"]:
raise ValueError("rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'")
@@ -186,7 +186,7 @@ class Attention(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
B, L, _D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Generator, Optional, Tuple
from typing import Dict, Generator, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn

View File

@@ -1,5 +1,4 @@
from exo.inference.shard import Shard
from exo.inference.mlx.sharded_model import StatefulShardedModel
import mlx.core as mx
import mlx.nn as nn
from typing import Optional

View File

@@ -1,7 +1,6 @@
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.inference.inference_engine import InferenceEngine
from exo.inference.shard import Shard
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
import asyncio
import numpy as np
@@ -14,7 +13,7 @@ async def test_inference_engine(
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
"A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
)
next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=resp_full,
@@ -36,7 +35,7 @@ async def test_inference_engine(
input_data=resp2,
inference_state=inference_state_2,
)
resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp3,

View File

@@ -2,12 +2,12 @@ import asyncio
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
import json, argparse, random, time
import json
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
from tinygrad import Tensor, nn, Context, GlobalCounters
from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
from exo.inference.shard import Shard
from exo.inference.inference_engine import InferenceEngine

View File

@@ -13,8 +13,8 @@ from exo.topology.device_capabilities import DeviceCapabilities
class GRPCPeerHandle(PeerHandle):
def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
self._id = id
def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
self._id = _id
self.address = address
self._device_capabilities = device_capabilities
self.channel = None

View File

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, List, Callable
from typing import Optional, Tuple, List
import numpy as np
from abc import ABC, abstractmethod
from exo.helpers import AsyncCallbackSystem

View File

@@ -18,7 +18,7 @@ from exo.viz.topology_viz import TopologyViz
class StandardNode(Node):
def __init__(
self,
id: str,
_id: str,
server: Server,
inference_engine: InferenceEngine,
discovery: Discovery,
@@ -28,7 +28,7 @@ class StandardNode(Node):
web_chat_url: Optional[str] = None,
disable_tui: Optional[bool] = False,
):
self.id = id
self.id = _id
self.inference_engine = inference_engine
self.server = server
self.discovery = discovery
@@ -358,7 +358,7 @@ class StandardNode(Node):
continue
if max_depth <= 0:
if DEBUG >= 2: print(f"Max depth reached. Skipping...")
if DEBUG >= 2: print("Max depth reached. Skipping...")
continue
try:

View File

@@ -1,7 +1,6 @@
from exo.orchestration import Node
from prometheus_client import start_http_server, Counter, Histogram
import json
from typing import List
# Create metrics to track time spent and requests made.
PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
@@ -14,9 +13,9 @@ def start_metrics_server(node: Node, port: int):
def _on_opaque_status(request_id, opaque_status: str):
status_data = json.loads(opaque_status)
type = status_data.get("type", "")
_type = status_data.get("type", "")
node_id = status_data.get("node_id", "")
if type != "node_status":
if _type != "node_status":
return
status = status_data.get("status", "")

View File

@@ -116,8 +116,8 @@ def device_capabilities() -> DeviceCapabilities:
return linux_device_capabilities()
else:
return DeviceCapabilities(
model=f"Unknown Device",
chip=f"Unknown Chip",
model="Unknown Device",
chip="Unknown Chip",
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
@@ -151,7 +151,7 @@ def linux_device_capabilities() -> DeviceCapabilities:
if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
import pynvml, pynvml_utils
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import List
from dataclasses import dataclass
from .topology import Topology
from exo.inference.shard import Shard

View File

@@ -1,6 +1,5 @@
from typing import List
from .partitioning_strategy import PartitioningStrategy
from exo.inference.shard import Shard
from .topology import Topology
from .partitioning_strategy import Partition

View File

@@ -5,7 +5,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
class TestMacDeviceCapabilities(unittest.TestCase):
@patch("subprocess.check_output")
def test_mac_device_capabilities(self, mock_check_output):
def test_mac_device_capabilities_pro(self, mock_check_output):
# Mock the subprocess output
mock_check_output.return_value = b"""
Hardware:
@@ -40,7 +40,7 @@ Activation Lock Status: Enabled
)
@patch("subprocess.check_output")
def test_mac_device_capabilities(self, mock_check_output):
def test_mac_device_capabilities_air(self, mock_check_output):
# Mock the subprocess output
mock_check_output.return_value = b"""
Hardware:

View File

@@ -8,7 +8,6 @@ from rich.panel import Panel
from rich.text import Text
from rich.live import Live
from rich.style import Style
from rich.color import Color
from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
@@ -20,7 +19,7 @@ class TopologyViz:
self.partitions: List[Partition] = []
self.console = Console()
self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
self.live_panel.start()

View File

@@ -47,9 +47,9 @@ def adjust_indentation(content):
def process_file(file_path, process_func):
with open(file_path, 'r') as file:
content = file.read()
modified_content = process_func(content)
if content != modified_content:
with open(file_path, 'w') as file:
file.write(modified_content)

15
main.py
View File

@@ -2,7 +2,6 @@ import argparse
import asyncio
import signal
import uuid
from typing import List
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -41,11 +40,21 @@ if args.node_port is None:
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui, max_generate_tokens=args.max_generate_tokens)
node = StandardNode(
args.node_id,
None,
inference_engine,
discovery,
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions",
web_chat_url=f"http://localhost:{args.chatgpt_api_port}",
disable_tui=args.disable_tui,
max_generate_tokens=args.max_generate_tokens,
)
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
if args.prometheus_client_port:
from exo.stats.metrics import start_metrics_server
start_metrics_server(node, args.prometheus_client_port)