mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #364 from rahat2134/DummyInferenceEngine
Implementation of DummyInferenceEngine
This commit is contained in:
@@ -44,6 +44,13 @@ commands:
|
||||
# Check processes before proceeding
|
||||
check_processes
|
||||
|
||||
# Special handling for dummy engine
|
||||
if [ "<<parameters.inference_engine>>" = "dummy" ]; then
|
||||
expected_content="This is a dummy response"
|
||||
else
|
||||
expected_content="Michael Jackson"
|
||||
fi
|
||||
|
||||
echo "Sending request to first instance..."
|
||||
response_1=$(curl -s http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -74,8 +81,8 @@ commands:
|
||||
kill $PID1 $PID2
|
||||
|
||||
echo ""
|
||||
if ! echo "$response_1" | grep -q "Michael Jackson" || ! echo "$response_2" | grep -q "Michael Jackson"; then
|
||||
echo "Test failed: Response does not contain 'Michael Jackson'"
|
||||
if ! echo "$response_1" | grep -q "$expected_content" || ! echo "$response_2" | grep -q "$expected_content"; then
|
||||
echo "Test failed: Response does not contain '$expected_content'"
|
||||
echo "Response 1: $response_1"
|
||||
echo ""
|
||||
echo "Response 2: $response_2"
|
||||
@@ -85,7 +92,7 @@ commands:
|
||||
cat output2.log
|
||||
exit 1
|
||||
else
|
||||
echo "Test passed: Response from both nodes contains 'Michael Jackson'"
|
||||
echo "Test passed: Response from both nodes contains '$expected_content'"
|
||||
fi
|
||||
|
||||
jobs:
|
||||
@@ -178,6 +185,28 @@ jobs:
|
||||
inference_engine: mlx
|
||||
model_id: llama-3.2-1b
|
||||
|
||||
chatgpt_api_integration_test_dummy:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
resource_class: m2pro.large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
brew install python@3.12
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run_chatgpt_api_test:
|
||||
inference_engine: dummy
|
||||
model_id: dummy-model
|
||||
|
||||
test_macos_m1:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
@@ -215,5 +244,6 @@ workflows:
|
||||
- unit_test
|
||||
- discovery_integration_test
|
||||
- chatgpt_api_integration_test_mlx
|
||||
- chatgpt_api_integration_test_dummy
|
||||
- test_macos_m1
|
||||
# - chatgpt_api_integration_test_tinygrad
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,5 +1,5 @@
|
||||
__pycache__/
|
||||
.venv
|
||||
.venv*
|
||||
test_weights.npz
|
||||
.exo_used_ports
|
||||
.exo_node_id
|
||||
|
||||
65
exo/inference/dummy_inference_engine.py
Normal file
65
exo/inference/dummy_inference_engine.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import json
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
class DummyInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
self.vocab_size = 1000
|
||||
self.eos_token_id = 0
|
||||
self.latency_mean = 0.1
|
||||
self.latency_stddev = 0.02
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
||||
try:
|
||||
await self.ensure_shard(shard)
|
||||
|
||||
# Generate random tokens
|
||||
output_length = np.random.randint(1, 10)
|
||||
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
|
||||
|
||||
# Simulate latency
|
||||
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
|
||||
|
||||
# Randomly decide if finished
|
||||
is_finished = np.random.random() < 0.2
|
||||
if is_finished:
|
||||
output = np.array([[self.eos_token_id]])
|
||||
|
||||
new_state = json.dumps({"dummy_state": "some_value"})
|
||||
|
||||
return output, new_state, is_finished
|
||||
except Exception as e:
|
||||
print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
|
||||
return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
||||
await self.ensure_shard(shard)
|
||||
state = json.loads(inference_state or "{}")
|
||||
start_pos = state.get("start_pos", 0)
|
||||
|
||||
output_length = np.random.randint(1, 10)
|
||||
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
|
||||
|
||||
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
|
||||
|
||||
is_finished = np.random.random() < 0.2
|
||||
if is_finished:
|
||||
output = np.array([[self.eos_token_id]])
|
||||
|
||||
start_pos += input_data.shape[1] + output_length
|
||||
new_state = json.dumps({"start_pos": start_pos})
|
||||
|
||||
return output, new_state, is_finished
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
return
|
||||
# Simulate shard loading without making any API calls
|
||||
await asyncio.sleep(0.1) # Simulate a short delay
|
||||
self.shard = shard
|
||||
print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from exo.helpers import DEBUG # Make sure to import DEBUG
|
||||
|
||||
from typing import Tuple, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -8,7 +9,7 @@ from .shard import Shard
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -17,6 +18,8 @@ class InferenceEngine(ABC):
|
||||
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
|
||||
if DEBUG >= 2:
|
||||
print(f"get_inference_engine called with: {inference_engine_name}")
|
||||
if inference_engine_name == "mlx":
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
|
||||
@@ -27,5 +30,7 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
|
||||
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
||||
|
||||
return TinygradDynamicShardInferenceEngine(shard_downloader)
|
||||
else:
|
||||
raise ValueError(f"Inference engine {inference_engine_name} not supported")
|
||||
elif inference_engine_name == "dummy":
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
return DummyInferenceEngine(shard_downloader)
|
||||
raise ValueError(f"Unsupported inference engine: {inference_engine_name}")
|
||||
56
exo/inference/test_dummy_inference_engine.py
Normal file
56
exo/inference/test_dummy_inference_engine.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
import json
|
||||
import numpy as np
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
class MockShardDownloader:
|
||||
async def ensure_shard(self, shard):
|
||||
pass
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_inference_specific():
|
||||
engine = DummyInferenceEngine(MockShardDownloader())
|
||||
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
test_prompt = "This is a test prompt"
|
||||
|
||||
result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
|
||||
|
||||
print(f"Inference result shape: {result.shape}")
|
||||
print(f"Inference state: {state}")
|
||||
print(f"Is finished: {is_finished}")
|
||||
|
||||
assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
|
||||
assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
|
||||
assert isinstance(is_finished, bool), "is_finished should be a boolean"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_inference_engine():
|
||||
# Initialize the DummyInferenceEngine
|
||||
engine = DummyInferenceEngine(MockShardDownloader())
|
||||
|
||||
# Create a test shard
|
||||
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
|
||||
# Test infer_prompt
|
||||
output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
assert isinstance(state, str), "State should be a string"
|
||||
assert isinstance(is_finished, bool), "is_finished should be a boolean"
|
||||
|
||||
# Test infer_tensor
|
||||
input_tensor = np.array([[1, 2, 3]])
|
||||
output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
assert isinstance(state, str), "State should be a string"
|
||||
assert isinstance(is_finished, bool), "is_finished should be a boolean"
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(test_dummy_inference_engine())
|
||||
asyncio.run(test_dummy_inference_specific())
|
||||
@@ -20,6 +20,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration.node import Node
|
||||
from exo.models import model_base_shards
|
||||
@@ -44,13 +45,15 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
|
||||
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
||||
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
||||
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
|
||||
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
|
||||
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
|
||||
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
||||
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
||||
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
||||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||
args = parser.parse_args()
|
||||
print(f"Selected inference engine: {args.inference_engine}")
|
||||
|
||||
|
||||
print_yellow_exo()
|
||||
|
||||
@@ -60,6 +63,8 @@ print(f"Detected system: {system_info}")
|
||||
|
||||
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
|
||||
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
||||
print(f"Inference engine name after selection: {inference_engine_name}")
|
||||
|
||||
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
||||
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user