mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update model training code
This commit is contained in:
@@ -303,6 +303,15 @@ def parse_model_from_args(args):
|
||||
model = textattack.shared.utils.load_textattack_model_from_path(
|
||||
args.model, model_path
|
||||
)
|
||||
# Choose the approprate model wrapper (based on whether or not this is
|
||||
# a HuggingFace model).
|
||||
if isinstance(
|
||||
model, textattack.models.helpers.BERTForClassification
|
||||
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
|
||||
model = textattack.models.wrappers.HuggingFaceModelWrapper(
|
||||
model, model.tokenizer, batch_size=args.model_batch_size
|
||||
)
|
||||
else:
|
||||
model = textattack.models.wrappers.PyTorchModelWrapper(
|
||||
model, model.tokenizer, batch_size=args.model_batch_size
|
||||
)
|
||||
|
||||
@@ -153,14 +153,17 @@ def _get_eval_score(model, eval_dataloader, do_regression):
|
||||
logits = []
|
||||
labels = []
|
||||
for input_ids, batch_labels in eval_dataloader:
|
||||
if isinstance(input_ids, dict):
|
||||
## HACK: dataloader collates dict backwards. This is a temporary
|
||||
# workaround to get ids in the right shape
|
||||
input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()}
|
||||
batch_labels = batch_labels.to(device)
|
||||
|
||||
if isinstance(input_ids, dict):
|
||||
## dataloader collates dict backwards. This is a workaround to get
|
||||
# ids in the right shape for HuggingFace models
|
||||
input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()}
|
||||
with torch.no_grad():
|
||||
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
batch_logits = model(**input_ids)[0]
|
||||
else:
|
||||
input_ids = input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
batch_logits = model(input_ids)
|
||||
|
||||
logits.extend(batch_logits.cpu().squeeze().tolist())
|
||||
labels.extend(batch_labels)
|
||||
@@ -327,10 +330,13 @@ def train_model(args):
|
||||
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})"
|
||||
)
|
||||
|
||||
model = model_from_args(args, args.num_labels)
|
||||
tokenizer = model.tokenizer
|
||||
model_wrapper = model_from_args(args, args.num_labels)
|
||||
model = model_wrapper.model
|
||||
tokenizer = model_wrapper.tokenizer
|
||||
|
||||
attackCls = attack_from_args(args)
|
||||
# We are adversarial training if the user specified an attack along with
|
||||
# the training args.
|
||||
adversarial_training = attackCls is not None
|
||||
|
||||
# multi-gpu training
|
||||
@@ -463,13 +469,16 @@ def train_model(args):
|
||||
input_ids, labels = batch
|
||||
labels = labels.to(device)
|
||||
if isinstance(input_ids, dict):
|
||||
## HACK: dataloader collates dict backwards. This is a temporary
|
||||
# workaround to get ids in the right shape
|
||||
## dataloader collates dict backwards. This is a workaround to get
|
||||
# ids in the right shape for HuggingFace models
|
||||
input_ids = {
|
||||
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
|
||||
}
|
||||
logits = model(**input_ids)[0]
|
||||
else:
|
||||
|
||||
logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
input_ids = input_ids.to(device)
|
||||
logits = model(input_ids)
|
||||
|
||||
if args.do_regression:
|
||||
# TODO integrate with textattack `metrics` package
|
||||
@@ -535,11 +544,12 @@ def train_model(args):
|
||||
|
||||
# read the saved model and report its eval performance
|
||||
logger.info("Finished training. Re-loading and evaluating model from disk.")
|
||||
model = model_from_args(args, args.num_labels)
|
||||
model_wrapper = model_from_args(args, args.num_labels)
|
||||
model = model_wrapper.model
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
|
||||
eval_score = _get_eval_score(model, eval_dataloader, args.do_regression)
|
||||
logger.info(
|
||||
f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
|
||||
f"Saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
|
||||
)
|
||||
|
||||
if args.save_last:
|
||||
|
||||
@@ -101,6 +101,8 @@ def model_from_args(train_args, num_labels, model_path=None):
|
||||
)
|
||||
if model_path:
|
||||
model.load_from_disk(model_path)
|
||||
|
||||
model = textattack.models.wrappers.PyTorchModelWrapper(model, model.tokenizer)
|
||||
elif train_args.model == "cnn":
|
||||
textattack.shared.logger.info(
|
||||
"Loading textattack model: WordCNNForClassification"
|
||||
|
||||
@@ -148,6 +148,10 @@ class GoalFunction(ABC):
|
||||
|
||||
outputs = self.model(inputs)
|
||||
|
||||
assert len(inputs) == len(
|
||||
outputs
|
||||
), f"Got {len(outputs)} outputs for {len(inputs)} inputs"
|
||||
|
||||
return self._process_model_outputs(attacked_text_list, outputs)
|
||||
|
||||
def _call_model(self, attacked_text_list):
|
||||
|
||||
@@ -14,11 +14,8 @@ class TextToTextGoalFunction(GoalFunction):
|
||||
return TextToTextGoalFunctionResult
|
||||
|
||||
def _process_model_outputs(self, _, outputs):
|
||||
"""Processes and validates a list of model outputs.
|
||||
|
||||
Flatten list of lists to a single list.
|
||||
"""
|
||||
return [output for batch in outputs for output in batch]
|
||||
"""Processes and validates a list of model outputs."""
|
||||
return outputs.flatten()
|
||||
|
||||
def _get_displayed_output(self, raw_output):
|
||||
return raw_output
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from textattack.models.tokenizers import T5Tokenizer
|
||||
from textattack.shared import utils
|
||||
|
||||
|
||||
class T5ForTextToText:
|
||||
class T5ForTextToText(torch.nn.Module):
|
||||
"""A T5 model trained to generate text from text.
|
||||
|
||||
For more information, please see the T5 paper, "Exploring the Limits of
|
||||
@@ -29,6 +30,7 @@ class T5ForTextToText:
|
||||
def __init__(
|
||||
self, mode="english_to_german", max_length=20, num_beams=1, early_stopping=True
|
||||
):
|
||||
super().__init__()
|
||||
self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
self.model.to(utils.device)
|
||||
self.model.eval()
|
||||
|
||||
@@ -27,6 +27,16 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
|
||||
for k, v in input_dict.items()
|
||||
}
|
||||
outputs = self.model(**input_dict)
|
||||
|
||||
if isinstance(outputs[0], str):
|
||||
# HuggingFace sequence-to-sequence models return a list of
|
||||
# string predictions as output. In this case, return the full
|
||||
# list of outputs.
|
||||
return outputs
|
||||
else:
|
||||
# HuggingFace classification models return a tuple as output
|
||||
# where the first item in the tuple corresponds to the list of
|
||||
# scores for each input.
|
||||
return outputs[0]
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -11,9 +12,20 @@ def batch_model_predict(model_predict, inputs, batch_size=32):
|
||||
while i < len(inputs):
|
||||
batch = inputs[i : i + batch_size]
|
||||
batch_preds = model_predict(batch)
|
||||
if not isinstance(batch_preds, torch.Tensor):
|
||||
batch_preds = torch.Tensor(batch_preds)
|
||||
|
||||
# Some seq-to-seq models will return a single string as a prediction
|
||||
# for a single-string list. Wrap these in a list.
|
||||
if isinstance(batch_preds, str):
|
||||
batch_preds = [batch_preds]
|
||||
|
||||
# Get PyTorch tensors off of other devices.
|
||||
if isinstance(batch_preds, torch.Tensor):
|
||||
batch_preds = batch_preds.cpu()
|
||||
|
||||
# Cast all predictions iterables to ``np.ndarray`` types.
|
||||
if not isinstance(batch_preds, np.ndarray):
|
||||
batch_preds = np.array(batch_preds)
|
||||
outputs.append(batch_preds)
|
||||
i += batch_size
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
return np.concatenate(outputs, axis=0)
|
||||
|
||||
@@ -35,7 +35,7 @@ class WordSwapInflections(WordSwap):
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
words = current_text.words
|
||||
sentence = Sentence(current_text.text)
|
||||
sentence = Sentence(" ".join(words))
|
||||
self._flair_pos_tagger.predict(sentence)
|
||||
word_list, pos_list = zip_flair_result(sentence)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user