mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix modelpool, add tests in test/test_model_helpers.py
This commit is contained in:
@@ -126,6 +126,7 @@ jobs:
|
|||||||
METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
|
METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
|
||||||
echo "Running tokenizer tests..."
|
echo "Running tokenizer tests..."
|
||||||
python3 ./test/test_tokenizers.py
|
python3 ./test/test_tokenizers.py
|
||||||
|
python3 ./test/test_model_helpers.py
|
||||||
|
|
||||||
discovery_integration_test:
|
discovery_integration_test:
|
||||||
macos:
|
macos:
|
||||||
|
|||||||
@@ -13,11 +13,9 @@ import sys
|
|||||||
from exo import DEBUG, VERSION
|
from exo import DEBUG, VERSION
|
||||||
from exo.download.download_progress import RepoProgressEvent
|
from exo.download.download_progress import RepoProgressEvent
|
||||||
from exo.helpers import PrefixDict, shutdown
|
from exo.helpers import PrefixDict, shutdown
|
||||||
from exo.inference.inference_engine import inference_engine_classes
|
|
||||||
from exo.inference.shard import Shard
|
|
||||||
from exo.inference.tokenizers import resolve_tokenizer
|
from exo.inference.tokenizers import resolve_tokenizer
|
||||||
from exo.orchestration import Node
|
from exo.orchestration import Node
|
||||||
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
@@ -218,18 +216,7 @@ class ChatGPTAPI:
|
|||||||
return web.json_response({
|
return web.json_response({
|
||||||
"model pool": {
|
"model pool": {
|
||||||
model_name: pretty_name.get(model_name, model_name)
|
model_name: pretty_name.get(model_name, model_name)
|
||||||
for model_name in [
|
for model_name in get_supported_models(self.node.topology_inference_engines_pool)
|
||||||
model_id for model_id, model_info in model_cards.items()
|
|
||||||
if all(map(
|
|
||||||
lambda engine: engine in model_info["repo"],
|
|
||||||
list(dict.fromkeys([
|
|
||||||
inference_engine_classes.get(engine_name, None)
|
|
||||||
for engine_list in self.node.topology_inference_engines_pool
|
|
||||||
for engine_name in engine_list
|
|
||||||
if engine_name is not None
|
|
||||||
] + [self.inference_engine_classname]))
|
|
||||||
))
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from exo.inference.shard import Shard
|
from exo.inference.shard import Shard
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
model_cards = {
|
model_cards = {
|
||||||
### llama
|
### llama
|
||||||
"llama-3.2-1b": {
|
"llama-3.2-1b": {
|
||||||
"layers": 16,
|
"layers": 16,
|
||||||
"repo": {
|
"repo": {
|
||||||
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
|
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
|
||||||
},
|
},
|
||||||
@@ -124,4 +124,25 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
|
|||||||
if repo is None or n_layers < 1:
|
if repo is None or n_layers < 1:
|
||||||
return None
|
return None
|
||||||
return Shard(model_id, 0, 0, n_layers)
|
return Shard(model_id, 0, 0, n_layers)
|
||||||
|
|
||||||
|
def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
|
||||||
|
if not supported_inference_engine_lists:
|
||||||
|
return list(model_cards.keys())
|
||||||
|
|
||||||
|
from exo.inference.inference_engine import inference_engine_classes
|
||||||
|
supported_inference_engine_lists = [
|
||||||
|
[inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
|
||||||
|
for engine_list in supported_inference_engine_lists
|
||||||
|
]
|
||||||
|
|
||||||
|
def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
|
||||||
|
return any(engine in model_info.get("repo", {}) for engine in engine_list)
|
||||||
|
|
||||||
|
def supports_all_engine_lists(model_info: dict) -> bool:
|
||||||
|
return all(has_any_engine(model_info, engine_list)
|
||||||
|
for engine_list in supported_inference_engine_lists)
|
||||||
|
|
||||||
|
return [
|
||||||
|
model_id for model_id, model_info in model_cards.items()
|
||||||
|
if supports_all_engine_lists(model_info)
|
||||||
|
]
|
||||||
|
|||||||
121
test/test_model_helpers.py
Normal file
121
test/test_model_helpers.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import unittest
|
||||||
|
from exo.models import get_supported_models, model_cards
|
||||||
|
from exo.inference.inference_engine import inference_engine_classes
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
class TestCase(NamedTuple):
|
||||||
|
name: str
|
||||||
|
engine_lists: list # Will contain short names, will be mapped to class names
|
||||||
|
expected_models_contains: list
|
||||||
|
min_count: int | None
|
||||||
|
exact_count: int | None
|
||||||
|
max_count: int | None
|
||||||
|
|
||||||
|
# Helper function to map short names to class names
|
||||||
|
def expand_engine_lists(engine_lists):
|
||||||
|
def map_engine(engine):
|
||||||
|
return inference_engine_classes.get(engine, engine) # Return original name if not found
|
||||||
|
|
||||||
|
return [[map_engine(engine) for engine in sublist]
|
||||||
|
for sublist in engine_lists]
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
TestCase(
|
||||||
|
name="single_mlx_engine",
|
||||||
|
engine_lists=[["mlx"]],
|
||||||
|
expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
|
||||||
|
min_count=10,
|
||||||
|
exact_count=None,
|
||||||
|
max_count=None
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="single_tinygrad_engine",
|
||||||
|
engine_lists=[["tinygrad"]],
|
||||||
|
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
|
||||||
|
min_count=5,
|
||||||
|
exact_count=None,
|
||||||
|
max_count=10
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="multiple_engines_or",
|
||||||
|
engine_lists=[["mlx", "tinygrad"], ["mlx"]],
|
||||||
|
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
|
||||||
|
min_count=10,
|
||||||
|
exact_count=None,
|
||||||
|
max_count=None
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="multiple_engines_all",
|
||||||
|
engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
|
||||||
|
expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
|
||||||
|
min_count=10,
|
||||||
|
exact_count=None,
|
||||||
|
max_count=None
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="distinct_engine_lists",
|
||||||
|
engine_lists=[["mlx"], ["tinygrad"]],
|
||||||
|
expected_models_contains=["llama-3.2-1b"],
|
||||||
|
min_count=5,
|
||||||
|
exact_count=None,
|
||||||
|
max_count=10
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="no_engines",
|
||||||
|
engine_lists=[],
|
||||||
|
expected_models_contains=None,
|
||||||
|
min_count=None,
|
||||||
|
exact_count=len(model_cards),
|
||||||
|
max_count=None
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="nonexistent_engine",
|
||||||
|
engine_lists=[["NonexistentEngine"]],
|
||||||
|
expected_models_contains=[],
|
||||||
|
min_count=None,
|
||||||
|
exact_count=0,
|
||||||
|
max_count=None
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="dummy_engine",
|
||||||
|
engine_lists=[["dummy"]],
|
||||||
|
expected_models_contains=["dummy"],
|
||||||
|
min_count=None,
|
||||||
|
exact_count=1,
|
||||||
|
max_count=None
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
class TestModelHelpers(unittest.TestCase):
|
||||||
|
def test_get_supported_models(self):
|
||||||
|
for case in test_cases:
|
||||||
|
with self.subTest(f"{case.name}_short_names"):
|
||||||
|
result = get_supported_models(case.engine_lists)
|
||||||
|
self._verify_results(case, result)
|
||||||
|
|
||||||
|
with self.subTest(f"{case.name}_class_names"):
|
||||||
|
class_name_lists = expand_engine_lists(case.engine_lists)
|
||||||
|
result = get_supported_models(class_name_lists)
|
||||||
|
self._verify_results(case, result)
|
||||||
|
|
||||||
|
def _verify_results(self, case, result):
|
||||||
|
if case.expected_models_contains:
|
||||||
|
for model in case.expected_models_contains:
|
||||||
|
self.assertIn(model, result)
|
||||||
|
|
||||||
|
if case.min_count:
|
||||||
|
self.assertGreater(len(result), case.min_count)
|
||||||
|
|
||||||
|
if case.exact_count is not None:
|
||||||
|
self.assertEqual(len(result), case.exact_count)
|
||||||
|
|
||||||
|
# Special case for distinct lists test
|
||||||
|
if case.name == "distinct_engine_lists":
|
||||||
|
self.assertLess(len(result), 10)
|
||||||
|
self.assertNotIn("mistral-nemo", result)
|
||||||
|
|
||||||
|
if case.max_count:
|
||||||
|
self.assertLess(len(result), case.max_count)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user