mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
simplify tinygrad non blocking
This commit is contained in:
@@ -56,7 +56,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
self.model_lock = threading.Lock()
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
||||
@@ -64,9 +63,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
||||
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
||||
|
||||
toks = self.tokenizer.encode(prompt)
|
||||
toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
||||
input_tensor = Tensor([toks])
|
||||
|
||||
h = await self._run_inference(Tensor([toks]), start_pos)
|
||||
h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
|
||||
|
||||
if h.shape == (1,):
|
||||
start_pos += len(toks)
|
||||
@@ -82,7 +82,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
||||
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
||||
|
||||
h = await self._run_inference(Tensor(input_data), start_pos)
|
||||
h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
|
||||
|
||||
if h.shape == (1,):
|
||||
start_pos += n_captured_toks
|
||||
@@ -92,19 +92,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
else:
|
||||
return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
|
||||
|
||||
async def _run_inference(self, input_tensor, start_pos):
|
||||
with self.model_lock:
|
||||
return await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
return
|
||||
|
||||
model_path = await self.shard_downloader.ensure_shard(shard)
|
||||
|
||||
with self.model_lock:
|
||||
if self.shard != shard:
|
||||
self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
|
||||
tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
||||
self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
||||
self.shard = shard
|
||||
if self.shard != shard:
|
||||
self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
|
||||
|
||||
tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
||||
self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
||||
self.shard = shard
|
||||
|
||||
Reference in New Issue
Block a user