diff --git a/exo/inference/tinygrad/models/llama.py b/exo/inference/tinygrad/models/llama.py index d0cb1050..bc99fbfd 100644 --- a/exo/inference/tinygrad/models/llama.py +++ b/exo/inference/tinygrad/models/llama.py @@ -322,6 +322,6 @@ def fix_bf16(weights: Dict[Any, Tensor]): } if getenv("SUPPORT_BF16", 1): # TODO: without casting to float16, 70B llama OOM on tinybox. - return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()} + return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items() # TODO: check if device supports bf16 return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}