mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
update tinygrad version
This commit is contained in:
@@ -40,8 +40,7 @@ MODEL_PARAMS = {
|
||||
def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
|
||||
# build model
|
||||
linear = nn.Linear
|
||||
with Context(THREEFRY=0):
|
||||
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
|
||||
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
|
||||
|
||||
# load weights
|
||||
if model_path.is_dir():
|
||||
|
||||
@@ -225,9 +225,9 @@ class Transformer:
|
||||
h = inputs
|
||||
return h
|
||||
|
||||
def forward(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
|
||||
if x.shape[0:2] == (1, 1) and self.forward_jit is not None:
|
||||
return self.forward_jit(x, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
|
||||
def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
|
||||
if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
|
||||
return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
|
||||
return self.forward_base(x, start_pos, cache=cache)
|
||||
|
||||
def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
|
||||
|
||||
2
setup.py
2
setup.py
@@ -26,7 +26,7 @@ install_requires = [
|
||||
"tqdm==4.66.4",
|
||||
"transformers==4.46.3",
|
||||
"uuid==1.30",
|
||||
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
|
||||
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
|
||||
Reference in New Issue
Block a user