mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix modelpool, add tests in test/test_model_helpers.py
This commit is contained in:
@@ -126,6 +126,7 @@ jobs:
|
||||
METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
|
||||
echo "Running tokenizer tests..."
|
||||
python3 ./test/test_tokenizers.py
|
||||
python3 ./test/test_model_helpers.py
|
||||
|
||||
discovery_integration_test:
|
||||
macos:
|
||||
|
||||
@@ -13,11 +13,9 @@ import sys
|
||||
from exo import DEBUG, VERSION
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.helpers import PrefixDict, shutdown
|
||||
from exo.inference.inference_engine import inference_engine_classes
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration import Node
|
||||
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
||||
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
|
||||
from typing import Callable, Optional
|
||||
|
||||
class Message:
|
||||
@@ -218,18 +216,7 @@ class ChatGPTAPI:
|
||||
return web.json_response({
|
||||
"model pool": {
|
||||
model_name: pretty_name.get(model_name, model_name)
|
||||
for model_name in [
|
||||
model_id for model_id, model_info in model_cards.items()
|
||||
if all(map(
|
||||
lambda engine: engine in model_info["repo"],
|
||||
list(dict.fromkeys([
|
||||
inference_engine_classes.get(engine_name, None)
|
||||
for engine_list in self.node.topology_inference_engines_pool
|
||||
for engine_name in engine_list
|
||||
if engine_name is not None
|
||||
] + [self.inference_engine_classname]))
|
||||
))
|
||||
]
|
||||
for model_name in get_supported_models(self.node.topology_inference_engines_pool)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.inference.shard import Shard
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
model_cards = {
|
||||
### llama
|
||||
@@ -125,3 +125,24 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
|
||||
return None
|
||||
return Shard(model_id, 0, 0, n_layers)
|
||||
|
||||
def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
|
||||
if not supported_inference_engine_lists:
|
||||
return list(model_cards.keys())
|
||||
|
||||
from exo.inference.inference_engine import inference_engine_classes
|
||||
supported_inference_engine_lists = [
|
||||
[inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
|
||||
for engine_list in supported_inference_engine_lists
|
||||
]
|
||||
|
||||
def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
|
||||
return any(engine in model_info.get("repo", {}) for engine in engine_list)
|
||||
|
||||
def supports_all_engine_lists(model_info: dict) -> bool:
|
||||
return all(has_any_engine(model_info, engine_list)
|
||||
for engine_list in supported_inference_engine_lists)
|
||||
|
||||
return [
|
||||
model_id for model_id, model_info in model_cards.items()
|
||||
if supports_all_engine_lists(model_info)
|
||||
]
|
||||
|
||||
121
test/test_model_helpers.py
Normal file
121
test/test_model_helpers.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import unittest
|
||||
from exo.models import get_supported_models, model_cards
|
||||
from exo.inference.inference_engine import inference_engine_classes
|
||||
from typing import NamedTuple
|
||||
|
||||
class TestCase(NamedTuple):
|
||||
name: str
|
||||
engine_lists: list # Will contain short names, will be mapped to class names
|
||||
expected_models_contains: list
|
||||
min_count: int | None
|
||||
exact_count: int | None
|
||||
max_count: int | None
|
||||
|
||||
# Helper function to map short names to class names
|
||||
def expand_engine_lists(engine_lists):
|
||||
def map_engine(engine):
|
||||
return inference_engine_classes.get(engine, engine) # Return original name if not found
|
||||
|
||||
return [[map_engine(engine) for engine in sublist]
|
||||
for sublist in engine_lists]
|
||||
|
||||
test_cases = [
|
||||
TestCase(
|
||||
name="single_mlx_engine",
|
||||
engine_lists=[["mlx"]],
|
||||
expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
|
||||
min_count=10,
|
||||
exact_count=None,
|
||||
max_count=None
|
||||
),
|
||||
TestCase(
|
||||
name="single_tinygrad_engine",
|
||||
engine_lists=[["tinygrad"]],
|
||||
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
|
||||
min_count=5,
|
||||
exact_count=None,
|
||||
max_count=10
|
||||
),
|
||||
TestCase(
|
||||
name="multiple_engines_or",
|
||||
engine_lists=[["mlx", "tinygrad"], ["mlx"]],
|
||||
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
|
||||
min_count=10,
|
||||
exact_count=None,
|
||||
max_count=None
|
||||
),
|
||||
TestCase(
|
||||
name="multiple_engines_all",
|
||||
engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
|
||||
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
|
||||
min_count=10,
|
||||
exact_count=None,
|
||||
max_count=None
|
||||
),
|
||||
TestCase(
|
||||
name="distinct_engine_lists",
|
||||
engine_lists=[["mlx"], ["tinygrad"]],
|
||||
expected_models_contains=["llama-3.2-1b"],
|
||||
min_count=5,
|
||||
exact_count=None,
|
||||
max_count=10
|
||||
),
|
||||
TestCase(
|
||||
name="no_engines",
|
||||
engine_lists=[],
|
||||
expected_models_contains=None,
|
||||
min_count=None,
|
||||
exact_count=len(model_cards),
|
||||
max_count=None
|
||||
),
|
||||
TestCase(
|
||||
name="nonexistent_engine",
|
||||
engine_lists=[["NonexistentEngine"]],
|
||||
expected_models_contains=[],
|
||||
min_count=None,
|
||||
exact_count=0,
|
||||
max_count=None
|
||||
),
|
||||
TestCase(
|
||||
name="dummy_engine",
|
||||
engine_lists=[["dummy"]],
|
||||
expected_models_contains=["dummy"],
|
||||
min_count=None,
|
||||
exact_count=1,
|
||||
max_count=None
|
||||
),
|
||||
]
|
||||
|
||||
class TestModelHelpers(unittest.TestCase):
|
||||
def test_get_supported_models(self):
|
||||
for case in test_cases:
|
||||
with self.subTest(f"{case.name}_short_names"):
|
||||
result = get_supported_models(case.engine_lists)
|
||||
self._verify_results(case, result)
|
||||
|
||||
with self.subTest(f"{case.name}_class_names"):
|
||||
class_name_lists = expand_engine_lists(case.engine_lists)
|
||||
result = get_supported_models(class_name_lists)
|
||||
self._verify_results(case, result)
|
||||
|
||||
def _verify_results(self, case, result):
|
||||
if case.expected_models_contains:
|
||||
for model in case.expected_models_contains:
|
||||
self.assertIn(model, result)
|
||||
|
||||
if case.min_count:
|
||||
self.assertGreater(len(result), case.min_count)
|
||||
|
||||
if case.exact_count is not None:
|
||||
self.assertEqual(len(result), case.exact_count)
|
||||
|
||||
# Special case for distinct lists test
|
||||
if case.name == "distinct_engine_lists":
|
||||
self.assertLess(len(result), 10)
|
||||
self.assertNotIn("mistral-nemo", result)
|
||||
|
||||
if case.max_count:
|
||||
self.assertLess(len(result), case.max_count)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user