mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
code-breaking typo
oops
This commit is contained in:
@@ -322,6 +322,6 @@ def fix_bf16(weights: Dict[Any, Tensor]):
|
||||
}
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()
|
||||
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
# TODO: check if device supports bf16
|
||||
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
|
||||
Reference in New Issue
Block a user