Minor fix for Shard typing

This commit is contained in:
Sandesh Bharadwaj
2025-01-16 14:36:46 -05:00
parent df3624d27a
commit 349b5344eb

View File

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG # Make sure to import DEBUG
from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
from exo.download.shard_download import ShardDownloader
class InferenceEngine(ABC):
@@ -55,7 +56,7 @@ inference_engine_classes = {
"dummy": "DummyInferenceEngine",
}
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")
if inference_engine_name == "mlx":