mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Change model to Huggingface
This commit is contained in:
@@ -341,7 +341,8 @@ def parse_model_from_args(args):
|
||||
num_labels,
|
||||
model_path=args.model,
|
||||
)
|
||||
model = textattack.models.wrappers.PyTorchModelWrapper(
|
||||
# Logic to change this according to model being loaded
|
||||
model = textattack.models.wrappers.HuggingFaceModelWrapper(
|
||||
model.model, model.tokenizer, batch_size=args.model_batch_size
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -40,7 +40,7 @@ class EvalModelCommand(TextAttackCommand):
|
||||
|
||||
def get_preds(self, model, inputs):
|
||||
with torch.no_grad():
|
||||
preds = model(inputs)
|
||||
preds = textattack.shared.utils.batch_model_predict(model, inputs)
|
||||
return preds
|
||||
|
||||
def test_model_on_dataset(self, args):
|
||||
|
||||
@@ -33,7 +33,10 @@ class PyTorchModelWrapper(ModelWrapper):
|
||||
|
||||
def __call__(self, text_input_list):
|
||||
model_device = next(self.model.parameters()).device
|
||||
<<<<<<< HEAD
|
||||
ids = self.encode(text_input_list)
|
||||
=======
|
||||
>>>>>>> Change model to Huggingface
|
||||
ids = torch.tensor(ids).to(model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -11,7 +11,7 @@ def batch_model_predict(model_predict, inputs, batch_size=32):
|
||||
i = 0
|
||||
while i < len(inputs):
|
||||
batch = inputs[i : i + batch_size]
|
||||
batch_preds = model_predict(batch)[0]
|
||||
batch_preds = model_predict(batch)
|
||||
|
||||
# Some seq-to-seq models will return a single string as a prediction
|
||||
# for a single-string list. Wrap these in a list.
|
||||
|
||||
Reference in New Issue
Block a user