Merge pull request #364 from rahat2134/DummyInferenceEngine

Implementation of DummyInferenceEngine
This commit is contained in:
Alex Cheema
2024-10-23 18:29:05 -07:00
committed by GitHub
6 changed files with 169 additions and 8 deletions

View File

@@ -44,6 +44,13 @@ commands:
# Check processes before proceeding
check_processes
# Special handling for dummy engine
if [ "<<parameters.inference_engine>>" = "dummy" ]; then
expected_content="This is a dummy response"
else
expected_content="Michael Jackson"
fi
echo "Sending request to first instance..."
response_1=$(curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
@@ -74,8 +81,8 @@ commands:
kill $PID1 $PID2
echo ""
if ! echo "$response_1" | grep -q "Michael Jackson" || ! echo "$response_2" | grep -q "Michael Jackson"; then
echo "Test failed: Response does not contain 'Michael Jackson'"
if ! echo "$response_1" | grep -q "$expected_content" || ! echo "$response_2" | grep -q "$expected_content"; then
echo "Test failed: Response does not contain '$expected_content'"
echo "Response 1: $response_1"
echo ""
echo "Response 2: $response_2"
@@ -85,7 +92,7 @@ commands:
cat output2.log
exit 1
else
echo "Test passed: Response from both nodes contains 'Michael Jackson'"
echo "Test passed: Response from both nodes contains '$expected_content'"
fi
jobs:
@@ -178,6 +185,28 @@ jobs:
inference_engine: mlx
model_id: llama-3.2-1b
chatgpt_api_integration_test_dummy:
macos:
xcode: "16.0.0"
resource_class: m2pro.large
steps:
- checkout
- run:
name: Set up Python
command: |
brew install python@3.12
python3.12 -m venv env
source env/bin/activate
- run:
name: Install dependencies
command: |
source env/bin/activate
pip install --upgrade pip
pip install .
- run_chatgpt_api_test:
inference_engine: dummy
model_id: dummy-model
test_macos_m1:
macos:
xcode: "16.0.0"
@@ -215,5 +244,6 @@ workflows:
- unit_test
- discovery_integration_test
- chatgpt_api_integration_test_mlx
- chatgpt_api_integration_test_dummy
- test_macos_m1
# - chatgpt_api_integration_test_tinygrad

2
.gitignore vendored
View File

@@ -1,5 +1,5 @@
__pycache__/
.venv
.venv*
test_weights.npz
.exo_used_ports
.exo_node_id

View File

@@ -0,0 +1,65 @@
from typing import Optional, Tuple, TYPE_CHECKING
import numpy as np
import asyncio
import json
from exo.inference.inference_engine import InferenceEngine
from exo.inference.shard import Shard
class DummyInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader):
self.shard = None
self.shard_downloader = shard_downloader
self.vocab_size = 1000
self.eos_token_id = 0
self.latency_mean = 0.1
self.latency_stddev = 0.02
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
try:
await self.ensure_shard(shard)
# Generate random tokens
output_length = np.random.randint(1, 10)
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
# Simulate latency
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
# Randomly decide if finished
is_finished = np.random.random() < 0.2
if is_finished:
output = np.array([[self.eos_token_id]])
new_state = json.dumps({"dummy_state": "some_value"})
return output, new_state, is_finished
except Exception as e:
print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
await self.ensure_shard(shard)
state = json.loads(inference_state or "{}")
start_pos = state.get("start_pos", 0)
output_length = np.random.randint(1, 10)
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
is_finished = np.random.random() < 0.2
if is_finished:
output = np.array([[self.eos_token_id]])
start_pos += input_data.shape[1] + output_length
new_state = json.dumps({"start_pos": start_pos})
return output, new_state, is_finished
async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return
# Simulate shard loading without making any API calls
await asyncio.sleep(0.1) # Simulate a short delay
self.shard = shard
print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

View File

@@ -1,5 +1,6 @@
import numpy as np
import os
from exo.helpers import DEBUG # Make sure to import DEBUG
from typing import Tuple, Optional
from abc import ABC, abstractmethod
@@ -8,7 +9,7 @@ from .shard import Shard
class InferenceEngine(ABC):
@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass
@abstractmethod
@@ -17,6 +18,8 @@ class InferenceEngine(ABC):
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")
if inference_engine_name == "mlx":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
@@ -27,5 +30,7 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
return TinygradDynamicShardInferenceEngine(shard_downloader)
else:
raise ValueError(f"Inference engine {inference_engine_name} not supported")
elif inference_engine_name == "dummy":
from exo.inference.dummy_inference_engine import DummyInferenceEngine
return DummyInferenceEngine(shard_downloader)
raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

View File

@@ -0,0 +1,56 @@
import pytest
import json
import numpy as np
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.shard import Shard
class MockShardDownloader:
async def ensure_shard(self, shard):
pass
@pytest.mark.asyncio
async def test_dummy_inference_specific():
engine = DummyInferenceEngine(MockShardDownloader())
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
test_prompt = "This is a test prompt"
result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
print(f"Inference result shape: {result.shape}")
print(f"Inference state: {state}")
print(f"Is finished: {is_finished}")
assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
assert isinstance(is_finished, bool), "is_finished should be a boolean"
@pytest.mark.asyncio
async def test_dummy_inference_engine():
# Initialize the DummyInferenceEngine
engine = DummyInferenceEngine(MockShardDownloader())
# Create a test shard
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
# Test infer_prompt
output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
assert isinstance(state, str), "State should be a string"
assert isinstance(is_finished, bool), "is_finished should be a boolean"
# Test infer_tensor
input_tensor = np.array([[1, 2, 3]])
output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
assert isinstance(state, str), "State should be a string"
assert isinstance(is_finished, bool), "is_finished should be a boolean"
print("All tests passed!")
if __name__ == "__main__":
import asyncio
asyncio.run(test_dummy_inference_engine())
asyncio.run(test_dummy_inference_specific())

View File

@@ -20,6 +20,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration.node import Node
from exo.models import model_base_shards
@@ -44,13 +45,15 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
print_yellow_exo()
@@ -60,6 +63,8 @@ print(f"Detected system: {system_info}")
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
print(f"Inference engine name after selection: {inference_engine_name}")
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")