diff --git a/Makefile b/Makefile index 74bf201a..57330ddc 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,13 @@ +PEP_IGNORE_ERRORS="C901 E501 W503 E203 E231 E266 F403" + format: FORCE ## Run black and isort (rewriting files) black . isort --atomic tests textattack - lint: FORCE ## Run black, isort, flake8 (in check mode) black . --check isort --check-only tests textattack - flake8 . --count --ignore=C901,E501,W503,E203,E231,E266,F403 --show-source --statistics --exclude=./.*,build,dist + flake8 . --count --ignore=$(PEP_IGNORE_ERRORS) --show-source --statistics --exclude=./.*,build,dist test: FORCE ## Run tests using pytest diff --git a/requirements.txt b/requirements.txt index d368e754..1ecea865 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ bert-score editdistance -flair>=0.5 +flair>=0.5.1 filelock language_tool_python lru-dict diff --git a/textattack/augmentation/augmenter.py b/textattack/augmentation/augmenter.py index 9df573b5..6bf79767 100644 --- a/textattack/augmentation/augmenter.py +++ b/textattack/augmentation/augmenter.py @@ -3,7 +3,7 @@ import random import tqdm from textattack.constraints import PreTransformationConstraint -from textattack.shared import AttackedText +from textattack.shared import AttackedText, utils class Augmenter: @@ -70,7 +70,9 @@ class Augmenter: attacked_text = AttackedText(text) original_text = attacked_text all_transformed_texts = set() - num_words_to_swap = int(self.pct_words_to_swap * len(attacked_text.words)) + num_words_to_swap = max( + int(self.pct_words_to_swap * len(attacked_text.words)), 1 + ) for _ in range(self.transformations_per_example): index_order = list(range(len(attacked_text.words))) random.shuffle(index_order) @@ -132,3 +134,22 @@ class Augmenter: all_text_list.extend([text] + augmented_texts) all_id_list.extend([_id] * (1 + len(augmented_texts))) return all_text_list, all_id_list + + def __repr__(self): + main_str = "Augmenter" + "(" + lines = [] + # self.transformation + lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2)) + # self.constraints + constraints_lines = [] + constraints = self.constraints + self.pre_transformation_constraints + if len(constraints): + for i, constraint in enumerate(constraints): + constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2)) + constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2) + else: + constraints_str = "None" + lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2)) + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index ea6e0161..64b79249 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -61,6 +61,9 @@ class EasyDataAugmenter(Augmenter): random.shuffle(augmented_text) return augmented_text[: self.transformations_per_example] + def __repr__(self): + return "EasyDataAugmenter" + class SwapAugmenter(Augmenter): def __init__(self, **kwargs): diff --git a/textattack/commands/attack/attack_args.py b/textattack/commands/attack/attack_args.py index 388917c4..91b3ac4f 100644 --- a/textattack/commands/attack/attack_args.py +++ b/textattack/commands/attack/attack_args.py @@ -308,6 +308,7 @@ BLACK_BOX_TRANSFORMATION_CLASS_NAMES = { "word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution", "word-swap-wordnet": "textattack.transformations.WordSwapWordNet", "word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM", + "word-swap-hownet": "textattack.transformations.WordSwapHowNet", } WHITE_BOX_TRANSFORMATION_CLASS_NAMES = { @@ -353,6 +354,7 @@ SEARCH_METHOD_CLASS_NAMES = { "greedy": "textattack.search_methods.GreedySearch", "ga-word": "textattack.search_methods.GeneticAlgorithm", "greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR", + "pso": "textattack.search_methods.ParticleSwarmOptimization", } GOAL_FUNCTION_CLASS_NAMES = { diff --git a/textattack/commands/attack/attack_args_helpers.py b/textattack/commands/attack/attack_args_helpers.py index 28d49c83..4f39fd56 100644 --- a/textattack/commands/attack/attack_args_helpers.py +++ b/textattack/commands/attack/attack_args_helpers.py @@ -360,9 +360,13 @@ def parse_dataset_from_args(args): ) model_train_args = json.loads(open(model_args_json_path).read()) try: + if ":" in model_train_args["dataset"]: + name, subset = model_train_args["dataset"].split(":") + else: + name, subset = model_train_args["dataset"], None args.dataset_from_nlp = ( - model_train_args["dataset"], - None, + name, + subset, model_train_args["dataset_dev_split"], ) except KeyError: diff --git a/textattack/commands/attack/run_attack_parallel.py b/textattack/commands/attack/run_attack_parallel.py index c96f0d7e..20d8e291 100644 --- a/textattack/commands/attack/run_attack_parallel.py +++ b/textattack/commands/attack/run_attack_parallel.py @@ -94,9 +94,18 @@ def run(args, checkpoint=None): in_queue = torch.multiprocessing.Queue() out_queue = torch.multiprocessing.Queue() # Add stuff to queue. + missing_datapoints = set() for i in worklist: - text, output = dataset[i] - in_queue.put((i, text, output)) + try: + text, output = dataset[i] + in_queue.put((i, text, output)) + except IndexError: + missing_datapoints.add(i) + + # if our dataset is shorter than the number of samples chosen, remove the + # out-of-bounds indices from the dataset + for i in missing_datapoints: + worklist.remove(i) # Start workers. # pool = torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue)) @@ -146,7 +155,7 @@ def run(args, checkpoint=None): in_queue.put((worklist_tail, text, output)) except IndexError: raise IndexError( - "Out of bounds access of dataset. Size of data is {} but tried to access index {}".format( + "Tried adding to worklist, but ran out of datapoints. Size of data is {} but tried to access index {}".format( len(dataset), worklist_tail ) ) diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index b0a4fc00..df59843e 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -1,18 +1,19 @@ import json import logging +import math import os -import time import numpy as np import scipy import torch -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, RandomSampler import tqdm import transformers import textattack from .train_args_helpers import ( + attack_from_args, augmenter_from_args, dataset_from_args, model_from_args, @@ -23,9 +24,178 @@ device = textattack.shared.utils.device logger = textattack.shared.logger -def make_directories(output_dir): +def _save_args(args, save_path): + """ + Dump args dictionary to a json + + :param: args. Dictionary of arguments to save. + :save_path: Path to json file to write args to. + """ + final_args_dict = {k: v for k, v in vars(args).items() if _is_writable_type(v)} + with open(save_path, "w", encoding="utf-8") as f: + f.write(json.dumps(final_args_dict, indent=2) + "\n") + + +def _get_sample_count(*lsts): + """ + Get sample count of a dataset. + + :param *lsts: variable number of lists. + :return: sample count of this dataset, if all lists match, else None. + """ + if all(len(lst) == len(lsts[0]) for lst in lsts): + sample_count = len(lsts[0]) + else: + sample_count = None + return sample_count + + +def _random_shuffle(*lsts): + """ + Randomly shuffle a dataset. Applies the same permutation + to each list (to preserve mapping between inputs and targets). + + :param *lsts: variable number of lists to shuffle. + :return: shuffled lsts. + """ + permutation = np.random.permutation(len(lsts[0])) + shuffled = [] + for lst in lsts: + shuffled.append((np.array(lst)[permutation]).tolist()) + return tuple(shuffled) + + +def _train_val_split(*lsts, split_val=0.2): + """ + Split dataset into training and validation sets. + + :param *lsts: variable number of lists that make up a dataset (e.g. text, labels) + :param split_val: float [0., 1.). Fraction of the dataset to reserve for evaluation. + :return: (train split of list for list in lsts), (val split of list for list in lsts) + """ + sample_count = _get_sample_count(*lsts) + if not sample_count: + raise Exception( + "Batch Axis inconsistent. All input arrays must have first axis of equal length." + ) + lsts = _random_shuffle(*lsts) + split_idx = math.floor(sample_count * split_val) + train_set = [lst[split_idx:] for lst in lsts] + val_set = [lst[:split_idx] for lst in lsts] + if len(train_set) == 1 and len(val_set) == 1: + train_set = train_set[0] + val_set = val_set[0] + return train_set, val_set + + +def _filter_labels(text, labels, allowed_labels): + """ + Keep examples with approved labels + + :param text: list of text inputs. + :param labels: list of corresponding labels. + :param allowed_labels: list of approved label values. + + :return: (final_text, final_labels). Filtered version of text and labels + """ + final_text, final_labels = [], [] + for text, label in zip(text, labels): + if label in allowed_labels: + final_text.append(text) + final_labels.append(label) + return final_text, final_labels + + +def _save_model_checkpoint(model, output_dir, global_step): + """ + Save model checkpoint to disk. + + :param model: Model to save (pytorch) + :param output_dir: Path to model save dir. + :param global_step: Current global training step #. Used in ckpt filename. + """ + # Save model checkpoint + output_dir = os.path.join(output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) + # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(output_dir) + + +def _save_model(model, output_dir, weights_name, config_name): + """ + Save model to disk. + + :param model: Model to save (pytorch) + :param output_dir: Path to model save dir. + :param weights_name: filename for model parameters. + :param config_name: filename for config. + """ + model_to_save = model.module if hasattr(model, "module") else model + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(output_dir, weights_name) + output_config_file = os.path.join(output_dir, config_name) + + torch.save(model_to_save.state_dict(), output_model_file) + try: + model_to_save.config.to_json_file(output_config_file) + except AttributeError: + # no config + pass + + +def _get_eval_score(model, eval_dataloader, do_regression): + """ + Measure performance of a model on the evaluation set. + + :param model: Model to test. + :param eval_dataloader: a torch DataLoader that iterates through the eval set. + :param do_regression: bool. Whether we are doing regression (True) or classification (False) + + :return: pearson correlation, if do_regression==True, else classification accuracy [0., 1.] + """ + model.eval() + correct = 0 + logits = [] + labels = [] + for input_ids, batch_labels in eval_dataloader: + if isinstance(input_ids, dict): + ## HACK: dataloader collates dict backwards. This is a temporary + # workaround to get ids in the right shape + input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()} + batch_labels = batch_labels.to(device) + + with torch.no_grad(): + batch_logits = textattack.shared.utils.model_predict(model, input_ids) + + logits.extend(batch_logits.cpu().squeeze().tolist()) + labels.extend(batch_labels) + + model.train() + logits = torch.tensor(logits) + labels = torch.tensor(labels) + + if do_regression: + pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels) + return pearson_correlation + else: + preds = logits.argmax(dim=1) + correct = (preds == labels).sum() + return float(correct) / len(labels) + + +def _make_directories(output_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + +def _is_writable_type(obj): + for ok_type in [bool, int, str, float]: + if isinstance(obj, ok_type): + return True + return False def batch_encode(tokenizer, text_list): @@ -35,12 +205,70 @@ def batch_encode(tokenizer, text_list): return [tokenizer.encode(text_input) for text_input in text_list] +def _make_dataloader(tokenizer, text, labels, batch_size): + """ + Create torch DataLoader from list of input text and labels. + + :param tokenizer: Tokenizer to use for this text. + :param text: list of input text. + :param labels: list of corresponding labels. + :param batch_size: batch size (int). + :return: torch DataLoader for this training set. + """ + text_ids = batch_encode(tokenizer, text) + input_ids = np.array(text_ids) + labels = np.array(labels) + data = list((ids, label) for ids, label in zip(input_ids, labels)) + sampler = RandomSampler(data) + dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) + return dataloader + + +def _data_augmentation(text, labels, augmenter): + """ + Use an augmentation method to expand a training set. + + :param text: list of input text. + :param labels: list of corresponding labels. + :param augmenter: textattack.augmentation.Augmenter, augmentation scheme. + + :return: augmented_text, augmented_labels. list of (augmented) input text and labels. + """ + aug_text = augmenter.augment_many(text) + # flatten augmented examples and duplicate labels + flat_aug_text = [] + flat_aug_labels = [] + for i, examples in enumerate(aug_text): + for aug_ver in examples: + flat_aug_text.append(aug_ver) + flat_aug_labels.append(labels[i]) + return flat_aug_text, flat_aug_labels + + +def _generate_adversarial_examples(model, attackCls, dataset): + """ + Create a dataset of adversarial examples based on perturbations of the existing dataset. + + :param model: Model to attack. + :param attackCls: class name of attack recipe to run. + :param dataset: iterable of (text, label) pairs. + + :return: list of adversarial examples. + """ + attack = attackCls(model) + adv_train_text = [] + for adv_ex in tqdm.tqdm( + attack.attack_dataset(dataset), desc="Attack", total=len(dataset) + ): + adv_train_text.append(adv_ex.perturbed_text()) + return adv_train_text + + def train_model(args): logger.warn( "WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues." ) - start_time = time.time() - make_directories(args.output_dir) + _make_directories(args.output_dir) num_gpus = torch.cuda.device_count() @@ -63,52 +291,27 @@ def train_model(args): # Filter labels if args.allowed_labels: - logger.info(f"Filtering samples with labels outside of {args.allowed_labels}.") - final_train_text, final_train_labels = [], [] - for text, label in zip(train_text, train_labels): - if label in args.allowed_labels: - final_train_text.append(text) - final_train_labels.append(label) - logger.info( - f"Filtered {len(train_text)} train samples to {len(final_train_text)} points." + train_text, train_labels = _filter_labels( + train_text, train_labels, args.allowed_labels ) - train_text, train_labels = final_train_text, final_train_labels - final_eval_text, final_eval_labels = [], [] - for text, label in zip(eval_text, eval_labels): - if label in args.allowed_labels: - final_eval_text.append(text) - final_eval_labels.append(label) - logger.info( - f"Filtered {len(eval_text)} dev samples to {len(final_eval_text)} points." + eval_text, eval_labels = _filter_labels( + eval_text, eval_labels, args.allowed_labels ) - eval_text, eval_labels = final_eval_text, final_eval_labels + + if args.pct_dataset < 1.0: + logger.info(f"Using {args.pct_dataset*100}% of the training set") + (train_text, train_labels), _ = _train_val_split( + train_text, train_labels, split_val=1.0 - args.pct_dataset + ) + train_examples_len = len(train_text) # data augmentation augmenter = augmenter_from_args(args) if augmenter: - # augment the training set - aug_train_text = augmenter.augment_many(train_text) - # flatten augmented examples and duplicate labels - flat_aug_train_text = [] - flat_aug_train_labels = [] - for i, examples in enumerate(aug_train_text): - for aug_ver in examples: - flat_aug_train_text.append(aug_ver) - flat_aug_train_labels.append(train_labels[i]) - train_text = flat_aug_train_text - train_labels = flat_aug_train_labels - - # augment the eval set - aug_eval_text = augmenter.augment_many(eval_text) - # flatten the augmented examples and duplicate labels - flat_aug_eval_text = [] - flat_aug_eval_labels = [] - for i, examples in enumerate(aug_eval_text): - for aug_ver in examples: - flat_aug_eval_text.append(aug_ver) - flat_aug_eval_labels.append(eval_labels[i]) - eval_text = flat_aug_eval_text - eval_labels = flat_aug_eval_labels + logger.info(f"Augmenting {len(train_text)} samples with {augmenter}") + train_text, train_labels = _data_augmentation( + train_text, train_labels, augmenter + ) # label_id_len = len(train_labels) label_set = set(train_labels) @@ -125,11 +328,9 @@ def train_model(args): else: args.do_regression = False - train_examples_len = len(train_text) - - if len(train_labels) != train_examples_len: + if len(train_labels) != len(train_text): raise ValueError( - f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})" + f"Number of train examples ({len(train_text)}) does not match number of labels ({len(train_labels)})" ) if len(eval_labels) != len(eval_text): raise ValueError( @@ -139,16 +340,13 @@ def train_model(args): model = model_from_args(args, args.num_labels) tokenizer = model.tokenizer - logger.info(f"Tokenizing training data. (len: {train_examples_len})") - train_text_ids = batch_encode(tokenizer, train_text) - logger.info(f"Tokenizing eval data (len: {len(eval_labels)})") - eval_text_ids = batch_encode(tokenizer, eval_text) - load_time = time.time() - logger.info(f"Loaded data and tokenized in {load_time-start_time}s") + attackCls = attack_from_args(args) + adversarial_training = attackCls is not None # multi-gpu training if num_gpus > 1: model = torch.nn.DataParallel(model) + model.tokenizer = model.module.tokenizer logger.info("Using torch.nn.DataParallel.") logger.info(f"Training model across {num_gpus} GPUs") @@ -199,110 +397,38 @@ def train_model(args): tb_writer = SummaryWriter(args.output_dir) - def is_writable_type(obj): - for ok_type in [bool, int, str, float]: - if isinstance(obj, ok_type): - return True - return False - - args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)} - # Save original args to file args_save_path = os.path.join(args.output_dir, "train_args.json") - with open(args_save_path, "w", encoding="utf-8") as f: - f.write(json.dumps(args_dict, indent=2) + "\n") + _save_args(args, args_save_path) logger.info(f"Wrote original training args to {args_save_path}.") - tb_writer.add_hparams(args_dict, {}) + tb_writer.add_hparams( + {k: v for k, v in vars(args).items() if _is_writable_type(v)}, {} + ) # Start training logger.info("***** Running training *****") - logger.info(f"\tNum examples = {train_examples_len}") + if augmenter: + logger.info(f"\tNum original examples = {train_examples_len}") + logger.info(f"\tNum examples after augmentation = {len(train_text)}") + else: + logger.info(f"\tNum examples = {train_examples_len}") logger.info(f"\tBatch size = {args.batch_size}") logger.info(f"\tMax sequence length = {args.max_length}") logger.info(f"\tNum steps = {num_train_optimization_steps}") logger.info(f"\tNum epochs = {args.num_train_epochs}") logger.info(f"\tLearning rate = {args.learning_rate}") - train_input_ids = np.array(train_text_ids) - train_labels = np.array(train_labels) - train_data = list((ids, label) for ids, label in zip(train_input_ids, train_labels)) - train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader( - train_data, sampler=train_sampler, batch_size=args.batch_size + eval_dataloader = _make_dataloader( + tokenizer, eval_text, eval_labels, args.batch_size ) - - eval_input_ids = np.array(eval_text_ids) - eval_labels = np.array(eval_labels) - eval_data = list((ids, label) for ids, label in zip(eval_input_ids, eval_labels)) - eval_sampler = SequentialSampler(eval_data) - eval_dataloader = DataLoader( - eval_data, sampler=eval_sampler, batch_size=args.batch_size + train_dataloader = _make_dataloader( + tokenizer, train_text, train_labels, args.batch_size ) - def get_eval_score(): - model.eval() - correct = 0 - # total = 0 - logits = [] - labels = [] - for input_ids, batch_labels in eval_dataloader: - if isinstance(input_ids, dict): - ## HACK: dataloader collates dict backwards. This is a temporary - # workaround to get ids in the right shape - input_ids = { - k: torch.stack(v).T.to(device) for k, v in input_ids.items() - } - batch_labels = batch_labels.to(device) - - with torch.no_grad(): - batch_logits = textattack.shared.utils.model_predict(model, input_ids) - - logits.extend(batch_logits.cpu().squeeze().tolist()) - labels.extend(batch_labels) - - model.train() - logits = torch.tensor(logits) - labels = torch.tensor(labels) - - if args.do_regression: - pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels) - return pearson_correlation - else: - preds = logits.argmax(dim=1) - correct = (preds == labels).sum() - return float(correct) / len(labels) - - def save_model(): - model_to_save = ( - model.module if hasattr(model, "module") else model - ) # Only save the model itself - - # If we save using the predefined names, we can load using `from_pretrained` - output_model_file = os.path.join(args.output_dir, args.weights_name) - output_config_file = os.path.join(args.output_dir, args.config_name) - - torch.save(model_to_save.state_dict(), output_model_file) - try: - model_to_save.config.to_json_file(output_config_file) - except AttributeError: - # no config - pass - global_step = 0 tr_loss = 0 - def save_model_checkpoint(): - # Save model checkpoint - output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - # Take care of distributed/parallel training - model_to_save = model.module if hasattr(model, "module") else model - model_to_save.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) - logger.info(f"Checkpoint saved to {output_dir}.") - model.train() args.best_eval_score = 0 args.best_eval_score_epoch = 0 @@ -325,6 +451,21 @@ def train_model(args): for epoch in tqdm.trange( int(args.num_train_epochs), desc="Epoch", position=0, leave=False ): + if adversarial_training: + if epoch >= args.num_clean_epochs: + if (epoch - args.num_clean_epochs) % args.attack_period == 0: + # only generate a new adversarial training set every args.attack_period epochs + # after the clean epochs + logger.info("Attacking model to generate new training set...") + adv_train_text = _generate_adversarial_examples( + model, attackCls, list(zip(train_text, train_labels)) + ) + train_dataloader = _make_dataloader( + tokenizer, adv_train_text, train_labels, args.batch_size + ) + else: + logger.info(f"Running clean epoch {epoch+1}/{args.num_clean_epochs}") + prog_bar = tqdm.tqdm( train_dataloader, desc="Iteration", position=1, leave=False ) @@ -337,6 +478,7 @@ def train_model(args): input_ids = { k: torch.stack(v).T.to(device) for k, v in input_ids.items() } + logits = textattack.shared.utils.model_predict(model, input_ids) if args.do_regression: @@ -366,42 +508,46 @@ def train_model(args): and (args.checkpoint_steps > 0) and (global_step % args.checkpoint_steps) == 0 ): - save_model_checkpoint() + _save_model_checkpoint(model, args.output_dir, global_step) # Inc step counter. global_step += 1 # Check accuracy after each epoch. - eval_score = get_eval_score() - tb_writer.add_scalar("epoch_eval_score", eval_score, global_step) + # skip args.num_clean_epochs during adversarial training + if not adversarial_training or epoch >= args.num_clean_epochs: + eval_score = _get_eval_score(model, eval_dataloader, args.do_regression) + tb_writer.add_scalar("epoch_eval_score", eval_score, global_step) - if args.checkpoint_every_epoch: - save_model_checkpoint() + if args.checkpoint_every_epoch: + _save_model_checkpoint(model, args.output_dir, args.global_step) - logger.info( - f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" - ) - if eval_score > args.best_eval_score: - args.best_eval_score = eval_score - args.best_eval_score_epoch = epoch - args.epochs_since_best_eval_score = 0 - save_model() - logger.info(f"Best acc found. Saved model to {args.output_dir}.") - else: - args.epochs_since_best_eval_score += 1 - if (args.early_stopping_epochs > 0) and ( - args.epochs_since_best_eval_score > args.early_stopping_epochs - ): - logger.info( - f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased" - ) - break + logger.info( + f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" + ) + if eval_score > args.best_eval_score: + args.best_eval_score = eval_score + args.best_eval_score_epoch = epoch + args.epochs_since_best_eval_score = 0 + _save_model(model, args.output_dir, args.weights_name, args.config_name) + logger.info(f"Best acc found. Saved model to {args.output_dir}.") + _save_args(args, args_save_path) + logger.info(f"Saved updated args to {args_save_path}") + else: + args.epochs_since_best_eval_score += 1 + if (args.early_stopping_epochs > 0) and ( + args.epochs_since_best_eval_score > args.early_stopping_epochs + ): + logger.info( + f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased" + ) + break # read the saved model and report its eval performance logger.info("Finished training. Re-loading and evaluating model from disk.") model = model_from_args(args, args.num_labels) model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name))) - eval_score = get_eval_score() + eval_score = _get_eval_score(model, eval_dataloader, args.do_regression) logger.info( f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" ) @@ -418,8 +564,5 @@ def train_model(args): # Save a little readme with model info write_readme(args, args.best_eval_score, args.best_eval_score_epoch) - # Save args to file - final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)} - with open(args_save_path, "w", encoding="utf-8") as f: - f.write(json.dumps(final_args_dict, indent=2) + "\n") + _save_args(args, args_save_path) logger.info(f"Wrote final training args to {args_save_path}.") diff --git a/textattack/commands/train_model/train_args_helpers.py b/textattack/commands/train_model/train_args_helpers.py index 12c223eb..8626f40c 100644 --- a/textattack/commands/train_model/train_args_helpers.py +++ b/textattack/commands/train_model/train_args_helpers.py @@ -1,6 +1,7 @@ import os import textattack +from textattack.commands.attack.attack_args import ATTACK_RECIPE_NAMES from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES logger = textattack.shared.logger @@ -133,6 +134,21 @@ def model_from_args(train_args, num_labels, model_path=None): return model +def attack_from_args(args): + # note that this returns a recipe type, not an object + # (we need to wait to have access to the model to initialize) + attackCls = None + if args.attack: + if args.attack in ATTACK_RECIPE_NAMES: + attackCls = eval(ATTACK_RECIPE_NAMES[args.attack]) + else: + raise ValueError(f"Unrecognized attack recipe: {args.attack}") + + # check attack-related args + assert args.num_clean_epochs > 0, "--num-clean-epochs must be > 0" + return attackCls + + def augmenter_from_args(args): augmenter = None if args.augment: diff --git a/textattack/commands/train_model/train_model_command.py b/textattack/commands/train_model/train_model_command.py index ec2fe05c..21b0c586 100644 --- a/textattack/commands/train_model/train_model_command.py +++ b/textattack/commands/train_model/train_model_command.py @@ -13,7 +13,7 @@ class TrainModelCommand(TextAttackCommand): def run(self, args): - date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M") + date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") current_dir = os.path.dirname(os.path.realpath(__file__)) outputs_dir = os.path.join( current_dir, os.pardir, os.pardir, os.pardir, "outputs", "training" @@ -47,6 +47,12 @@ class TrainModelCommand(TextAttackCommand): "`nlp` library. if dataset has a subset, separate with a colon. " " ex: `glue:sst2` or `rotten_tomatoes`", ) + parser.add_argument( + "--pct-dataset", + type=float, + default=1.0, + help="Fraction of dataset to use during training ([0., 1.])", + ) parser.add_argument( "--dataset-train-split", "--train-split", @@ -77,7 +83,7 @@ class TrainModelCommand(TextAttackCommand): help="save model after this many steps (-1 for no checkpointing)", ) parser.add_argument( - "--checkpoint-every_epoch", + "--checkpoint-every-epoch", action="store_true", default=False, help="save model checkpoint after each epoch", @@ -89,6 +95,24 @@ class TrainModelCommand(TextAttackCommand): default=100, help="Total number of epochs to train for", ) + parser.add_argument( + "--attack", + type=str, + default=None, + help="Attack recipe to use (enables adversarial training)", + ) + parser.add_argument( + "--num-clean-epochs", + type=int, + default=1, + help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)", + ) + parser.add_argument( + "--attack-period", + type=int, + default=1, + help="How often (in epochs) to generate a new adversarial training set (N/A if --attack unspecified)", + ) parser.add_argument( "--augment", type=str, default=None, help="Augmentation recipe to use", ) diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py index 47fdf07b..1f4bd3c0 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py @@ -28,7 +28,15 @@ class QueryHandler: except Exception: probs = [] for s, w in zip(sentences, swapped_words): - probs.append(self.try_query([s], [w], batch_size=1)[0]) + try: + probs.append(self.try_query([s], [w], batch_size=1)[0]) + except RuntimeError: + print( + "WARNING: got runtime error trying languag emodel on language model w s/w", + s, + w, + ) + probs.append(float("-inf")) return probs def try_query(self, sentences, swapped_words, batch_size=32): @@ -61,6 +69,8 @@ class QueryHandler: hidden = self.model.init_hidden(len(batch)) source = word_idxs[:-1, :] target = word_idxs[1:, :] + if (not len(source)) or not len(hidden): + return [float("-inf")] * len(batch) decode, hidden = self.model(source, hidden) decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1) for i in range(len(batch)): diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index bc8418b4..010a1a46 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -32,12 +32,13 @@ class GoalFunction(ABC): use_cache=True, query_budget=float("inf"), model_batch_size=32, - model_cache_size=2 ** 18, + model_cache_size=2 ** 20, ): validators.validate_model_goal_function_compatibility( self.__class__, model.__class__ ) self.model = model + self.model.eval() self.maximizable = maximizable self.tokenizer = tokenizer if not self.tokenizer: diff --git a/textattack/loggers/csv_logger.py b/textattack/loggers/csv_logger.py index e6347155..8a55fce8 100644 --- a/textattack/loggers/csv_logger.py +++ b/textattack/loggers/csv_logger.py @@ -2,14 +2,10 @@ import csv import pandas as pd -# from textattack.attack_results import FailedAttackResult -from textattack.shared import logger +from textattack.shared import AttackedText, logger from .logger import Logger -# import os -# import sys - class CSVLogger(Logger): """Logs attack results to a CSV.""" @@ -22,6 +18,8 @@ class CSVLogger(Logger): def log_attack_result(self, result): original_text, perturbed_text = result.diff_color(self.color_method) + original_text = original_text.replace("\n", AttackedText.SPLIT_TOKEN) + perturbed_text = perturbed_text.replace("\n", AttackedText.SPLIT_TOKEN) result_type = result.__class__.__name__.replace("AttackResult", "") row = { "original_text": original_text, diff --git a/textattack/models/helpers/lstm_for_classification.py b/textattack/models/helpers/lstm_for_classification.py index accff85a..0cef2149 100644 --- a/textattack/models/helpers/lstm_for_classification.py +++ b/textattack/models/helpers/lstm_for_classification.py @@ -31,6 +31,7 @@ class LSTMForClassification(nn.Module): # so if that's all we have, this will display a warning. dropout = 0 self.drop = nn.Dropout(dropout) + self.emb_layer_trainable = emb_layer_trainable self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable) self.word2id = self.emb_layer.word2id self.encoder = nn.LSTM( diff --git a/textattack/search_methods/greedy_word_swap_wir.py b/textattack/search_methods/greedy_word_swap_wir.py index 3794fb26..090317df 100644 --- a/textattack/search_methods/greedy_word_swap_wir.py +++ b/textattack/search_methods/greedy_word_swap_wir.py @@ -58,7 +58,9 @@ class GreedyWordSwapWIR(SearchMethod): initial_result, leave_one_texts ) - softmax_saliency_scores = softmax(torch.Tensor(saliency_scores)).numpy() + softmax_saliency_scores = softmax( + torch.Tensor(saliency_scores), dim=0 + ).numpy() # compute the largest change in score we can find by swapping each word delta_ps = [] @@ -72,9 +74,7 @@ class GreedyWordSwapWIR(SearchMethod): # no valid synonym substitutions for this word delta_ps.append(0.0) continue - swap_results, _ = self.get_goal_results( - transformed_text_candidates, initial_result.output - ) + swap_results, _ = self.get_goal_results(transformed_text_candidates) score_change = [result.score for result in swap_results] max_score_change = np.max(score_change) delta_ps.append(max_score_change) diff --git a/textattack/shared/attack.py b/textattack/shared/attack.py index 27ed6af9..d41849f1 100644 --- a/textattack/shared/attack.py +++ b/textattack/shared/attack.py @@ -38,7 +38,7 @@ class Attack: constraints=[], transformation=None, search_method=None, - constraint_cache_size=2 ** 18, + constraint_cache_size=2 ** 20, ): """Initialize an attack object. @@ -150,7 +150,8 @@ class Attack: self, transformed_texts, current_text, original_text=None ): """Filters a list of potential transformed texts based on - ``self.constraints`` Checks cache first. + ``self.constraints`` Utilizes an LRU cache to attempt to avoid + recomputing common transformations. Args: transformed_texts: A list of candidate transformed ``AttackedText`` to filter. diff --git a/textattack/shared/validators.py b/textattack/shared/validators.py index 556fb6f3..657ff7d1 100644 --- a/textattack/shared/validators.py +++ b/textattack/shared/validators.py @@ -14,8 +14,7 @@ from . import logger # A list of goal functions and the corresponding available models. MODELS_BY_GOAL_FUNCTIONS = { (TargetedClassification, UntargetedClassification, InputReduction): [ - r"^textattack.models.classification.*", - r"^textattack.models.entailment.*", + r"^textattack.models.lstm_for_classification.*", r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", ], (NonOverlappingOutput,): [ diff --git a/textattack/transformations/word_swap_gradient_based.py b/textattack/transformations/word_swap_gradient_based.py index ba42410f..76deb344 100644 --- a/textattack/transformations/word_swap_gradient_based.py +++ b/textattack/transformations/word_swap_gradient_based.py @@ -52,6 +52,7 @@ class WordSwapGradientBased(Transformation): word_index (int): index of the word to replace """ self.model.train() + self.model.emb_layer.embedding.weight.requires_grad = True lookup_table = self.model.lookup_table.to(utils.device) lookup_table_transpose = lookup_table.transpose(0, 1) @@ -105,6 +106,9 @@ class WordSwapGradientBased(Transformation): break self.model.eval() + self.model.emb_layer.embedding.weight.requires_grad = ( + self.model.emb_layer_trainable + ) return candidates def _call_model(self, text_ids): diff --git a/textattack/transformations/word_swap_hownet.py b/textattack/transformations/word_swap_hownet.py index c351e241..ff959ff9 100644 --- a/textattack/transformations/word_swap_hownet.py +++ b/textattack/transformations/word_swap_hownet.py @@ -53,10 +53,11 @@ class WordSwapHowNet(WordSwap): def _get_transformations(self, current_text, indices_to_modify): words = current_text.words - words_str = " ".join(words) - word_list, pos_list = zip_flair_result( - self._flair_pos_tagger.predict(words_str)[0] - ) + sentence = Sentence(" ".join(words)) + # in-place POS tagging + self._flair_pos_tagger.predict(sentence) + word_list, pos_list = zip_flair_result(sentence) + assert len(words) == len( word_list ), "Part-of-speech tagger returned incorrect number of tags" diff --git a/textattack/transformations/word_swap_masked_lm.py b/textattack/transformations/word_swap_masked_lm.py index 9eecfd6a..85e501df 100644 --- a/textattack/transformations/word_swap_masked_lm.py +++ b/textattack/transformations/word_swap_masked_lm.py @@ -93,7 +93,7 @@ class WordSwapMaskedLM(WordSwap): replacement_words = [] for id in top_ids: token = self._lm_tokenizer.convert_ids_to_tokens(id) - if utils.is_one_word(token): + if utils.is_one_word(token) and not check_if_subword(token): replacement_words.append(token) return replacement_words @@ -141,7 +141,7 @@ class WordSwapMaskedLM(WordSwap): replacement_words = [] for id in top_preds: token = self._lm_tokenizer.convert_ids_to_tokens(id) - if utils.is_one_word(token): + if utils.is_one_word(token) and not check_if_subword(token): replacement_words.append(token) return replacement_words else: @@ -231,3 +231,7 @@ def recover_word_case(word, reference_word): else: # if other, just do not alter the word's case return word + + +def check_if_subword(text): + return True if "##" in text else False