mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge branch 'main' into package-exo-app
This commit is contained in:
@@ -425,7 +425,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
elif shard.is_last_layer():
|
||||
shard_specific_patterns.add(sorted_file_names[-1])
|
||||
else:
|
||||
shard_specific_patterns = set("*.safetensors")
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from exo.inference.tokenizers import resolve_tokenizer
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad import Tensor, nn, Context
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
@@ -68,24 +67,21 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
||||
logits = x[:, -1, :]
|
||||
def sample_wrapper():
|
||||
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
|
||||
out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
||||
return out.numpy().astype(int)
|
||||
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
||||
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
||||
return np.array(tokens)
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
|
||||
|
||||
async def decode(self, shard: Shard, tokens) -> str:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
||||
return tokens
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
|
||||
return output_data.numpy()
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
|
||||
@@ -208,7 +208,7 @@ async def main():
|
||||
{"❌ No read access" if not has_read else ""}
|
||||
{"❌ No write access" if not has_write else ""}
|
||||
""")
|
||||
|
||||
|
||||
if not args.models_seed_dir is None:
|
||||
try:
|
||||
await move_models_to_hf(args.models_seed_dir)
|
||||
|
||||
7
setup.py
7
setup.py
@@ -8,8 +8,8 @@ install_requires = [
|
||||
"aiohttp==3.10.11",
|
||||
"aiohttp_cors==0.7.0",
|
||||
"aiofiles==24.1.0",
|
||||
"grpcio==1.64.1",
|
||||
"grpcio-tools==1.64.1",
|
||||
"grpcio==1.68.0",
|
||||
"grpcio-tools==1.68.0",
|
||||
"Jinja2==3.1.4",
|
||||
"netifaces==0.11.0",
|
||||
"numpy==2.0.0",
|
||||
@@ -22,10 +22,9 @@ install_requires = [
|
||||
"pydantic==2.9.2",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"safetensors==0.4.3",
|
||||
"tenacity==9.0.0",
|
||||
"tqdm==4.66.4",
|
||||
"transformers==4.43.3",
|
||||
"transformers==4.46.3",
|
||||
"uuid==1.30",
|
||||
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user