1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Change model to Huggingface

This commit is contained in:
sanchit97
2020-10-25 19:51:32 -04:00
parent e04c3777b3
commit f0e27d129c
4 changed files with 7 additions and 3 deletions

View File

@@ -341,7 +341,8 @@ def parse_model_from_args(args):
num_labels,
model_path=args.model,
)
model = textattack.models.wrappers.PyTorchModelWrapper(
# Logic to change this according to model being loaded
model = textattack.models.wrappers.HuggingFaceModelWrapper(
model.model, model.tokenizer, batch_size=args.model_batch_size
)
else:

View File

@@ -40,7 +40,7 @@ class EvalModelCommand(TextAttackCommand):
def get_preds(self, model, inputs):
with torch.no_grad():
preds = model(inputs)
preds = textattack.shared.utils.batch_model_predict(model, inputs)
return preds
def test_model_on_dataset(self, args):

View File

@@ -33,7 +33,10 @@ class PyTorchModelWrapper(ModelWrapper):
def __call__(self, text_input_list):
model_device = next(self.model.parameters()).device
<<<<<<< HEAD
ids = self.encode(text_input_list)
=======
>>>>>>> Change model to Huggingface
ids = torch.tensor(ids).to(model_device)
with torch.no_grad():

View File

@@ -11,7 +11,7 @@ def batch_model_predict(model_predict, inputs, batch_size=32):
i = 0
while i < len(inputs):
batch = inputs[i : i + batch_size]
batch_preds = model_predict(batch)[0]
batch_preds = model_predict(batch)
# Some seq-to-seq models will return a single string as a prediction
# for a single-string list. Wrap these in a list.