From 84187113de3c5d8171c987a03061ba1d7ca7c7a4 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Sep 2024 16:39:40 +0100 Subject: [PATCH] add a test for hf get_weight_map --- test/test_hf.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 test/test_hf.py diff --git a/test/test_hf.py b/test/test_hf.py new file mode 100644 index 00000000..0477d132 --- /dev/null +++ b/test/test_hf.py @@ -0,0 +1,26 @@ +import os +import sys + +# Add the project root to the Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + +import asyncio +from exo.download.hf.hf_helpers import get_weight_map + +async def test_get_weight_map(): + repo_ids = [ + "mlx-community/quantized-gemma-2b", + "mlx-community/Meta-Llama-3.1-8B-4bit", + "mlx-community/Meta-Llama-3.1-70B-4bit", + "mlx-community/Meta-Llama-3.1-405B-4bit", + ] + for repo_id in repo_ids: + weight_map = await get_weight_map(repo_id) + assert weight_map is not None, "Weight map should not be None" + assert isinstance(weight_map, dict), "Weight map should be a dictionary" + assert len(weight_map) > 0, "Weight map should not be empty" + print(f"OK: {repo_id}") + +if __name__ == "__main__": + asyncio.run(test_get_weight_map())