mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
add a test for hf get_weight_map
This commit is contained in:
26
test/test_hf.py
Normal file
26
test/test_hf.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add the project root to the Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import asyncio
|
||||
from exo.download.hf.hf_helpers import get_weight_map
|
||||
|
||||
async def test_get_weight_map():
|
||||
repo_ids = [
|
||||
"mlx-community/quantized-gemma-2b",
|
||||
"mlx-community/Meta-Llama-3.1-8B-4bit",
|
||||
"mlx-community/Meta-Llama-3.1-70B-4bit",
|
||||
"mlx-community/Meta-Llama-3.1-405B-4bit",
|
||||
]
|
||||
for repo_id in repo_ids:
|
||||
weight_map = await get_weight_map(repo_id)
|
||||
assert weight_map is not None, "Weight map should not be None"
|
||||
assert isinstance(weight_map, dict), "Weight map should be a dictionary"
|
||||
assert len(weight_map) > 0, "Weight map should not be empty"
|
||||
print(f"OK: {repo_id}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_get_weight_map())
|
||||
Reference in New Issue
Block a user