fix(server): t5 cannot run in f16 (#356)

Fix #349
This commit is contained in:
OlivierDehaene
2023-05-23 12:15:54 +02:00
committed by GitHub
parent 91d9beec90
commit 4f4c9c1665

View File

@@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32