mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
workaround f16 cast ambiguity
for unknown reasons, without this, when trying to execute "Llama 3.2 1B", I get the error below. Fwiw I do not know the performance impact for this change. I can't even get exo running, but this change allows me to /get further/ (before running into a second issue with vram allocation? story for another day i suppose)
error:
Failed to fetch completions: Error processing prompt (see logs with DEBUG>=2): Nvrtc Error 6, NVRTC_ERROR_COMPILATION
<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
*((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
^
<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
*((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
^
<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
*((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
^
<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
*((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
^
4 errors detected in the compilation of "<null>".
This commit is contained in:
@@ -322,6 +322,6 @@ def fix_bf16(weights: Dict[Any, Tensor]):
|
||||
}
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()
|
||||
# TODO: check if device supports bf16
|
||||
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
|
||||
Reference in New Issue
Block a user