update tinygrad version

This commit is contained in:
Rory Clear
2024-11-22 21:12:25 +00:00
parent f38cd55565
commit 3384fc7294
3 changed files with 5 additions and 6 deletions

View File

@@ -40,8 +40,7 @@ MODEL_PARAMS = {
def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
# build model
linear = nn.Linear
with Context(THREEFRY=0):
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
# load weights
if model_path.is_dir():

View File

@@ -225,9 +225,9 @@ class Transformer:
h = inputs
return h
def forward(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
if x.shape[0:2] == (1, 1) and self.forward_jit is not None:
return self.forward_jit(x, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
return self.forward_base(x, start_pos, cache=cache)
def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):

View File

@@ -26,7 +26,7 @@ install_requires = [
"tqdm==4.66.4",
"transformers==4.46.3",
"uuid==1.30",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
]
extras_require = {