mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
DummyInferenceEngine commit 1
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,5 +1,5 @@
|
||||
__pycache__/
|
||||
.venv
|
||||
.venv*
|
||||
test_weights.npz
|
||||
.exo_used_ports
|
||||
.exo_node_id
|
||||
|
||||
68
exo/inference/dummy_inference_engine.py
Normal file
68
exo/inference/dummy_inference_engine.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional, Tuple
|
||||
if TYPE_CHECKING:
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
class DummyInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
self.vocab_size = 1000
|
||||
self.eos_token_id = 0
|
||||
self.latency_mean = 0.1
|
||||
self.latency_stddev = 0.02
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
||||
try:
|
||||
await self.ensure_shard(shard)
|
||||
|
||||
# Generate random tokens
|
||||
output_length = np.random.randint(1, 10)
|
||||
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
|
||||
|
||||
# Simulate latency
|
||||
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
|
||||
|
||||
# Randomly decide if finished
|
||||
is_finished = np.random.random() < 0.2
|
||||
if is_finished:
|
||||
output = np.array([[self.eos_token_id]])
|
||||
|
||||
new_state = json.dumps({"dummy_state": "some_value"})
|
||||
|
||||
return output, new_state, is_finished
|
||||
except Exception as e:
|
||||
print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
|
||||
return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
||||
await self.ensure_shard(shard)
|
||||
state = json.loads(inference_state or "{}")
|
||||
start_pos = state.get("start_pos", 0)
|
||||
|
||||
output_length = np.random.randint(1, 10)
|
||||
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
|
||||
|
||||
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
|
||||
|
||||
is_finished = np.random.random() < 0.2
|
||||
if is_finished:
|
||||
output = np.array([[self.eos_token_id]])
|
||||
|
||||
start_pos += input_data.shape[1] + output_length
|
||||
new_state = json.dumps({"start_pos": start_pos})
|
||||
|
||||
return output, new_state, is_finished
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
return
|
||||
# Simulate shard loading without making any API calls
|
||||
await asyncio.sleep(0.1) # Simulate a short delay
|
||||
self.shard = shard
|
||||
print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")
|
||||
@@ -8,7 +8,7 @@ from .shard import Shard
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -17,6 +17,7 @@ class InferenceEngine(ABC):
|
||||
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
|
||||
print(f"get_inference_engine called with: {inference_engine_name}") # Debug print
|
||||
if inference_engine_name == "mlx":
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
|
||||
@@ -27,5 +28,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
|
||||
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
||||
|
||||
return TinygradDynamicShardInferenceEngine(shard_downloader)
|
||||
elif inference_engine_name == "dummy":
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
return DummyInferenceEngine(shard_downloader)
|
||||
else:
|
||||
raise ValueError(f"Inference engine {inference_engine_name} not supported")
|
||||
raise ValueError(f"Inference engine {inference_engine_name} not supported. Supported engines are 'mlx', 'tinygrad', and 'dummy'.")
|
||||
40
exo/inference/test_dummy_inference_engine.py
Normal file
40
exo/inference/test_dummy_inference_engine.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_inference_engine():
|
||||
# Create a mock shard downloader
|
||||
class MockShardDownloader:
|
||||
async def ensure_shard(self, shard):
|
||||
pass
|
||||
|
||||
# Initialize the DummyInferenceEngine
|
||||
engine = DummyInferenceEngine(MockShardDownloader())
|
||||
|
||||
# Create a test shard
|
||||
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
|
||||
# Test infer_prompt
|
||||
output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
assert isinstance(state, str), "State should be a string"
|
||||
assert isinstance(is_finished, bool), "is_finished should be a boolean"
|
||||
|
||||
# Test infer_tensor
|
||||
input_tensor = np.array([[1, 2, 3]])
|
||||
output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
assert isinstance(state, str), "State should be a string"
|
||||
assert isinstance(is_finished, bool), "is_finished should be a boolean"
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(test_dummy_inference_engine())
|
||||
26
exo/main.py
26
exo/main.py
@@ -18,6 +18,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration.node import Node
|
||||
from exo.models import model_base_shards
|
||||
@@ -41,13 +42,15 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
|
||||
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
||||
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
||||
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
|
||||
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
|
||||
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
|
||||
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
||||
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
||||
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
||||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||
args = parser.parse_args()
|
||||
print(f"Selected inference engine: {args.inference_engine}")
|
||||
|
||||
|
||||
print_yellow_exo()
|
||||
|
||||
@@ -56,6 +59,15 @@ print(f"Detected system: {system_info}")
|
||||
|
||||
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
|
||||
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
||||
print(f"Inference engine name after selection: {inference_engine_name}")
|
||||
|
||||
if inference_engine_name not in ["mlx", "tinygrad", "dummy"]:
|
||||
print(f"Warning: Unknown inference engine '{inference_engine_name}'. Defaulting to 'tinygrad'.")
|
||||
inference_engine_name = "tinygrad"
|
||||
else:
|
||||
print(f"Using selected inference engine: {inference_engine_name}")
|
||||
|
||||
print(f"About to call get_inference_engine with: {inference_engine_name}")
|
||||
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
||||
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
||||
|
||||
@@ -173,6 +185,16 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
||||
node.on_token.deregister(callback_id)
|
||||
|
||||
|
||||
async def test_dummy_inference(inference_engine):
|
||||
print("Testing DummyInferenceEngine...")
|
||||
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
test_prompt = "This is a test prompt"
|
||||
result, state, is_finished = await inference_engine.infer_prompt("test_request", test_shard, test_prompt)
|
||||
print(f"Inference result shape: {result.shape}")
|
||||
print(f"Inference state: {state}")
|
||||
print(f"Is finished: {is_finished}")
|
||||
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
@@ -193,6 +215,8 @@ async def main():
|
||||
await run_model_cli(node, inference_engine, model_name, args.prompt)
|
||||
else:
|
||||
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
||||
if isinstance(node.inference_engine, DummyInferenceEngine):
|
||||
await test_dummy_inference(node.inference_engine)
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user