1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update model training code

This commit is contained in:
Jack Morris
2020-07-27 10:40:19 -04:00
parent 29432d88c8
commit 618e815eea
9 changed files with 74 additions and 28 deletions

View File

@@ -303,6 +303,15 @@ def parse_model_from_args(args):
model = textattack.shared.utils.load_textattack_model_from_path(
args.model, model_path
)
# Choose the approprate model wrapper (based on whether or not this is
# a HuggingFace model).
if isinstance(
model, textattack.models.helpers.BERTForClassification
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
model = textattack.models.wrappers.HuggingFaceModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)
else:
model = textattack.models.wrappers.PyTorchModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)

View File

@@ -153,14 +153,17 @@ def _get_eval_score(model, eval_dataloader, do_regression):
logits = []
labels = []
for input_ids, batch_labels in eval_dataloader:
if isinstance(input_ids, dict):
## HACK: dataloader collates dict backwards. This is a temporary
# workaround to get ids in the right shape
input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()}
batch_labels = batch_labels.to(device)
if isinstance(input_ids, dict):
## dataloader collates dict backwards. This is a workaround to get
# ids in the right shape for HuggingFace models
input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()}
with torch.no_grad():
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
batch_logits = model(**input_ids)[0]
else:
input_ids = input_ids.to(device)
with torch.no_grad():
batch_logits = model(input_ids)
logits.extend(batch_logits.cpu().squeeze().tolist())
labels.extend(batch_labels)
@@ -327,10 +330,13 @@ def train_model(args):
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})"
)
model = model_from_args(args, args.num_labels)
tokenizer = model.tokenizer
model_wrapper = model_from_args(args, args.num_labels)
model = model_wrapper.model
tokenizer = model_wrapper.tokenizer
attackCls = attack_from_args(args)
# We are adversarial training if the user specified an attack along with
# the training args.
adversarial_training = attackCls is not None
# multi-gpu training
@@ -463,13 +469,16 @@ def train_model(args):
input_ids, labels = batch
labels = labels.to(device)
if isinstance(input_ids, dict):
## HACK: dataloader collates dict backwards. This is a temporary
# workaround to get ids in the right shape
## dataloader collates dict backwards. This is a workaround to get
# ids in the right shape for HuggingFace models
input_ids = {
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
}
logits = model(**input_ids)[0]
else:
logits = textattack.shared.utils.model_predict(model, input_ids)
input_ids = input_ids.to(device)
logits = model(input_ids)
if args.do_regression:
# TODO integrate with textattack `metrics` package
@@ -535,11 +544,12 @@ def train_model(args):
# read the saved model and report its eval performance
logger.info("Finished training. Re-loading and evaluating model from disk.")
model = model_from_args(args, args.num_labels)
model_wrapper = model_from_args(args, args.num_labels)
model = model_wrapper.model
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
eval_score = _get_eval_score(model, eval_dataloader, args.do_regression)
logger.info(
f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
f"Saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
)
if args.save_last:

View File

@@ -101,6 +101,8 @@ def model_from_args(train_args, num_labels, model_path=None):
)
if model_path:
model.load_from_disk(model_path)
model = textattack.models.wrappers.PyTorchModelWrapper(model, model.tokenizer)
elif train_args.model == "cnn":
textattack.shared.logger.info(
"Loading textattack model: WordCNNForClassification"

View File

@@ -148,6 +148,10 @@ class GoalFunction(ABC):
outputs = self.model(inputs)
assert len(inputs) == len(
outputs
), f"Got {len(outputs)} outputs for {len(inputs)} inputs"
return self._process_model_outputs(attacked_text_list, outputs)
def _call_model(self, attacked_text_list):

View File

@@ -14,11 +14,8 @@ class TextToTextGoalFunction(GoalFunction):
return TextToTextGoalFunctionResult
def _process_model_outputs(self, _, outputs):
"""Processes and validates a list of model outputs.
Flatten list of lists to a single list.
"""
return [output for batch in outputs for output in batch]
"""Processes and validates a list of model outputs."""
return outputs.flatten()
def _get_displayed_output(self, raw_output):
return raw_output

View File

@@ -1,10 +1,11 @@
import torch
import transformers
from textattack.models.tokenizers import T5Tokenizer
from textattack.shared import utils
class T5ForTextToText:
class T5ForTextToText(torch.nn.Module):
"""A T5 model trained to generate text from text.
For more information, please see the T5 paper, "Exploring the Limits of
@@ -29,6 +30,7 @@ class T5ForTextToText:
def __init__(
self, mode="english_to_german", max_length=20, num_beams=1, early_stopping=True
):
super().__init__()
self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base")
self.model.to(utils.device)
self.model.eval()

View File

@@ -27,6 +27,16 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
for k, v in input_dict.items()
}
outputs = self.model(**input_dict)
if isinstance(outputs[0], str):
# HuggingFace sequence-to-sequence models return a list of
# string predictions as output. In this case, return the full
# list of outputs.
return outputs
else:
# HuggingFace classification models return a tuple as output
# where the first item in the tuple corresponds to the list of
# scores for each input.
return outputs[0]
with torch.no_grad():

View File

@@ -1,3 +1,4 @@
import numpy as np
import torch
@@ -11,9 +12,20 @@ def batch_model_predict(model_predict, inputs, batch_size=32):
while i < len(inputs):
batch = inputs[i : i + batch_size]
batch_preds = model_predict(batch)
if not isinstance(batch_preds, torch.Tensor):
batch_preds = torch.Tensor(batch_preds)
# Some seq-to-seq models will return a single string as a prediction
# for a single-string list. Wrap these in a list.
if isinstance(batch_preds, str):
batch_preds = [batch_preds]
# Get PyTorch tensors off of other devices.
if isinstance(batch_preds, torch.Tensor):
batch_preds = batch_preds.cpu()
# Cast all predictions iterables to ``np.ndarray`` types.
if not isinstance(batch_preds, np.ndarray):
batch_preds = np.array(batch_preds)
outputs.append(batch_preds)
i += batch_size
return torch.cat(outputs, dim=0)
return np.concatenate(outputs, axis=0)

View File

@@ -35,7 +35,7 @@ class WordSwapInflections(WordSwap):
def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
sentence = Sentence(current_text.text)
sentence = Sentence(" ".join(words))
self._flair_pos_tagger.predict(sentence)
word_list, pos_list = zip_flair_result(sentence)