DummyInferenceEngine commit 1

This commit is contained in:
rahat2134
2024-10-19 22:47:44 +05:30
parent 1e4524b5c0
commit 7d6104750a
5 changed files with 140 additions and 4 deletions

2
.gitignore vendored
View File

@@ -1,5 +1,5 @@
__pycache__/
.venv
.venv*
test_weights.npz
.exo_used_ports
.exo_node_id

View File

@@ -0,0 +1,68 @@
from typing import Optional, Tuple, TYPE_CHECKING
import numpy as np
import asyncio
import json
from typing import Optional, Tuple
if TYPE_CHECKING:
from exo.inference.inference_engine import InferenceEngine
from exo.inference.inference_engine import InferenceEngine
from exo.inference.shard import Shard
class DummyInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader):
self.shard = None
self.shard_downloader = shard_downloader
self.vocab_size = 1000
self.eos_token_id = 0
self.latency_mean = 0.1
self.latency_stddev = 0.02
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
try:
await self.ensure_shard(shard)
# Generate random tokens
output_length = np.random.randint(1, 10)
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
# Simulate latency
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
# Randomly decide if finished
is_finished = np.random.random() < 0.2
if is_finished:
output = np.array([[self.eos_token_id]])
new_state = json.dumps({"dummy_state": "some_value"})
return output, new_state, is_finished
except Exception as e:
print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
await self.ensure_shard(shard)
state = json.loads(inference_state or "{}")
start_pos = state.get("start_pos", 0)
output_length = np.random.randint(1, 10)
output = np.random.randint(1, self.vocab_size, size=(1, output_length))
await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
is_finished = np.random.random() < 0.2
if is_finished:
output = np.array([[self.eos_token_id]])
start_pos += input_data.shape[1] + output_length
new_state = json.dumps({"start_pos": start_pos})
return output, new_state, is_finished
async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return
# Simulate shard loading without making any API calls
await asyncio.sleep(0.1) # Simulate a short delay
self.shard = shard
print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

View File

@@ -8,7 +8,7 @@ from .shard import Shard
class InferenceEngine(ABC):
@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass
@abstractmethod
@@ -17,6 +17,7 @@ class InferenceEngine(ABC):
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
print(f"get_inference_engine called with: {inference_engine_name}") # Debug print
if inference_engine_name == "mlx":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
@@ -27,5 +28,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
return TinygradDynamicShardInferenceEngine(shard_downloader)
elif inference_engine_name == "dummy":
from exo.inference.dummy_inference_engine import DummyInferenceEngine
return DummyInferenceEngine(shard_downloader)
else:
raise ValueError(f"Inference engine {inference_engine_name} not supported")
raise ValueError(f"Inference engine {inference_engine_name} not supported. Supported engines are 'mlx', 'tinygrad', and 'dummy'.")

View File

@@ -0,0 +1,40 @@
import pytest
import numpy as np
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.shard import Shard
@pytest.mark.asyncio
async def test_dummy_inference_engine():
# Create a mock shard downloader
class MockShardDownloader:
async def ensure_shard(self, shard):
pass
# Initialize the DummyInferenceEngine
engine = DummyInferenceEngine(MockShardDownloader())
# Create a test shard
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
# Test infer_prompt
output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
assert isinstance(state, str), "State should be a string"
assert isinstance(is_finished, bool), "is_finished should be a boolean"
# Test infer_tensor
input_tensor = np.array([[1, 2, 3]])
output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
assert isinstance(state, str), "State should be a string"
assert isinstance(is_finished, bool), "is_finished should be a boolean"
print("All tests passed!")
if __name__ == "__main__":
import asyncio
asyncio.run(test_dummy_inference_engine())

View File

@@ -18,6 +18,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration.node import Node
from exo.models import model_base_shards
@@ -41,13 +42,15 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
print_yellow_exo()
@@ -56,6 +59,15 @@ print(f"Detected system: {system_info}")
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
print(f"Inference engine name after selection: {inference_engine_name}")
if inference_engine_name not in ["mlx", "tinygrad", "dummy"]:
print(f"Warning: Unknown inference engine '{inference_engine_name}'. Defaulting to 'tinygrad'.")
inference_engine_name = "tinygrad"
else:
print(f"Using selected inference engine: {inference_engine_name}")
print(f"About to call get_inference_engine with: {inference_engine_name}")
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
@@ -173,6 +185,16 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
node.on_token.deregister(callback_id)
async def test_dummy_inference(inference_engine):
print("Testing DummyInferenceEngine...")
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
test_prompt = "This is a test prompt"
result, state, is_finished = await inference_engine.infer_prompt("test_request", test_shard, test_prompt)
print(f"Inference result shape: {result.shape}")
print(f"Inference state: {state}")
print(f"Is finished: {is_finished}")
async def main():
loop = asyncio.get_running_loop()
@@ -193,6 +215,8 @@ async def main():
await run_model_cli(node, inference_engine, model_name, args.prompt)
else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
if isinstance(node.inference_engine, DummyInferenceEngine):
await test_dummy_inference(node.inference_engine)
await asyncio.Event().wait()