mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2023-08-15 01:09:35 +03:00
@@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
Reference in New Issue
Block a user