simplify tinygrad non blocking

This commit is contained in:
Alex Cheema
2024-09-05 16:39:11 +01:00
parent 6342384df4
commit 4ec613d4e8

View File

@@ -56,7 +56,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
self.model_lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=1)
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
@@ -64,9 +63,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
toks = self.tokenizer.encode(prompt)
toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
input_tensor = Tensor([toks])
h = await self._run_inference(Tensor([toks]), start_pos)
h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
if h.shape == (1,):
start_pos += len(toks)
@@ -82,7 +82,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
h = await self._run_inference(Tensor(input_data), start_pos)
h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
if h.shape == (1,):
start_pos += n_captured_toks
@@ -92,19 +92,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
else:
return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
async def _run_inference(self, input_tensor, start_pos):
with self.model_lock:
return await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return
model_path = await self.shard_downloader.ensure_shard(shard)
with self.model_lock:
if self.shard != shard:
self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
self.tokenizer = await resolve_tokenizer(tokenizer_path)
self.shard = shard
if self.shard != shard:
self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
self.tokenizer = await resolve_tokenizer(tokenizer_path)
self.shard = shard