mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
generate text in batch
This commit is contained in:
@@ -76,7 +76,6 @@ class GoalFunction:
|
||||
model_device = next(self.model.model.parameters()).device
|
||||
else:
|
||||
model_device = next(self.model.parameters()).device
|
||||
print('ids:', ids)
|
||||
ids = torch.tensor(ids).to(model_device)
|
||||
#
|
||||
# shape of `ids` is (n, m, d)
|
||||
|
||||
Reference in New Issue
Block a user