fix modelpool, add tests in test/test_model_helpers.py

This commit is contained in:
Alex Cheema
2024-11-18 20:52:06 +04:00
committed by josh
parent 559f12e7d0
commit 1b7e67832c
4 changed files with 148 additions and 18 deletions

View File

@@ -126,6 +126,7 @@ jobs:
METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
echo "Running tokenizer tests..." echo "Running tokenizer tests..."
python3 ./test/test_tokenizers.py python3 ./test/test_tokenizers.py
python3 ./test/test_model_helpers.py
discovery_integration_test: discovery_integration_test:
macos: macos:

View File

@@ -13,11 +13,9 @@ import sys
from exo import DEBUG, VERSION from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown from exo.helpers import PrefixDict, shutdown
from exo.inference.inference_engine import inference_engine_classes
from exo.inference.shard import Shard
from exo.inference.tokenizers import resolve_tokenizer from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
from typing import Callable, Optional from typing import Callable, Optional
class Message: class Message:
@@ -218,18 +216,7 @@ class ChatGPTAPI:
return web.json_response({ return web.json_response({
"model pool": { "model pool": {
model_name: pretty_name.get(model_name, model_name) model_name: pretty_name.get(model_name, model_name)
for model_name in [ for model_name in get_supported_models(self.node.topology_inference_engines_pool)
model_id for model_id, model_info in model_cards.items()
if all(map(
lambda engine: engine in model_info["repo"],
list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))
))
]
} }
}) })

View File

@@ -1,11 +1,11 @@
from exo.inference.shard import Shard from exo.inference.shard import Shard
from typing import Optional from typing import Optional, List
model_cards = { model_cards = {
### llama ### llama
"llama-3.2-1b": { "llama-3.2-1b": {
"layers": 16, "layers": 16,
"repo": { "repo": {
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit", "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct", "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
}, },
@@ -124,4 +124,25 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
if repo is None or n_layers < 1: if repo is None or n_layers < 1:
return None return None
return Shard(model_id, 0, 0, n_layers) return Shard(model_id, 0, 0, n_layers)
def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
if not supported_inference_engine_lists:
return list(model_cards.keys())
from exo.inference.inference_engine import inference_engine_classes
supported_inference_engine_lists = [
[inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
for engine_list in supported_inference_engine_lists
]
def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
return any(engine in model_info.get("repo", {}) for engine in engine_list)
def supports_all_engine_lists(model_info: dict) -> bool:
return all(has_any_engine(model_info, engine_list)
for engine_list in supported_inference_engine_lists)
return [
model_id for model_id, model_info in model_cards.items()
if supports_all_engine_lists(model_info)
]

121
test/test_model_helpers.py Normal file
View File

@@ -0,0 +1,121 @@
import unittest
from exo.models import get_supported_models, model_cards
from exo.inference.inference_engine import inference_engine_classes
from typing import NamedTuple
class TestCase(NamedTuple):
name: str
engine_lists: list # Will contain short names, will be mapped to class names
expected_models_contains: list
min_count: int | None
exact_count: int | None
max_count: int | None
# Helper function to map short names to class names
def expand_engine_lists(engine_lists):
def map_engine(engine):
return inference_engine_classes.get(engine, engine) # Return original name if not found
return [[map_engine(engine) for engine in sublist]
for sublist in engine_lists]
test_cases = [
TestCase(
name="single_mlx_engine",
engine_lists=[["mlx"]],
expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
min_count=10,
exact_count=None,
max_count=None
),
TestCase(
name="single_tinygrad_engine",
engine_lists=[["tinygrad"]],
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
min_count=5,
exact_count=None,
max_count=10
),
TestCase(
name="multiple_engines_or",
engine_lists=[["mlx", "tinygrad"], ["mlx"]],
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
min_count=10,
exact_count=None,
max_count=None
),
TestCase(
name="multiple_engines_all",
engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
min_count=10,
exact_count=None,
max_count=None
),
TestCase(
name="distinct_engine_lists",
engine_lists=[["mlx"], ["tinygrad"]],
expected_models_contains=["llama-3.2-1b"],
min_count=5,
exact_count=None,
max_count=10
),
TestCase(
name="no_engines",
engine_lists=[],
expected_models_contains=None,
min_count=None,
exact_count=len(model_cards),
max_count=None
),
TestCase(
name="nonexistent_engine",
engine_lists=[["NonexistentEngine"]],
expected_models_contains=[],
min_count=None,
exact_count=0,
max_count=None
),
TestCase(
name="dummy_engine",
engine_lists=[["dummy"]],
expected_models_contains=["dummy"],
min_count=None,
exact_count=1,
max_count=None
),
]
class TestModelHelpers(unittest.TestCase):
def test_get_supported_models(self):
for case in test_cases:
with self.subTest(f"{case.name}_short_names"):
result = get_supported_models(case.engine_lists)
self._verify_results(case, result)
with self.subTest(f"{case.name}_class_names"):
class_name_lists = expand_engine_lists(case.engine_lists)
result = get_supported_models(class_name_lists)
self._verify_results(case, result)
def _verify_results(self, case, result):
if case.expected_models_contains:
for model in case.expected_models_contains:
self.assertIn(model, result)
if case.min_count:
self.assertGreater(len(result), case.min_count)
if case.exact_count is not None:
self.assertEqual(len(result), case.exact_count)
# Special case for distinct lists test
if case.name == "distinct_engine_lists":
self.assertLess(len(result), 10)
self.assertNotIn("mistral-nemo", result)
if case.max_count:
self.assertLess(len(result), case.max_count)
if __name__ == '__main__':
unittest.main()