diff --git a/.circleci/config.yml b/.circleci/config.yml index fc4b8b4b..83654c97 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -126,6 +126,7 @@ jobs: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine echo "Running tokenizer tests..." python3 ./test/test_tokenizers.py + python3 ./test/test_model_helpers.py discovery_integration_test: macos: diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 0a1c85a5..979e2d4b 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -13,11 +13,9 @@ import sys from exo import DEBUG, VERSION from exo.download.download_progress import RepoProgressEvent from exo.helpers import PrefixDict, shutdown -from exo.inference.inference_engine import inference_engine_classes -from exo.inference.shard import Shard from exo.inference.tokenizers import resolve_tokenizer from exo.orchestration import Node -from exo.models import build_base_shard, model_cards, get_repo, pretty_name +from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models from typing import Callable, Optional class Message: @@ -218,18 +216,7 @@ class ChatGPTAPI: return web.json_response({ "model pool": { model_name: pretty_name.get(model_name, model_name) - for model_name in [ - model_id for model_id, model_info in model_cards.items() - if all(map( - lambda engine: engine in model_info["repo"], - list(dict.fromkeys([ - inference_engine_classes.get(engine_name, None) - for engine_list in self.node.topology_inference_engines_pool - for engine_name in engine_list - if engine_name is not None - ] + [self.inference_engine_classname])) - )) - ] + for model_name in get_supported_models(self.node.topology_inference_engines_pool) } }) diff --git a/exo/models.py b/exo/models.py index 0561220f..1fb567a6 100644 --- a/exo/models.py +++ b/exo/models.py @@ -1,11 +1,11 @@ from exo.inference.shard import Shard -from typing import Optional +from typing import Optional, List model_cards = { ### llama "llama-3.2-1b": { "layers": 16, - "repo": { + "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit", "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct", }, @@ -124,4 +124,25 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional if repo is None or n_layers < 1: return None return Shard(model_id, 0, 0, n_layers) - + +def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]: + if not supported_inference_engine_lists: + return list(model_cards.keys()) + + from exo.inference.inference_engine import inference_engine_classes + supported_inference_engine_lists = [ + [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list] + for engine_list in supported_inference_engine_lists + ] + + def has_any_engine(model_info: dict, engine_list: List[str]) -> bool: + return any(engine in model_info.get("repo", {}) for engine in engine_list) + + def supports_all_engine_lists(model_info: dict) -> bool: + return all(has_any_engine(model_info, engine_list) + for engine_list in supported_inference_engine_lists) + + return [ + model_id for model_id, model_info in model_cards.items() + if supports_all_engine_lists(model_info) + ] diff --git a/test/test_model_helpers.py b/test/test_model_helpers.py new file mode 100644 index 00000000..8c01104e --- /dev/null +++ b/test/test_model_helpers.py @@ -0,0 +1,121 @@ +import unittest +from exo.models import get_supported_models, model_cards +from exo.inference.inference_engine import inference_engine_classes +from typing import NamedTuple + +class TestCase(NamedTuple): + name: str + engine_lists: list # Will contain short names, will be mapped to class names + expected_models_contains: list + min_count: int | None + exact_count: int | None + max_count: int | None + +# Helper function to map short names to class names +def expand_engine_lists(engine_lists): + def map_engine(engine): + return inference_engine_classes.get(engine, engine) # Return original name if not found + + return [[map_engine(engine) for engine in sublist] + for sublist in engine_lists] + +test_cases = [ + TestCase( + name="single_mlx_engine", + engine_lists=[["mlx"]], + expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"], + min_count=10, + exact_count=None, + max_count=None + ), + TestCase( + name="single_tinygrad_engine", + engine_lists=[["tinygrad"]], + expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"], + min_count=5, + exact_count=None, + max_count=10 + ), + TestCase( + name="multiple_engines_or", + engine_lists=[["mlx", "tinygrad"], ["mlx"]], + expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"], + min_count=10, + exact_count=None, + max_count=None + ), + TestCase( + name="multiple_engines_all", + engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]], + expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"], + min_count=10, + exact_count=None, + max_count=None + ), + TestCase( + name="distinct_engine_lists", + engine_lists=[["mlx"], ["tinygrad"]], + expected_models_contains=["llama-3.2-1b"], + min_count=5, + exact_count=None, + max_count=10 + ), + TestCase( + name="no_engines", + engine_lists=[], + expected_models_contains=None, + min_count=None, + exact_count=len(model_cards), + max_count=None + ), + TestCase( + name="nonexistent_engine", + engine_lists=[["NonexistentEngine"]], + expected_models_contains=[], + min_count=None, + exact_count=0, + max_count=None + ), + TestCase( + name="dummy_engine", + engine_lists=[["dummy"]], + expected_models_contains=["dummy"], + min_count=None, + exact_count=1, + max_count=None + ), +] + +class TestModelHelpers(unittest.TestCase): + def test_get_supported_models(self): + for case in test_cases: + with self.subTest(f"{case.name}_short_names"): + result = get_supported_models(case.engine_lists) + self._verify_results(case, result) + + with self.subTest(f"{case.name}_class_names"): + class_name_lists = expand_engine_lists(case.engine_lists) + result = get_supported_models(class_name_lists) + self._verify_results(case, result) + + def _verify_results(self, case, result): + if case.expected_models_contains: + for model in case.expected_models_contains: + self.assertIn(model, result) + + if case.min_count: + self.assertGreater(len(result), case.min_count) + + if case.exact_count is not None: + self.assertEqual(len(result), case.exact_count) + + # Special case for distinct lists test + if case.name == "distinct_engine_lists": + self.assertLess(len(result), 10) + self.assertNotIn("mistral-nemo", result) + + if case.max_count: + self.assertLess(len(result), case.max_count) + +if __name__ == '__main__': + unittest.main()