mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge and format
This commit is contained in:
5
Makefile
5
Makefile
@@ -1,12 +1,13 @@
|
||||
PEP_IGNORE_ERRORS="C901 E501 W503 E203 E231 E266 F403"
|
||||
|
||||
format: FORCE ## Run black and isort (rewriting files)
|
||||
black .
|
||||
isort --atomic tests textattack
|
||||
|
||||
|
||||
lint: FORCE ## Run black, isort, flake8 (in check mode)
|
||||
black . --check
|
||||
isort --check-only tests textattack
|
||||
flake8 . --count --ignore=C901,E501,W503,E203,E231,E266,F403 --show-source --statistics --exclude=./.*,build,dist
|
||||
flake8 . --count --ignore=$(PEP_IGNORE_ERRORS) --show-source --statistics --exclude=./.*,build,dist
|
||||
|
||||
|
||||
test: FORCE ## Run tests using pytest
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
bert-score
|
||||
editdistance
|
||||
flair>=0.5
|
||||
flair>=0.5.1
|
||||
filelock
|
||||
language_tool_python
|
||||
lru-dict
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import tqdm
|
||||
|
||||
from textattack.constraints import PreTransformationConstraint
|
||||
from textattack.shared import AttackedText
|
||||
from textattack.shared import AttackedText, utils
|
||||
|
||||
|
||||
class Augmenter:
|
||||
@@ -70,7 +70,9 @@ class Augmenter:
|
||||
attacked_text = AttackedText(text)
|
||||
original_text = attacked_text
|
||||
all_transformed_texts = set()
|
||||
num_words_to_swap = int(self.pct_words_to_swap * len(attacked_text.words))
|
||||
num_words_to_swap = max(
|
||||
int(self.pct_words_to_swap * len(attacked_text.words)), 1
|
||||
)
|
||||
for _ in range(self.transformations_per_example):
|
||||
index_order = list(range(len(attacked_text.words)))
|
||||
random.shuffle(index_order)
|
||||
@@ -132,3 +134,22 @@ class Augmenter:
|
||||
all_text_list.extend([text] + augmented_texts)
|
||||
all_id_list.extend([_id] * (1 + len(augmented_texts)))
|
||||
return all_text_list, all_id_list
|
||||
|
||||
def __repr__(self):
|
||||
main_str = "Augmenter" + "("
|
||||
lines = []
|
||||
# self.transformation
|
||||
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
|
||||
# self.constraints
|
||||
constraints_lines = []
|
||||
constraints = self.constraints + self.pre_transformation_constraints
|
||||
if len(constraints):
|
||||
for i, constraint in enumerate(constraints):
|
||||
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
|
||||
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
|
||||
else:
|
||||
constraints_str = "None"
|
||||
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
|
||||
main_str += "\n " + "\n ".join(lines) + "\n"
|
||||
main_str += ")"
|
||||
return main_str
|
||||
|
||||
@@ -61,6 +61,9 @@ class EasyDataAugmenter(Augmenter):
|
||||
random.shuffle(augmented_text)
|
||||
return augmented_text[: self.transformations_per_example]
|
||||
|
||||
def __repr__(self):
|
||||
return "EasyDataAugmenter"
|
||||
|
||||
|
||||
class SwapAugmenter(Augmenter):
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -308,6 +308,7 @@ BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
|
||||
"word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
|
||||
"word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
|
||||
"word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
|
||||
"word-swap-hownet": "textattack.transformations.WordSwapHowNet",
|
||||
}
|
||||
|
||||
WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
|
||||
@@ -353,6 +354,7 @@ SEARCH_METHOD_CLASS_NAMES = {
|
||||
"greedy": "textattack.search_methods.GreedySearch",
|
||||
"ga-word": "textattack.search_methods.GeneticAlgorithm",
|
||||
"greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
|
||||
"pso": "textattack.search_methods.ParticleSwarmOptimization",
|
||||
}
|
||||
|
||||
GOAL_FUNCTION_CLASS_NAMES = {
|
||||
|
||||
@@ -360,9 +360,13 @@ def parse_dataset_from_args(args):
|
||||
)
|
||||
model_train_args = json.loads(open(model_args_json_path).read())
|
||||
try:
|
||||
if ":" in model_train_args["dataset"]:
|
||||
name, subset = model_train_args["dataset"].split(":")
|
||||
else:
|
||||
name, subset = model_train_args["dataset"], None
|
||||
args.dataset_from_nlp = (
|
||||
model_train_args["dataset"],
|
||||
None,
|
||||
name,
|
||||
subset,
|
||||
model_train_args["dataset_dev_split"],
|
||||
)
|
||||
except KeyError:
|
||||
|
||||
@@ -94,9 +94,18 @@ def run(args, checkpoint=None):
|
||||
in_queue = torch.multiprocessing.Queue()
|
||||
out_queue = torch.multiprocessing.Queue()
|
||||
# Add stuff to queue.
|
||||
missing_datapoints = set()
|
||||
for i in worklist:
|
||||
text, output = dataset[i]
|
||||
in_queue.put((i, text, output))
|
||||
try:
|
||||
text, output = dataset[i]
|
||||
in_queue.put((i, text, output))
|
||||
except IndexError:
|
||||
missing_datapoints.add(i)
|
||||
|
||||
# if our dataset is shorter than the number of samples chosen, remove the
|
||||
# out-of-bounds indices from the dataset
|
||||
for i in missing_datapoints:
|
||||
worklist.remove(i)
|
||||
|
||||
# Start workers.
|
||||
# pool = torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue))
|
||||
@@ -146,7 +155,7 @@ def run(args, checkpoint=None):
|
||||
in_queue.put((worklist_tail, text, output))
|
||||
except IndexError:
|
||||
raise IndexError(
|
||||
"Out of bounds access of dataset. Size of data is {} but tried to access index {}".format(
|
||||
"Tried adding to worklist, but ran out of datapoints. Size of data is {} but tried to access index {}".format(
|
||||
len(dataset), worklist_tail
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
import tqdm
|
||||
import transformers
|
||||
|
||||
import textattack
|
||||
|
||||
from .train_args_helpers import (
|
||||
attack_from_args,
|
||||
augmenter_from_args,
|
||||
dataset_from_args,
|
||||
model_from_args,
|
||||
@@ -23,9 +24,178 @@ device = textattack.shared.utils.device
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
def make_directories(output_dir):
|
||||
def _save_args(args, save_path):
|
||||
"""
|
||||
Dump args dictionary to a json
|
||||
|
||||
:param: args. Dictionary of arguments to save.
|
||||
:save_path: Path to json file to write args to.
|
||||
"""
|
||||
final_args_dict = {k: v for k, v in vars(args).items() if _is_writable_type(v)}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(final_args_dict, indent=2) + "\n")
|
||||
|
||||
|
||||
def _get_sample_count(*lsts):
|
||||
"""
|
||||
Get sample count of a dataset.
|
||||
|
||||
:param *lsts: variable number of lists.
|
||||
:return: sample count of this dataset, if all lists match, else None.
|
||||
"""
|
||||
if all(len(lst) == len(lsts[0]) for lst in lsts):
|
||||
sample_count = len(lsts[0])
|
||||
else:
|
||||
sample_count = None
|
||||
return sample_count
|
||||
|
||||
|
||||
def _random_shuffle(*lsts):
|
||||
"""
|
||||
Randomly shuffle a dataset. Applies the same permutation
|
||||
to each list (to preserve mapping between inputs and targets).
|
||||
|
||||
:param *lsts: variable number of lists to shuffle.
|
||||
:return: shuffled lsts.
|
||||
"""
|
||||
permutation = np.random.permutation(len(lsts[0]))
|
||||
shuffled = []
|
||||
for lst in lsts:
|
||||
shuffled.append((np.array(lst)[permutation]).tolist())
|
||||
return tuple(shuffled)
|
||||
|
||||
|
||||
def _train_val_split(*lsts, split_val=0.2):
|
||||
"""
|
||||
Split dataset into training and validation sets.
|
||||
|
||||
:param *lsts: variable number of lists that make up a dataset (e.g. text, labels)
|
||||
:param split_val: float [0., 1.). Fraction of the dataset to reserve for evaluation.
|
||||
:return: (train split of list for list in lsts), (val split of list for list in lsts)
|
||||
"""
|
||||
sample_count = _get_sample_count(*lsts)
|
||||
if not sample_count:
|
||||
raise Exception(
|
||||
"Batch Axis inconsistent. All input arrays must have first axis of equal length."
|
||||
)
|
||||
lsts = _random_shuffle(*lsts)
|
||||
split_idx = math.floor(sample_count * split_val)
|
||||
train_set = [lst[split_idx:] for lst in lsts]
|
||||
val_set = [lst[:split_idx] for lst in lsts]
|
||||
if len(train_set) == 1 and len(val_set) == 1:
|
||||
train_set = train_set[0]
|
||||
val_set = val_set[0]
|
||||
return train_set, val_set
|
||||
|
||||
|
||||
def _filter_labels(text, labels, allowed_labels):
|
||||
"""
|
||||
Keep examples with approved labels
|
||||
|
||||
:param text: list of text inputs.
|
||||
:param labels: list of corresponding labels.
|
||||
:param allowed_labels: list of approved label values.
|
||||
|
||||
:return: (final_text, final_labels). Filtered version of text and labels
|
||||
"""
|
||||
final_text, final_labels = [], []
|
||||
for text, label in zip(text, labels):
|
||||
if label in allowed_labels:
|
||||
final_text.append(text)
|
||||
final_labels.append(label)
|
||||
return final_text, final_labels
|
||||
|
||||
|
||||
def _save_model_checkpoint(model, output_dir, global_step):
|
||||
"""
|
||||
Save model checkpoint to disk.
|
||||
|
||||
:param model: Model to save (pytorch)
|
||||
:param output_dir: Path to model save dir.
|
||||
:param global_step: Current global training step #. Used in ckpt filename.
|
||||
"""
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# Take care of distributed/parallel training
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def _save_model(model, output_dir, weights_name, config_name):
|
||||
"""
|
||||
Save model to disk.
|
||||
|
||||
:param model: Model to save (pytorch)
|
||||
:param output_dir: Path to model save dir.
|
||||
:param weights_name: filename for model parameters.
|
||||
:param config_name: filename for config.
|
||||
"""
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(output_dir, weights_name)
|
||||
output_config_file = os.path.join(output_dir, config_name)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
try:
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
except AttributeError:
|
||||
# no config
|
||||
pass
|
||||
|
||||
|
||||
def _get_eval_score(model, eval_dataloader, do_regression):
|
||||
"""
|
||||
Measure performance of a model on the evaluation set.
|
||||
|
||||
:param model: Model to test.
|
||||
:param eval_dataloader: a torch DataLoader that iterates through the eval set.
|
||||
:param do_regression: bool. Whether we are doing regression (True) or classification (False)
|
||||
|
||||
:return: pearson correlation, if do_regression==True, else classification accuracy [0., 1.]
|
||||
"""
|
||||
model.eval()
|
||||
correct = 0
|
||||
logits = []
|
||||
labels = []
|
||||
for input_ids, batch_labels in eval_dataloader:
|
||||
if isinstance(input_ids, dict):
|
||||
## HACK: dataloader collates dict backwards. This is a temporary
|
||||
# workaround to get ids in the right shape
|
||||
input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()}
|
||||
batch_labels = batch_labels.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
logits.extend(batch_logits.cpu().squeeze().tolist())
|
||||
labels.extend(batch_labels)
|
||||
|
||||
model.train()
|
||||
logits = torch.tensor(logits)
|
||||
labels = torch.tensor(labels)
|
||||
|
||||
if do_regression:
|
||||
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
|
||||
return pearson_correlation
|
||||
else:
|
||||
preds = logits.argmax(dim=1)
|
||||
correct = (preds == labels).sum()
|
||||
return float(correct) / len(labels)
|
||||
|
||||
|
||||
def _make_directories(output_dir):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
|
||||
def _is_writable_type(obj):
|
||||
for ok_type in [bool, int, str, float]:
|
||||
if isinstance(obj, ok_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def batch_encode(tokenizer, text_list):
|
||||
@@ -35,12 +205,70 @@ def batch_encode(tokenizer, text_list):
|
||||
return [tokenizer.encode(text_input) for text_input in text_list]
|
||||
|
||||
|
||||
def _make_dataloader(tokenizer, text, labels, batch_size):
|
||||
"""
|
||||
Create torch DataLoader from list of input text and labels.
|
||||
|
||||
:param tokenizer: Tokenizer to use for this text.
|
||||
:param text: list of input text.
|
||||
:param labels: list of corresponding labels.
|
||||
:param batch_size: batch size (int).
|
||||
:return: torch DataLoader for this training set.
|
||||
"""
|
||||
text_ids = batch_encode(tokenizer, text)
|
||||
input_ids = np.array(text_ids)
|
||||
labels = np.array(labels)
|
||||
data = list((ids, label) for ids, label in zip(input_ids, labels))
|
||||
sampler = RandomSampler(data)
|
||||
dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
|
||||
return dataloader
|
||||
|
||||
|
||||
def _data_augmentation(text, labels, augmenter):
|
||||
"""
|
||||
Use an augmentation method to expand a training set.
|
||||
|
||||
:param text: list of input text.
|
||||
:param labels: list of corresponding labels.
|
||||
:param augmenter: textattack.augmentation.Augmenter, augmentation scheme.
|
||||
|
||||
:return: augmented_text, augmented_labels. list of (augmented) input text and labels.
|
||||
"""
|
||||
aug_text = augmenter.augment_many(text)
|
||||
# flatten augmented examples and duplicate labels
|
||||
flat_aug_text = []
|
||||
flat_aug_labels = []
|
||||
for i, examples in enumerate(aug_text):
|
||||
for aug_ver in examples:
|
||||
flat_aug_text.append(aug_ver)
|
||||
flat_aug_labels.append(labels[i])
|
||||
return flat_aug_text, flat_aug_labels
|
||||
|
||||
|
||||
def _generate_adversarial_examples(model, attackCls, dataset):
|
||||
"""
|
||||
Create a dataset of adversarial examples based on perturbations of the existing dataset.
|
||||
|
||||
:param model: Model to attack.
|
||||
:param attackCls: class name of attack recipe to run.
|
||||
:param dataset: iterable of (text, label) pairs.
|
||||
|
||||
:return: list of adversarial examples.
|
||||
"""
|
||||
attack = attackCls(model)
|
||||
adv_train_text = []
|
||||
for adv_ex in tqdm.tqdm(
|
||||
attack.attack_dataset(dataset), desc="Attack", total=len(dataset)
|
||||
):
|
||||
adv_train_text.append(adv_ex.perturbed_text())
|
||||
return adv_train_text
|
||||
|
||||
|
||||
def train_model(args):
|
||||
logger.warn(
|
||||
"WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues."
|
||||
)
|
||||
start_time = time.time()
|
||||
make_directories(args.output_dir)
|
||||
_make_directories(args.output_dir)
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
||||
@@ -63,52 +291,27 @@ def train_model(args):
|
||||
|
||||
# Filter labels
|
||||
if args.allowed_labels:
|
||||
logger.info(f"Filtering samples with labels outside of {args.allowed_labels}.")
|
||||
final_train_text, final_train_labels = [], []
|
||||
for text, label in zip(train_text, train_labels):
|
||||
if label in args.allowed_labels:
|
||||
final_train_text.append(text)
|
||||
final_train_labels.append(label)
|
||||
logger.info(
|
||||
f"Filtered {len(train_text)} train samples to {len(final_train_text)} points."
|
||||
train_text, train_labels = _filter_labels(
|
||||
train_text, train_labels, args.allowed_labels
|
||||
)
|
||||
train_text, train_labels = final_train_text, final_train_labels
|
||||
final_eval_text, final_eval_labels = [], []
|
||||
for text, label in zip(eval_text, eval_labels):
|
||||
if label in args.allowed_labels:
|
||||
final_eval_text.append(text)
|
||||
final_eval_labels.append(label)
|
||||
logger.info(
|
||||
f"Filtered {len(eval_text)} dev samples to {len(final_eval_text)} points."
|
||||
eval_text, eval_labels = _filter_labels(
|
||||
eval_text, eval_labels, args.allowed_labels
|
||||
)
|
||||
eval_text, eval_labels = final_eval_text, final_eval_labels
|
||||
|
||||
if args.pct_dataset < 1.0:
|
||||
logger.info(f"Using {args.pct_dataset*100}% of the training set")
|
||||
(train_text, train_labels), _ = _train_val_split(
|
||||
train_text, train_labels, split_val=1.0 - args.pct_dataset
|
||||
)
|
||||
train_examples_len = len(train_text)
|
||||
|
||||
# data augmentation
|
||||
augmenter = augmenter_from_args(args)
|
||||
if augmenter:
|
||||
# augment the training set
|
||||
aug_train_text = augmenter.augment_many(train_text)
|
||||
# flatten augmented examples and duplicate labels
|
||||
flat_aug_train_text = []
|
||||
flat_aug_train_labels = []
|
||||
for i, examples in enumerate(aug_train_text):
|
||||
for aug_ver in examples:
|
||||
flat_aug_train_text.append(aug_ver)
|
||||
flat_aug_train_labels.append(train_labels[i])
|
||||
train_text = flat_aug_train_text
|
||||
train_labels = flat_aug_train_labels
|
||||
|
||||
# augment the eval set
|
||||
aug_eval_text = augmenter.augment_many(eval_text)
|
||||
# flatten the augmented examples and duplicate labels
|
||||
flat_aug_eval_text = []
|
||||
flat_aug_eval_labels = []
|
||||
for i, examples in enumerate(aug_eval_text):
|
||||
for aug_ver in examples:
|
||||
flat_aug_eval_text.append(aug_ver)
|
||||
flat_aug_eval_labels.append(eval_labels[i])
|
||||
eval_text = flat_aug_eval_text
|
||||
eval_labels = flat_aug_eval_labels
|
||||
logger.info(f"Augmenting {len(train_text)} samples with {augmenter}")
|
||||
train_text, train_labels = _data_augmentation(
|
||||
train_text, train_labels, augmenter
|
||||
)
|
||||
|
||||
# label_id_len = len(train_labels)
|
||||
label_set = set(train_labels)
|
||||
@@ -125,11 +328,9 @@ def train_model(args):
|
||||
else:
|
||||
args.do_regression = False
|
||||
|
||||
train_examples_len = len(train_text)
|
||||
|
||||
if len(train_labels) != train_examples_len:
|
||||
if len(train_labels) != len(train_text):
|
||||
raise ValueError(
|
||||
f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})"
|
||||
f"Number of train examples ({len(train_text)}) does not match number of labels ({len(train_labels)})"
|
||||
)
|
||||
if len(eval_labels) != len(eval_text):
|
||||
raise ValueError(
|
||||
@@ -139,16 +340,13 @@ def train_model(args):
|
||||
model = model_from_args(args, args.num_labels)
|
||||
tokenizer = model.tokenizer
|
||||
|
||||
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
|
||||
train_text_ids = batch_encode(tokenizer, train_text)
|
||||
logger.info(f"Tokenizing eval data (len: {len(eval_labels)})")
|
||||
eval_text_ids = batch_encode(tokenizer, eval_text)
|
||||
load_time = time.time()
|
||||
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
|
||||
attackCls = attack_from_args(args)
|
||||
adversarial_training = attackCls is not None
|
||||
|
||||
# multi-gpu training
|
||||
if num_gpus > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
model.tokenizer = model.module.tokenizer
|
||||
logger.info("Using torch.nn.DataParallel.")
|
||||
logger.info(f"Training model across {num_gpus} GPUs")
|
||||
|
||||
@@ -199,110 +397,38 @@ def train_model(args):
|
||||
|
||||
tb_writer = SummaryWriter(args.output_dir)
|
||||
|
||||
def is_writable_type(obj):
|
||||
for ok_type in [bool, int, str, float]:
|
||||
if isinstance(obj, ok_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
|
||||
|
||||
# Save original args to file
|
||||
args_save_path = os.path.join(args.output_dir, "train_args.json")
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(args_dict, indent=2) + "\n")
|
||||
_save_args(args, args_save_path)
|
||||
logger.info(f"Wrote original training args to {args_save_path}.")
|
||||
|
||||
tb_writer.add_hparams(args_dict, {})
|
||||
tb_writer.add_hparams(
|
||||
{k: v for k, v in vars(args).items() if _is_writable_type(v)}, {}
|
||||
)
|
||||
|
||||
# Start training
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f"\tNum examples = {train_examples_len}")
|
||||
if augmenter:
|
||||
logger.info(f"\tNum original examples = {train_examples_len}")
|
||||
logger.info(f"\tNum examples after augmentation = {len(train_text)}")
|
||||
else:
|
||||
logger.info(f"\tNum examples = {train_examples_len}")
|
||||
logger.info(f"\tBatch size = {args.batch_size}")
|
||||
logger.info(f"\tMax sequence length = {args.max_length}")
|
||||
logger.info(f"\tNum steps = {num_train_optimization_steps}")
|
||||
logger.info(f"\tNum epochs = {args.num_train_epochs}")
|
||||
logger.info(f"\tLearning rate = {args.learning_rate}")
|
||||
|
||||
train_input_ids = np.array(train_text_ids)
|
||||
train_labels = np.array(train_labels)
|
||||
train_data = list((ids, label) for ids, label in zip(train_input_ids, train_labels))
|
||||
train_sampler = RandomSampler(train_data)
|
||||
train_dataloader = DataLoader(
|
||||
train_data, sampler=train_sampler, batch_size=args.batch_size
|
||||
eval_dataloader = _make_dataloader(
|
||||
tokenizer, eval_text, eval_labels, args.batch_size
|
||||
)
|
||||
|
||||
eval_input_ids = np.array(eval_text_ids)
|
||||
eval_labels = np.array(eval_labels)
|
||||
eval_data = list((ids, label) for ids, label in zip(eval_input_ids, eval_labels))
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_data, sampler=eval_sampler, batch_size=args.batch_size
|
||||
train_dataloader = _make_dataloader(
|
||||
tokenizer, train_text, train_labels, args.batch_size
|
||||
)
|
||||
|
||||
def get_eval_score():
|
||||
model.eval()
|
||||
correct = 0
|
||||
# total = 0
|
||||
logits = []
|
||||
labels = []
|
||||
for input_ids, batch_labels in eval_dataloader:
|
||||
if isinstance(input_ids, dict):
|
||||
## HACK: dataloader collates dict backwards. This is a temporary
|
||||
# workaround to get ids in the right shape
|
||||
input_ids = {
|
||||
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
|
||||
}
|
||||
batch_labels = batch_labels.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
logits.extend(batch_logits.cpu().squeeze().tolist())
|
||||
labels.extend(batch_labels)
|
||||
|
||||
model.train()
|
||||
logits = torch.tensor(logits)
|
||||
labels = torch.tensor(labels)
|
||||
|
||||
if args.do_regression:
|
||||
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
|
||||
return pearson_correlation
|
||||
else:
|
||||
preds = logits.argmax(dim=1)
|
||||
correct = (preds == labels).sum()
|
||||
return float(correct) / len(labels)
|
||||
|
||||
def save_model():
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Only save the model itself
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, args.weights_name)
|
||||
output_config_file = os.path.join(args.output_dir, args.config_name)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
try:
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
except AttributeError:
|
||||
# no config
|
||||
pass
|
||||
|
||||
global_step = 0
|
||||
tr_loss = 0
|
||||
|
||||
def save_model_checkpoint():
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# Take care of distributed/parallel training
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info(f"Checkpoint saved to {output_dir}.")
|
||||
|
||||
model.train()
|
||||
args.best_eval_score = 0
|
||||
args.best_eval_score_epoch = 0
|
||||
@@ -325,6 +451,21 @@ def train_model(args):
|
||||
for epoch in tqdm.trange(
|
||||
int(args.num_train_epochs), desc="Epoch", position=0, leave=False
|
||||
):
|
||||
if adversarial_training:
|
||||
if epoch >= args.num_clean_epochs:
|
||||
if (epoch - args.num_clean_epochs) % args.attack_period == 0:
|
||||
# only generate a new adversarial training set every args.attack_period epochs
|
||||
# after the clean epochs
|
||||
logger.info("Attacking model to generate new training set...")
|
||||
adv_train_text = _generate_adversarial_examples(
|
||||
model, attackCls, list(zip(train_text, train_labels))
|
||||
)
|
||||
train_dataloader = _make_dataloader(
|
||||
tokenizer, adv_train_text, train_labels, args.batch_size
|
||||
)
|
||||
else:
|
||||
logger.info(f"Running clean epoch {epoch+1}/{args.num_clean_epochs}")
|
||||
|
||||
prog_bar = tqdm.tqdm(
|
||||
train_dataloader, desc="Iteration", position=1, leave=False
|
||||
)
|
||||
@@ -337,6 +478,7 @@ def train_model(args):
|
||||
input_ids = {
|
||||
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
|
||||
}
|
||||
|
||||
logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
if args.do_regression:
|
||||
@@ -366,42 +508,46 @@ def train_model(args):
|
||||
and (args.checkpoint_steps > 0)
|
||||
and (global_step % args.checkpoint_steps) == 0
|
||||
):
|
||||
save_model_checkpoint()
|
||||
_save_model_checkpoint(model, args.output_dir, global_step)
|
||||
|
||||
# Inc step counter.
|
||||
global_step += 1
|
||||
|
||||
# Check accuracy after each epoch.
|
||||
eval_score = get_eval_score()
|
||||
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
|
||||
# skip args.num_clean_epochs during adversarial training
|
||||
if not adversarial_training or epoch >= args.num_clean_epochs:
|
||||
eval_score = _get_eval_score(model, eval_dataloader, args.do_regression)
|
||||
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
|
||||
|
||||
if args.checkpoint_every_epoch:
|
||||
save_model_checkpoint()
|
||||
if args.checkpoint_every_epoch:
|
||||
_save_model_checkpoint(model, args.output_dir, args.global_step)
|
||||
|
||||
logger.info(
|
||||
f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
|
||||
)
|
||||
if eval_score > args.best_eval_score:
|
||||
args.best_eval_score = eval_score
|
||||
args.best_eval_score_epoch = epoch
|
||||
args.epochs_since_best_eval_score = 0
|
||||
save_model()
|
||||
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
|
||||
else:
|
||||
args.epochs_since_best_eval_score += 1
|
||||
if (args.early_stopping_epochs > 0) and (
|
||||
args.epochs_since_best_eval_score > args.early_stopping_epochs
|
||||
):
|
||||
logger.info(
|
||||
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
|
||||
)
|
||||
break
|
||||
logger.info(
|
||||
f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
|
||||
)
|
||||
if eval_score > args.best_eval_score:
|
||||
args.best_eval_score = eval_score
|
||||
args.best_eval_score_epoch = epoch
|
||||
args.epochs_since_best_eval_score = 0
|
||||
_save_model(model, args.output_dir, args.weights_name, args.config_name)
|
||||
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
|
||||
_save_args(args, args_save_path)
|
||||
logger.info(f"Saved updated args to {args_save_path}")
|
||||
else:
|
||||
args.epochs_since_best_eval_score += 1
|
||||
if (args.early_stopping_epochs > 0) and (
|
||||
args.epochs_since_best_eval_score > args.early_stopping_epochs
|
||||
):
|
||||
logger.info(
|
||||
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
|
||||
)
|
||||
break
|
||||
|
||||
# read the saved model and report its eval performance
|
||||
logger.info("Finished training. Re-loading and evaluating model from disk.")
|
||||
model = model_from_args(args, args.num_labels)
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
|
||||
eval_score = get_eval_score()
|
||||
eval_score = _get_eval_score(model, eval_dataloader, args.do_regression)
|
||||
logger.info(
|
||||
f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
|
||||
)
|
||||
@@ -418,8 +564,5 @@ def train_model(args):
|
||||
# Save a little readme with model info
|
||||
write_readme(args, args.best_eval_score, args.best_eval_score_epoch)
|
||||
|
||||
# Save args to file
|
||||
final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(final_args_dict, indent=2) + "\n")
|
||||
_save_args(args, args_save_path)
|
||||
logger.info(f"Wrote final training args to {args_save_path}.")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
import textattack
|
||||
from textattack.commands.attack.attack_args import ATTACK_RECIPE_NAMES
|
||||
from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES
|
||||
|
||||
logger = textattack.shared.logger
|
||||
@@ -133,6 +134,21 @@ def model_from_args(train_args, num_labels, model_path=None):
|
||||
return model
|
||||
|
||||
|
||||
def attack_from_args(args):
|
||||
# note that this returns a recipe type, not an object
|
||||
# (we need to wait to have access to the model to initialize)
|
||||
attackCls = None
|
||||
if args.attack:
|
||||
if args.attack in ATTACK_RECIPE_NAMES:
|
||||
attackCls = eval(ATTACK_RECIPE_NAMES[args.attack])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized attack recipe: {args.attack}")
|
||||
|
||||
# check attack-related args
|
||||
assert args.num_clean_epochs > 0, "--num-clean-epochs must be > 0"
|
||||
return attackCls
|
||||
|
||||
|
||||
def augmenter_from_args(args):
|
||||
augmenter = None
|
||||
if args.augment:
|
||||
|
||||
@@ -13,7 +13,7 @@ class TrainModelCommand(TextAttackCommand):
|
||||
|
||||
def run(self, args):
|
||||
|
||||
date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M")
|
||||
date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(
|
||||
current_dir, os.pardir, os.pardir, os.pardir, "outputs", "training"
|
||||
@@ -47,6 +47,12 @@ class TrainModelCommand(TextAttackCommand):
|
||||
"`nlp` library. if dataset has a subset, separate with a colon. "
|
||||
" ex: `glue:sst2` or `rotten_tomatoes`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pct-dataset",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fraction of dataset to use during training ([0., 1.])",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-train-split",
|
||||
"--train-split",
|
||||
@@ -77,7 +83,7 @@ class TrainModelCommand(TextAttackCommand):
|
||||
help="save model after this many steps (-1 for no checkpointing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-every_epoch",
|
||||
"--checkpoint-every-epoch",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="save model checkpoint after each epoch",
|
||||
@@ -89,6 +95,24 @@ class TrainModelCommand(TextAttackCommand):
|
||||
default=100,
|
||||
help="Total number of epochs to train for",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attack",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Attack recipe to use (enables adversarial training)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-clean-epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attack-period",
|
||||
type=int,
|
||||
default=1,
|
||||
help="How often (in epochs) to generate a new adversarial training set (N/A if --attack unspecified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--augment", type=str, default=None, help="Augmentation recipe to use",
|
||||
)
|
||||
|
||||
@@ -28,7 +28,15 @@ class QueryHandler:
|
||||
except Exception:
|
||||
probs = []
|
||||
for s, w in zip(sentences, swapped_words):
|
||||
probs.append(self.try_query([s], [w], batch_size=1)[0])
|
||||
try:
|
||||
probs.append(self.try_query([s], [w], batch_size=1)[0])
|
||||
except RuntimeError:
|
||||
print(
|
||||
"WARNING: got runtime error trying languag emodel on language model w s/w",
|
||||
s,
|
||||
w,
|
||||
)
|
||||
probs.append(float("-inf"))
|
||||
return probs
|
||||
|
||||
def try_query(self, sentences, swapped_words, batch_size=32):
|
||||
@@ -61,6 +69,8 @@ class QueryHandler:
|
||||
hidden = self.model.init_hidden(len(batch))
|
||||
source = word_idxs[:-1, :]
|
||||
target = word_idxs[1:, :]
|
||||
if (not len(source)) or not len(hidden):
|
||||
return [float("-inf")] * len(batch)
|
||||
decode, hidden = self.model(source, hidden)
|
||||
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
|
||||
for i in range(len(batch)):
|
||||
|
||||
@@ -32,12 +32,13 @@ class GoalFunction(ABC):
|
||||
use_cache=True,
|
||||
query_budget=float("inf"),
|
||||
model_batch_size=32,
|
||||
model_cache_size=2 ** 18,
|
||||
model_cache_size=2 ** 20,
|
||||
):
|
||||
validators.validate_model_goal_function_compatibility(
|
||||
self.__class__, model.__class__
|
||||
)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.maximizable = maximizable
|
||||
self.tokenizer = tokenizer
|
||||
if not self.tokenizer:
|
||||
|
||||
@@ -2,14 +2,10 @@ import csv
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# from textattack.attack_results import FailedAttackResult
|
||||
from textattack.shared import logger
|
||||
from textattack.shared import AttackedText, logger
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
# import os
|
||||
# import sys
|
||||
|
||||
|
||||
class CSVLogger(Logger):
|
||||
"""Logs attack results to a CSV."""
|
||||
@@ -22,6 +18,8 @@ class CSVLogger(Logger):
|
||||
|
||||
def log_attack_result(self, result):
|
||||
original_text, perturbed_text = result.diff_color(self.color_method)
|
||||
original_text = original_text.replace("\n", AttackedText.SPLIT_TOKEN)
|
||||
perturbed_text = perturbed_text.replace("\n", AttackedText.SPLIT_TOKEN)
|
||||
result_type = result.__class__.__name__.replace("AttackResult", "")
|
||||
row = {
|
||||
"original_text": original_text,
|
||||
|
||||
@@ -31,6 +31,7 @@ class LSTMForClassification(nn.Module):
|
||||
# so if that's all we have, this will display a warning.
|
||||
dropout = 0
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.emb_layer_trainable = emb_layer_trainable
|
||||
self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable)
|
||||
self.word2id = self.emb_layer.word2id
|
||||
self.encoder = nn.LSTM(
|
||||
|
||||
@@ -58,7 +58,9 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
initial_result, leave_one_texts
|
||||
)
|
||||
|
||||
softmax_saliency_scores = softmax(torch.Tensor(saliency_scores)).numpy()
|
||||
softmax_saliency_scores = softmax(
|
||||
torch.Tensor(saliency_scores), dim=0
|
||||
).numpy()
|
||||
|
||||
# compute the largest change in score we can find by swapping each word
|
||||
delta_ps = []
|
||||
@@ -72,9 +74,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
# no valid synonym substitutions for this word
|
||||
delta_ps.append(0.0)
|
||||
continue
|
||||
swap_results, _ = self.get_goal_results(
|
||||
transformed_text_candidates, initial_result.output
|
||||
)
|
||||
swap_results, _ = self.get_goal_results(transformed_text_candidates)
|
||||
score_change = [result.score for result in swap_results]
|
||||
max_score_change = np.max(score_change)
|
||||
delta_ps.append(max_score_change)
|
||||
|
||||
@@ -38,7 +38,7 @@ class Attack:
|
||||
constraints=[],
|
||||
transformation=None,
|
||||
search_method=None,
|
||||
constraint_cache_size=2 ** 18,
|
||||
constraint_cache_size=2 ** 20,
|
||||
):
|
||||
"""Initialize an attack object.
|
||||
|
||||
@@ -150,7 +150,8 @@ class Attack:
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""Filters a list of potential transformed texts based on
|
||||
``self.constraints`` Checks cache first.
|
||||
``self.constraints`` Utilizes an LRU cache to attempt to avoid
|
||||
recomputing common transformations.
|
||||
|
||||
Args:
|
||||
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
||||
|
||||
@@ -14,8 +14,7 @@ from . import logger
|
||||
# A list of goal functions and the corresponding available models.
|
||||
MODELS_BY_GOAL_FUNCTIONS = {
|
||||
(TargetedClassification, UntargetedClassification, InputReduction): [
|
||||
r"^textattack.models.classification.*",
|
||||
r"^textattack.models.entailment.*",
|
||||
r"^textattack.models.lstm_for_classification.*",
|
||||
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
|
||||
],
|
||||
(NonOverlappingOutput,): [
|
||||
|
||||
@@ -52,6 +52,7 @@ class WordSwapGradientBased(Transformation):
|
||||
word_index (int): index of the word to replace
|
||||
"""
|
||||
self.model.train()
|
||||
self.model.emb_layer.embedding.weight.requires_grad = True
|
||||
|
||||
lookup_table = self.model.lookup_table.to(utils.device)
|
||||
lookup_table_transpose = lookup_table.transpose(0, 1)
|
||||
@@ -105,6 +106,9 @@ class WordSwapGradientBased(Transformation):
|
||||
break
|
||||
|
||||
self.model.eval()
|
||||
self.model.emb_layer.embedding.weight.requires_grad = (
|
||||
self.model.emb_layer_trainable
|
||||
)
|
||||
return candidates
|
||||
|
||||
def _call_model(self, text_ids):
|
||||
|
||||
@@ -53,10 +53,11 @@ class WordSwapHowNet(WordSwap):
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
words = current_text.words
|
||||
words_str = " ".join(words)
|
||||
word_list, pos_list = zip_flair_result(
|
||||
self._flair_pos_tagger.predict(words_str)[0]
|
||||
)
|
||||
sentence = Sentence(" ".join(words))
|
||||
# in-place POS tagging
|
||||
self._flair_pos_tagger.predict(sentence)
|
||||
word_list, pos_list = zip_flair_result(sentence)
|
||||
|
||||
assert len(words) == len(
|
||||
word_list
|
||||
), "Part-of-speech tagger returned incorrect number of tags"
|
||||
|
||||
@@ -93,7 +93,7 @@ class WordSwapMaskedLM(WordSwap):
|
||||
replacement_words = []
|
||||
for id in top_ids:
|
||||
token = self._lm_tokenizer.convert_ids_to_tokens(id)
|
||||
if utils.is_one_word(token):
|
||||
if utils.is_one_word(token) and not check_if_subword(token):
|
||||
replacement_words.append(token)
|
||||
|
||||
return replacement_words
|
||||
@@ -141,7 +141,7 @@ class WordSwapMaskedLM(WordSwap):
|
||||
replacement_words = []
|
||||
for id in top_preds:
|
||||
token = self._lm_tokenizer.convert_ids_to_tokens(id)
|
||||
if utils.is_one_word(token):
|
||||
if utils.is_one_word(token) and not check_if_subword(token):
|
||||
replacement_words.append(token)
|
||||
return replacement_words
|
||||
else:
|
||||
@@ -231,3 +231,7 @@ def recover_word_case(word, reference_word):
|
||||
else:
|
||||
# if other, just do not alter the word's case
|
||||
return word
|
||||
|
||||
|
||||
def check_if_subword(text):
|
||||
return True if "##" in text else False
|
||||
|
||||
Reference in New Issue
Block a user