From 4173ab10aa1ae85e604be3612175acfc635e8d37 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Wed, 8 Jul 2020 23:36:24 -0400 Subject: [PATCH 01/21] adv_train (WIP) --- .../commands/train_model/run_training.py | 206 ++++++++---------- .../train_model/train_args_helpers.py | 12 + .../train_model/train_model_command.py | 3 + textattack/shared/validators.py | 3 +- 4 files changed, 107 insertions(+), 117 deletions(-) diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index 27165003..474969aa 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -16,17 +16,64 @@ from .train_args_helpers import ( augmenter_from_args, dataset_from_args, model_from_args, + attack_from_args, write_readme, ) device = textattack.shared.utils.device logger = textattack.shared.logger +def filter_labels(text, labels, allowed_labels): + final_text, final_labels = [], [] + for text, label in zip(text, labels): + if label in allowed_labels: + final_text.append(text) + final_labels.append(label) + return final_text, final_labels + + +def get_eval_score(model, eval_dataloader, do_regression): + model.eval() + correct = 0 + total = 0 + logits = [] + labels = [] + for input_ids, batch_labels in eval_dataloader: + if isinstance(input_ids, dict): + ## HACK: dataloader collates dict backwards. This is a temporary + # workaround to get ids in the right shape + input_ids = { + k: torch.stack(v).T.to(device) for k, v in input_ids.items() + } + batch_labels = batch_labels.to(device) + + with torch.no_grad(): + batch_logits = textattack.shared.utils.model_predict(model, input_ids) + + logits.extend(batch_logits.cpu().squeeze().tolist()) + labels.extend(batch_labels) + + model.train() + logits = torch.tensor(logits) + labels = torch.tensor(labels) + + if do_regression: + pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels) + return pearson_correlation + else: + preds = logits.argmax(dim=1) + correct = (preds == labels).sum() + return float(correct) / len(labels) def make_directories(output_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) +def is_writable_type(obj): + for ok_type in [bool, int, str, float]: + if isinstance(obj, ok_type): + return True + return False def batch_encode(tokenizer, text_list): if hasattr(tokenizer, "batch_encode"): @@ -34,6 +81,28 @@ def batch_encode(tokenizer, text_list): else: return [tokenizer.encode(text_input) for text_input in text_list] +def make_dataloader(tokenizer, text, labels, batch_size): + text_ids = batch_encode(tokenizer, text) + input_ids = np.array(text_ids) + labels = np.array(labels) + data = list((ids, label) for ids, label in zip(input_ids, labels)) + sampler = RandomSampler(data) + dataloader = DataLoader( + data, sampler=sampler, batch_size=batch_size, + ) + return dataloader + +def data_augmentation(text, labels, augmenter): + aug_text = augmenter.augment_many(text) + # flatten augmented examples and duplicate labels + flat_aug_text = [] + flat_aug_labels = [] + for i, examples in enumerate(aug_text): + for aug_ver in examples: + flat_aug_text.append(aug_ver) + flat_aug_labels.append(labels[i]) + return flat_aug_text, flat_aug_labels + def train_model(args): logger.warn( @@ -55,7 +124,6 @@ def train_model(args): if args.enable_wandb: global wandb import wandb - wandb.init(sync_tensorboard=True) # Get list of text and list of label (integers) from disk. @@ -64,58 +132,18 @@ def train_model(args): # Filter labels if args.allowed_labels: logger.info(f"Filtering samples with labels outside of {args.allowed_labels}.") - final_train_text, final_train_labels = [], [] - for text, label in zip(train_text, train_labels): - if label in args.allowed_labels: - final_train_text.append(text) - final_train_labels.append(label) - logger.info( - f"Filtered {len(train_text)} train samples to {len(final_train_text)} points." - ) - train_text, train_labels = final_train_text, final_train_labels - final_eval_text, final_eval_labels = [], [] - for text, label in zip(eval_text, eval_labels): - if label in args.allowed_labels: - final_eval_text.append(text) - final_eval_labels.append(label) - logger.info( - f"Filtered {len(eval_text)} dev samples to {len(final_eval_text)} points." - ) - eval_text, eval_labels = final_eval_text, final_eval_labels - + train_text, train_labels = filter_labels(train_text, train_labels, args.allowed_labels) + eval_text, eval_labels = filter_labels(eval_text, eval_labels, args.allowed_labels) + # data augmentation augmenter = augmenter_from_args(args) if augmenter: - # augment the training set - aug_train_text = augmenter.augment_many(train_text) - # flatten augmented examples and duplicate labels - flat_aug_train_text = [] - flat_aug_train_labels = [] - for i, examples in enumerate(aug_train_text): - for aug_ver in examples: - flat_aug_train_text.append(aug_ver) - flat_aug_train_labels.append(train_labels[i]) - train_text = flat_aug_train_text - train_labels = flat_aug_train_labels - - # augment the eval set - aug_eval_text = augmenter.augment_many(eval_text) - # flatten the augmented examples and duplicate labels - flat_aug_eval_text = [] - flat_aug_eval_labels = [] - for i, examples in enumerate(aug_eval_text): - for aug_ver in examples: - flat_aug_eval_text.append(aug_ver) - flat_aug_eval_labels.append(eval_labels[i]) - eval_text = flat_aug_eval_text - eval_labels = flat_aug_eval_labels + train_text, train_labels = data_augmentation(train_text, train_labels, augmenter) label_id_len = len(train_labels) label_set = set(train_labels) args.num_labels = len(label_set) - logger.info( - f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})" - ) + logger.info(f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})") if isinstance(train_labels[0], float): # TODO come up with a more sophisticated scheme for knowing when to do regression @@ -128,27 +156,19 @@ def train_model(args): train_examples_len = len(train_text) if len(train_labels) != train_examples_len: - raise ValueError( - f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})" - ) + raise ValueError(f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})") if len(eval_labels) != len(eval_text): - raise ValueError( - f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})" - ) + raise ValueError(f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})") model = model_from_args(args, args.num_labels) tokenizer = model.tokenizer - logger.info(f"Tokenizing training data. (len: {train_examples_len})") - train_text_ids = batch_encode(tokenizer, train_text) - logger.info(f"Tokenizing eval data (len: {len(eval_labels)})") - eval_text_ids = batch_encode(tokenizer, eval_text) - load_time = time.time() - logger.info(f"Loaded data and tokenized in {load_time-start_time}s") + attack_t = attack_from_args(args) # multi-gpu training if num_gpus > 1: model = torch.nn.DataParallel(model) + model.tokenizer = model.module.tokenizer logger.info("Using torch.nn.DataParallel.") logger.info(f"Training model across {num_gpus} GPUs") @@ -196,12 +216,6 @@ def train_model(args): tb_writer = SummaryWriter(args.output_dir) - def is_writable_type(obj): - for ok_type in [bool, int, str, float]: - if isinstance(obj, ok_type): - return True - return False - args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)} # Save original args to file @@ -221,55 +235,9 @@ def train_model(args): logger.info(f"\tNum epochs = {args.num_train_epochs}") logger.info(f"\tLearning rate = {args.learning_rate}") - train_input_ids = np.array(train_text_ids) - train_labels = np.array(train_labels) - train_data = list((ids, label) for ids, label in zip(train_input_ids, train_labels)) - train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader( - train_data, sampler=train_sampler, batch_size=args.batch_size - ) - - eval_input_ids = np.array(eval_text_ids) - eval_labels = np.array(eval_labels) - eval_data = list((ids, label) for ids, label in zip(eval_input_ids, eval_labels)) - eval_sampler = SequentialSampler(eval_data) - eval_dataloader = DataLoader( - eval_data, sampler=eval_sampler, batch_size=args.batch_size - ) - - def get_eval_score(): - model.eval() - correct = 0 - total = 0 - logits = [] - labels = [] - for input_ids, batch_labels in eval_dataloader: - if isinstance(input_ids, dict): - ## HACK: dataloader collates dict backwards. This is a temporary - # workaround to get ids in the right shape - input_ids = { - k: torch.stack(v).T.to(device) for k, v in input_ids.items() - } - batch_labels = batch_labels.to(device) - - with torch.no_grad(): - batch_logits = textattack.shared.utils.model_predict(model, input_ids) - - logits.extend(batch_logits.cpu().squeeze().tolist()) - labels.extend(batch_labels) - - model.train() - logits = torch.tensor(logits) - labels = torch.tensor(labels) - - if args.do_regression: - pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels) - return pearson_correlation - else: - preds = logits.argmax(dim=1) - correct = (preds == labels).sum() - return float(correct) / len(labels) - + eval_dataloader = make_dataloader(tokenizer, eval_text, eval_labels, args.batch_size) + train_dataloader = make_dataloader(tokenizer, train_text, train_labels, args.batch_size) + def save_model(): model_to_save = ( model.module if hasattr(model, "module") else model @@ -322,6 +290,14 @@ def train_model(args): for epoch in tqdm.trange( int(args.num_train_epochs), desc="Epoch", position=0, leave=False ): + if attack_t and epoch > 0: + logger.info("Attacking model to generating new training set...") + attack = attack_t(model) + adv_train_text = [] + for adv_ex in tqdm.tqdm(attack.attack_dataset(list(zip(train_text, train_labels))), desc="Attack", total=len(train_text)): + adv_train_text.append(adv_ex.perturbed_text()) + train_dataloader = make_dataloader(tokenizer, adv_train_text, train_labels, args.batch_size) + prog_bar = tqdm.tqdm( train_dataloader, desc="Iteration", position=1, leave=False ) @@ -369,7 +345,7 @@ def train_model(args): global_step += 1 # Check accuracy after each epoch. - eval_score = get_eval_score() + eval_score = get_eval_score(model, eval_dataloader, args.do_regression) tb_writer.add_scalar("epoch_eval_score", eval_score, global_step) if args.checkpoint_every_epoch: @@ -398,7 +374,7 @@ def train_model(args): logger.info("Finished training. Re-loading and evaluating model from disk.") model = model_from_args(args, args.num_labels) model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name))) - eval_score = get_eval_score() + eval_score = get_eval_score(model, eval_dataloader, args.do_regression) logger.info( f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" ) diff --git a/textattack/commands/train_model/train_args_helpers.py b/textattack/commands/train_model/train_args_helpers.py index 7f40d6a4..6ffdb6f1 100644 --- a/textattack/commands/train_model/train_args_helpers.py +++ b/textattack/commands/train_model/train_args_helpers.py @@ -2,6 +2,7 @@ import os import textattack from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES +from textattack.commands.attack.attack_args import ATTACK_RECIPE_NAMES logger = textattack.shared.logger @@ -131,6 +132,17 @@ def model_from_args(train_args, num_labels, model_path=None): return model +def attack_from_args(args): + # note that this returns a recipe type, not an object + # (we need to wait to have access to the model to initialize) + attack_t = None + if args.attack: + if args.attack in ATTACK_RECIPE_NAMES: + attack_t = eval(ATTACK_RECIPE_NAMES[args.attack]) + else: + raise ValueError(f"Unrecognized attack recipe: {args.attack}") + return attack_t + def augmenter_from_args(args): augmenter = None diff --git a/textattack/commands/train_model/train_model_command.py b/textattack/commands/train_model/train_model_command.py index 58df323f..5379489b 100644 --- a/textattack/commands/train_model/train_model_command.py +++ b/textattack/commands/train_model/train_model_command.py @@ -90,6 +90,9 @@ class TrainModelCommand(TextAttackCommand): default=100, help="Total number of epochs to train for", ) + parser.add_argument( + "--attack", type=str, default=None, help="Attack recipe to use (enables adversarial training)" + ) parser.add_argument( "--augment", type=str, default=None, help="Augmentation recipe to use", ) diff --git a/textattack/shared/validators.py b/textattack/shared/validators.py index 612c3ff7..b76b994a 100644 --- a/textattack/shared/validators.py +++ b/textattack/shared/validators.py @@ -9,8 +9,7 @@ from . import logger # A list of goal functions and the corresponding available models. MODELS_BY_GOAL_FUNCTIONS = { (TargetedClassification, UntargetedClassification, InputReduction): [ - r"^textattack.models.classification.*", - r"^textattack.models.entailment.*", + r"^textattack.models.lstm_for_classification.*", r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", ], (NonOverlappingOutput,): [ From a9db0ffda5a1137f53effa59057916248ed5179a Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Thu, 9 Jul 2020 12:18:48 -0400 Subject: [PATCH 02/21] updates to adv training --- textattack/augmentation/augmenter.py | 3 + .../commands/train_model/run_training.py | 162 ++++++++---------- 2 files changed, 78 insertions(+), 87 deletions(-) diff --git a/textattack/augmentation/augmenter.py b/textattack/augmentation/augmenter.py index 70e9b9f3..3f561562 100644 --- a/textattack/augmentation/augmenter.py +++ b/textattack/augmentation/augmenter.py @@ -137,3 +137,6 @@ class Augmenter: all_text_list.extend([text] + augmented_texts) all_id_list.extend([_id] * (1 + len(augmented_texts))) return all_text_list, all_id_list + + def __repr__(self): + return type(self).__name__ diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index 474969aa..95b0020f 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -23,6 +23,12 @@ from .train_args_helpers import ( device = textattack.shared.utils.device logger = textattack.shared.logger +def save_args(args, save_path): + # Save args to file + final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)} + with open(save_path, "w", encoding="utf-8") as f: + f.write(json.dumps(final_args_dict, indent=2) + "\n") + def filter_labels(text, labels, allowed_labels): final_text, final_labels = [], [] for text, label in zip(text, labels): @@ -31,6 +37,29 @@ def filter_labels(text, labels, allowed_labels): final_labels.append(label) return final_text, final_labels +def save_model_checkpoint(model, output_dir, global_step): + # Save model checkpoint + output_dir = os.path.join(output_dir, "checkpoint-{}".format(global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, "training_args.bin")) + +def save_model(model, output_dir, weights_name, config_name): + model_to_save = (model.module if hasattr(model, "module") else model) + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(output_dir, weights_name) + output_config_file = os.path.join(output_dir, config_name) + + torch.save(model_to_save.state_dict(), output_model_file) + try: + model_to_save.config.to_json_file(output_config_file) + except AttributeError: + # no config + pass def get_eval_score(model, eval_dataloader, do_regression): model.eval() @@ -42,9 +71,7 @@ def get_eval_score(model, eval_dataloader, do_regression): if isinstance(input_ids, dict): ## HACK: dataloader collates dict backwards. This is a temporary # workaround to get ids in the right shape - input_ids = { - k: torch.stack(v).T.to(device) for k, v in input_ids.items() - } + input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()} batch_labels = batch_labels.to(device) with torch.no_grad(): @@ -87,22 +114,26 @@ def make_dataloader(tokenizer, text, labels, batch_size): labels = np.array(labels) data = list((ids, label) for ids, label in zip(input_ids, labels)) sampler = RandomSampler(data) - dataloader = DataLoader( - data, sampler=sampler, batch_size=batch_size, - ) + dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) return dataloader def data_augmentation(text, labels, augmenter): - aug_text = augmenter.augment_many(text) - # flatten augmented examples and duplicate labels - flat_aug_text = [] - flat_aug_labels = [] - for i, examples in enumerate(aug_text): - for aug_ver in examples: - flat_aug_text.append(aug_ver) - flat_aug_labels.append(labels[i]) - return flat_aug_text, flat_aug_labels + aug_text = augmenter.augment_many(text) + # flatten augmented examples and duplicate labels + flat_aug_text = [] + flat_aug_labels = [] + for i, examples in enumerate(aug_text): + for aug_ver in examples: + flat_aug_text.append(aug_ver) + flat_aug_labels.append(labels[i]) + return flat_aug_text, flat_aug_labels +def attack_model(model, attack_t, dataset): + attack = attack_t(model) + adv_train_text = [] + for adv_ex in tqdm.tqdm(attack.attack_dataset(dataset), desc="Attack", total=len(dataset)): + adv_train_text.append(adv_ex.perturbed_text()) + return adv_train_text def train_model(args): logger.warn( @@ -131,14 +162,17 @@ def train_model(args): # Filter labels if args.allowed_labels: - logger.info(f"Filtering samples with labels outside of {args.allowed_labels}.") train_text, train_labels = filter_labels(train_text, train_labels, args.allowed_labels) eval_text, eval_labels = filter_labels(eval_text, eval_labels, args.allowed_labels) + train_examples_len = len(train_text) + # data augmentation augmenter = augmenter_from_args(args) if augmenter: + logger.info(f"Augmenting {len(train_text)} samples with {augmenter}") train_text, train_labels = data_augmentation(train_text, train_labels, augmenter) + logger.info(f"Using augmented training set of size {len(train_text)}") label_id_len = len(train_labels) label_set = set(train_labels) @@ -153,10 +187,9 @@ def train_model(args): else: args.do_regression = False - train_examples_len = len(train_text) - if len(train_labels) != train_examples_len: - raise ValueError(f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})") + if len(train_labels) != len(train_text): + raise ValueError(f"Number of train examples ({len(train_text)}) does not match number of labels ({len(train_labels)})") if len(eval_labels) != len(eval_text): raise ValueError(f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})") @@ -187,23 +220,11 @@ def train_model(args): param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ - { - "params": [ - p for n, p in param_optimizer if not any(nd in n for nd in no_decay) - ], - "weight_decay": 0.01, - }, - { - "params": [ - p for n, p in param_optimizer if any(nd in n for nd in no_decay) - ], - "weight_decay": 0.0, - }, - ] + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01,}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0,}, + ] - optimizer = transformers.optimization.AdamW( - optimizer_grouped_parameters, lr=args.learning_rate - ) + optimizer = transformers.optimization.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) scheduler = transformers.optimization.get_linear_schedule_with_warmup( optimizer, @@ -216,19 +237,20 @@ def train_model(args): tb_writer = SummaryWriter(args.output_dir) - args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)} - # Save original args to file args_save_path = os.path.join(args.output_dir, "train_args.json") - with open(args_save_path, "w", encoding="utf-8") as f: - f.write(json.dumps(args_dict, indent=2) + "\n") + save_args(args, args_save_path) logger.info(f"Wrote original training args to {args_save_path}.") - tb_writer.add_hparams(args_dict, {}) + tb_writer.add_hparams({k: v for k, v in vars(args).items() if is_writable_type(v)}, {}) # Start training logger.info("***** Running training *****") - logger.info(f"\tNum examples = {train_examples_len}") + if augmenter: + logger.info(f"\tNum original examples = {train_examples_len}") + logger.info(f"\tNum examples after augmentation = {len(train_text)}") + else: + logger.info(f"\tNum examples = {train_examples_len}") logger.info(f"\tBatch size = {args.batch_size}") logger.info(f"\tMax sequence length = {args.max_length}") logger.info(f"\tNum steps = {num_train_optimization_steps}") @@ -238,36 +260,9 @@ def train_model(args): eval_dataloader = make_dataloader(tokenizer, eval_text, eval_labels, args.batch_size) train_dataloader = make_dataloader(tokenizer, train_text, train_labels, args.batch_size) - def save_model(): - model_to_save = ( - model.module if hasattr(model, "module") else model - ) # Only save the model itself - - # If we save using the predefined names, we can load using `from_pretrained` - output_model_file = os.path.join(args.output_dir, args.weights_name) - output_config_file = os.path.join(args.output_dir, args.config_name) - - torch.save(model_to_save.state_dict(), output_model_file) - try: - model_to_save.config.to_json_file(output_config_file) - except AttributeError: - # no config - pass - global_step = 0 tr_loss = 0 - - def save_model_checkpoint(): - # Save model checkpoint - output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - # Take care of distributed/parallel training - model_to_save = model.module if hasattr(model, "module") else model - model_to_save.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) - logger.info(f"Checkpoint saved to {output_dir}.") - + model.train() args.best_eval_score = 0 args.best_eval_score_epoch = 0 @@ -287,15 +282,10 @@ def train_model(args): else: loss_fct = torch.nn.CrossEntropyLoss() - for epoch in tqdm.trange( - int(args.num_train_epochs), desc="Epoch", position=0, leave=False - ): + for epoch in tqdm.trange(int(args.num_train_epochs), desc="Epoch", position=0, leave=False): if attack_t and epoch > 0: - logger.info("Attacking model to generating new training set...") - attack = attack_t(model) - adv_train_text = [] - for adv_ex in tqdm.tqdm(attack.attack_dataset(list(zip(train_text, train_labels))), desc="Attack", total=len(train_text)): - adv_train_text.append(adv_ex.perturbed_text()) + logger.info("Attacking model to generate new training set...") + adv_train_text = attack_model(model, attack_t, list(zip(train_text, train_labels))) train_dataloader = make_dataloader(tokenizer, adv_train_text, train_labels, args.batch_size) prog_bar = tqdm.tqdm( @@ -307,9 +297,8 @@ def train_model(args): if isinstance(input_ids, dict): ## HACK: dataloader collates dict backwards. This is a temporary # workaround to get ids in the right shape - input_ids = { - k: torch.stack(v).T.to(device) for k, v in input_ids.items() - } + input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()} + logits = textattack.shared.utils.model_predict(model, input_ids) if args.do_regression: @@ -339,7 +328,7 @@ def train_model(args): and (args.checkpoint_steps > 0) and (global_step % args.checkpoint_steps) == 0 ): - save_model_checkpoint() + save_model_checkpoint(model, args.output_dir, global_step) # Inc step counter. global_step += 1 @@ -349,7 +338,7 @@ def train_model(args): tb_writer.add_scalar("epoch_eval_score", eval_score, global_step) if args.checkpoint_every_epoch: - save_model_checkpoint() + save_model_checkpoint(model, args.output_dir, args.global_step) logger.info( f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" @@ -358,8 +347,10 @@ def train_model(args): args.best_eval_score = eval_score args.best_eval_score_epoch = epoch args.epochs_since_best_eval_score = 0 - save_model() + save_model(model, args.output_dir, args.weights_name, args.config_name) logger.info(f"Best acc found. Saved model to {args.output_dir}.") + save_args(args, args_save_path) + logger.info(f"Saved updated args to {args_save_path}") else: args.epochs_since_best_eval_score += 1 if (args.early_stopping_epochs > 0) and ( @@ -391,8 +382,5 @@ def train_model(args): # Save a little readme with model info write_readme(args, args.best_eval_score, args.best_eval_score_epoch) - # Save args to file - final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)} - with open(args_save_path, "w", encoding="utf-8") as f: - f.write(json.dumps(final_args_dict, indent=2) + "\n") + save_args(args, args_save_path) logger.info(f"Wrote final training args to {args_save_path}.") From 1387478fb67af462d27b0905cb844bd146610d6a Mon Sep 17 00:00:00 2001 From: "Jin Yong (Jeffrey) Yoo" Date: Fri, 10 Jul 2020 16:55:57 +0900 Subject: [PATCH 03/21] Update greedy_word_swap_wir.py --- textattack/search_methods/greedy_word_swap_wir.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/textattack/search_methods/greedy_word_swap_wir.py b/textattack/search_methods/greedy_word_swap_wir.py index b6f6612a..ed46da7e 100644 --- a/textattack/search_methods/greedy_word_swap_wir.py +++ b/textattack/search_methods/greedy_word_swap_wir.py @@ -59,7 +59,7 @@ class GreedyWordSwapWIR(SearchMethod): initial_result, leave_one_texts ) - softmax_saliency_scores = softmax(torch.Tensor(saliency_scores)).numpy() + softmax_saliency_scores = softmax(torch.Tensor(saliency_scores), dim=0).numpy() # compute the largest change in score we can find by swapping each word delta_ps = [] @@ -73,9 +73,7 @@ class GreedyWordSwapWIR(SearchMethod): # no valid synonym substitutions for this word delta_ps.append(0.0) continue - swap_results, _ = self.get_goal_results( - transformed_text_candidates, initial_result.output - ) + swap_results, _ = self.get_goal_results(transformed_text_candidates) score_change = [result.score for result in swap_results] max_score_change = np.max(score_change) delta_ps.append(max_score_change) From f875f65eedef9082011fc6f5f785890172eb34b7 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Fri, 10 Jul 2020 09:27:33 -0400 Subject: [PATCH 04/21] attempting to fix fast alznatot --- .../learning_to_write/language_model_helpers.py | 3 +++ textattack/shared/attack.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py index 1d03c7e1..b25305d8 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py @@ -52,6 +52,9 @@ class QueryHandler: raw_idx_list[t].append(word_idxs[t]) orig_num_idxs = len(raw_idx_list) raw_idx_list = [x for x in raw_idx_list if len(x)] + if not len(raw_idx_list): + # if no inputs are long enough to check, return inf for all + return [float("-inf")] * len(batch) num_idxs_dropped = orig_num_idxs - len(raw_idx_list) all_raw_idxs = torch.tensor( raw_idx_list, device=self.device, dtype=torch.long diff --git a/textattack/shared/attack.py b/textattack/shared/attack.py index 0ceb1be6..e8c19038 100644 --- a/textattack/shared/attack.py +++ b/textattack/shared/attack.py @@ -38,7 +38,7 @@ class Attack: constraints=[], transformation=None, search_method=None, - constraint_cache_size=2 ** 18, + constraint_cache_size=2 ** 20, ): """ Initialize an attack object. Attacks can be run multiple times. """ self.goal_function = goal_function @@ -150,7 +150,7 @@ class Attack: ): """ Filters a list of potential transformed texts based on ``self.constraints``\. - Checks cache first. + Utilizes an LRU cache to attempt to avoid recomputing common transformations. Args: transformed_texts: A list of candidate transformed ``AttackedText``\s to filter. From 80b9a6e7e7602cc29b6efaa8ce2d8c5284968202 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Fri, 10 Jul 2020 09:42:01 -0400 Subject: [PATCH 05/21] fix parallel worklist issue & increase model cache size --- textattack/commands/attack/run_attack_parallel.py | 15 ++++++++++++--- textattack/goal_functions/goal_function.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/textattack/commands/attack/run_attack_parallel.py b/textattack/commands/attack/run_attack_parallel.py index e1fbc0fb..4cce3fdc 100644 --- a/textattack/commands/attack/run_attack_parallel.py +++ b/textattack/commands/attack/run_attack_parallel.py @@ -93,9 +93,18 @@ def run(args, checkpoint=None): in_queue = torch.multiprocessing.Queue() out_queue = torch.multiprocessing.Queue() # Add stuff to queue. + missing_datapoints = set() for i in worklist: - text, output = dataset[i] - in_queue.put((i, text, output)) + try: + text, output = dataset[i] + in_queue.put((i, text, output)) + except IndexError: + missing_datapoints.add(i) + + # if our dataset is shorter than the number of samples chosen, remove the + # out-of-bounds indices from the dataset + for i in missing_datapoints: + worklist.remove(i) # Start workers. pool = torch.multiprocessing.Pool( @@ -147,7 +156,7 @@ def run(args, checkpoint=None): in_queue.put((worklist_tail, text, output)) except IndexError: raise IndexError( - "Out of bounds access of dataset. Size of data is {} but tried to access index {}".format( + "Tried adding to worklist, but ran out of datapoints. Size of data is {} but tried to access index {}".format( len(dataset), worklist_tail ) ) diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index 20c2c334..72fda564 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -34,7 +34,7 @@ class GoalFunction(ABC): use_cache=True, query_budget=float("inf"), model_batch_size=32, - model_cache_size=2 ** 18, + model_cache_size=2 ** 20, ): validators.validate_model_goal_function_compatibility( self.__class__, model.__class__ From 1690c305e27d4bede3080a6f756df77ccfa2d3a9 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Fri, 10 Jul 2020 13:31:19 -0400 Subject: [PATCH 06/21] --pct-dataset training option --- .../commands/train_model/run_training.py | 32 ++++++++++++++++++- .../train_model/train_model_command.py | 6 ++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index 95b0020f..48790848 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -2,6 +2,7 @@ import json import logging import os import time +import math import numpy as np import scipy @@ -29,6 +30,33 @@ def save_args(args, save_path): with open(save_path, "w", encoding="utf-8") as f: f.write(json.dumps(final_args_dict, indent=2) + "\n") +def get_sample_count(*lsts): + if all(len(lst) == len(lsts[0]) for lst in lsts): + sample_count = len(lsts[0]) + else: + sample_count = None + return sample_count + +def random_shuffle(*lsts): + permutation = np.random.permutation(len(lsts[0])) + shuffled = [] + for lst in lsts: + shuffled.append((np.array(lst)[permutation]).tolist()) + return tuple(shuffled) + +def train_val_split(*arrays, split_val=.2): + sample_count = get_sample_count(*arrays) + if not sample_count: + raise Exception("Batch Axis inconsistent. All input arrays must have first axis of equal length.") + arrays = random_shuffle(*arrays) + split_idx = math.floor(sample_count * split_val) + train_set = [array[split_idx:] for array in arrays] + val_set = [array[:split_idx] for array in arrays] + if len(train_set) == 1 and len(val_set) == 1: + train_set = train_set[0] + val_set = val_set[0] + return train_set, val_set + def filter_labels(text, labels, allowed_labels): final_text, final_labels = [], [] for text, label in zip(text, labels): @@ -165,6 +193,9 @@ def train_model(args): train_text, train_labels = filter_labels(train_text, train_labels, args.allowed_labels) eval_text, eval_labels = filter_labels(eval_text, eval_labels, args.allowed_labels) + if args.pct_dataset < 1.: + logger.info(f"Using {args.pct_dataset*100}% of the training set") + (train_text, train_labels), _ = train_val_split(train_text, train_labels, split_val=1.-args.pct_dataset) train_examples_len = len(train_text) # data augmentation @@ -172,7 +203,6 @@ def train_model(args): if augmenter: logger.info(f"Augmenting {len(train_text)} samples with {augmenter}") train_text, train_labels = data_augmentation(train_text, train_labels, augmenter) - logger.info(f"Using augmented training set of size {len(train_text)}") label_id_len = len(train_labels) label_set = set(train_labels) diff --git a/textattack/commands/train_model/train_model_command.py b/textattack/commands/train_model/train_model_command.py index 5379489b..20859da3 100644 --- a/textattack/commands/train_model/train_model_command.py +++ b/textattack/commands/train_model/train_model_command.py @@ -48,6 +48,12 @@ class TrainModelCommand(TextAttackCommand): "`nlp` library. if dataset has a subset, separate with a colon. " " ex: `glue:sst2` or `rotten_tomatoes`", ) + parser.add_argument( + "--pct-dataset", + type=float, + default=1., + help="Fraction of dataset to use during training ([0., 1.])" + ) parser.add_argument( "--dataset-train-split", "--train-split", From 5541904af76c1324cb2c9f65cb63492c5cfd3321 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Fri, 10 Jul 2020 16:57:19 -0400 Subject: [PATCH 07/21] fix augmenter num_words_to_swap rounding bug --- textattack/augmentation/augmenter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/textattack/augmentation/augmenter.py b/textattack/augmentation/augmenter.py index 3f561562..348dd867 100644 --- a/textattack/augmentation/augmenter.py +++ b/textattack/augmentation/augmenter.py @@ -75,7 +75,7 @@ class Augmenter: attacked_text = AttackedText(text) original_text = attacked_text all_transformed_texts = set() - num_words_to_swap = int(self.pct_words_to_swap * len(attacked_text.words)) + num_words_to_swap = max(int(self.pct_words_to_swap * len(attacked_text.words)), 1) for _ in range(self.transformations_per_example): index_order = list(range(len(attacked_text.words))) random.shuffle(index_order) From 34115e7c0347d283a0da5198f9d662dd2e088b3b Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Fri, 10 Jul 2020 17:03:31 -0400 Subject: [PATCH 08/21] catch LM err --- .../learning_to_write/language_model_helpers.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py index b25305d8..01661e6b 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py @@ -27,7 +27,15 @@ class QueryHandler: except: probs = [] for s, w in zip(sentences, swapped_words): - probs.append(self.try_query([s], [w], batch_size=1)[0]) + try: + probs.append(self.try_query([s], [w], batch_size=1)[0]) + except RuntimeError: + print( + "WARNING: got runtime error trying languag emodel on language model w s/w", + s, + w, + ) + probs.append(float("-inf")) return probs def try_query(self, sentences, swapped_words, batch_size=32): @@ -52,9 +60,6 @@ class QueryHandler: raw_idx_list[t].append(word_idxs[t]) orig_num_idxs = len(raw_idx_list) raw_idx_list = [x for x in raw_idx_list if len(x)] - if not len(raw_idx_list): - # if no inputs are long enough to check, return inf for all - return [float("-inf")] * len(batch) num_idxs_dropped = orig_num_idxs - len(raw_idx_list) all_raw_idxs = torch.tensor( raw_idx_list, device=self.device, dtype=torch.long @@ -63,6 +68,8 @@ class QueryHandler: hidden = self.model.init_hidden(len(batch)) source = word_idxs[:-1, :] target = word_idxs[1:, :] + if (not len(source)) or not len(hidden): + return [float("-inf")] * len(batch) decode, hidden = self.model(source, hidden) decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1) for i in range(len(batch)): From 2498149d22ceae1724e82784be8a609d7ed7e007 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Fri, 10 Jul 2020 17:31:19 -0400 Subject: [PATCH 09/21] add second and microsecond to train output dir to prevent name conflicts --- textattack/commands/train_model/train_model_command.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/textattack/commands/train_model/train_model_command.py b/textattack/commands/train_model/train_model_command.py index 20859da3..79da07df 100644 --- a/textattack/commands/train_model/train_model_command.py +++ b/textattack/commands/train_model/train_model_command.py @@ -14,7 +14,7 @@ class TrainModelCommand(TextAttackCommand): def run(self, args): - date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M") + date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") current_dir = os.path.dirname(os.path.realpath(__file__)) outputs_dir = os.path.join( current_dir, os.pardir, os.pardir, os.pardir, "outputs", "training" From aaf36912f07e281ff250f313b62d9446d8273499 Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Sat, 11 Jul 2020 04:06:15 -0400 Subject: [PATCH 10/21] various fixes --- textattack/commands/attack/attack_args.py | 1 + textattack/loggers/csv_logger.py | 4 +++- textattack/search_methods/greedy_word_swap_wir.py | 4 +++- textattack/transformations/word_swap_masked_lm.py | 8 ++++++-- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/textattack/commands/attack/attack_args.py b/textattack/commands/attack/attack_args.py index 9a20598d..cc38e99e 100644 --- a/textattack/commands/attack/attack_args.py +++ b/textattack/commands/attack/attack_args.py @@ -312,6 +312,7 @@ BLACK_BOX_TRANSFORMATION_CLASS_NAMES = { "word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution", "word-swap-wordnet": "textattack.transformations.WordSwapWordNet", "word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM", + "word-swap-hownet": "textattack.transformations.WordSwapHowNet", } WHITE_BOX_TRANSFORMATION_CLASS_NAMES = { diff --git a/textattack/loggers/csv_logger.py b/textattack/loggers/csv_logger.py index 1dc52097..78b1bc7f 100644 --- a/textattack/loggers/csv_logger.py +++ b/textattack/loggers/csv_logger.py @@ -5,7 +5,7 @@ import sys import pandas as pd from textattack.attack_results import FailedAttackResult -from textattack.shared import logger +from textattack.shared import logger, AttackedText from .logger import Logger @@ -21,6 +21,8 @@ class CSVLogger(Logger): def log_attack_result(self, result): original_text, perturbed_text = result.diff_color(self.color_method) + original_text = original_text.replace("\n", AttackedText.SPLIT_TOKEN) + perturbed_text = perturbed_text.replace("\n", AttackedText.SPLIT_TOKEN) result_type = result.__class__.__name__.replace("AttackResult", "") row = { "original_text": original_text, diff --git a/textattack/search_methods/greedy_word_swap_wir.py b/textattack/search_methods/greedy_word_swap_wir.py index ed46da7e..478c8143 100644 --- a/textattack/search_methods/greedy_word_swap_wir.py +++ b/textattack/search_methods/greedy_word_swap_wir.py @@ -59,7 +59,9 @@ class GreedyWordSwapWIR(SearchMethod): initial_result, leave_one_texts ) - softmax_saliency_scores = softmax(torch.Tensor(saliency_scores), dim=0).numpy() + softmax_saliency_scores = softmax( + torch.Tensor(saliency_scores), dim=0 + ).numpy() # compute the largest change in score we can find by swapping each word delta_ps = [] diff --git a/textattack/transformations/word_swap_masked_lm.py b/textattack/transformations/word_swap_masked_lm.py index d2c91047..533f7e6a 100644 --- a/textattack/transformations/word_swap_masked_lm.py +++ b/textattack/transformations/word_swap_masked_lm.py @@ -94,7 +94,7 @@ class WordSwapMaskedLM(WordSwap): replacement_words = [] for id in top_ids: token = self._lm_tokenizer.convert_ids_to_tokens(id) - if utils.is_one_word(token): + if utils.is_one_word(token) and not check_if_subword(token): replacement_words.append(token) return replacement_words @@ -141,7 +141,7 @@ class WordSwapMaskedLM(WordSwap): replacement_words = [] for id in top_preds: token = self._lm_tokenizer.convert_ids_to_tokens(id) - if utils.is_one_word(token): + if utils.is_one_word(token) and not check_if_subword(token): replacement_words.append(token) return replacement_words else: @@ -229,3 +229,7 @@ def recover_word_case(word, reference_word): else: # if other, just do not alter the word's case return word + + +def check_if_subword(text): + return True if "##" in text else False From cb3b13cd31a9a26bedfccae04185b6bab954ebe6 Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Sat, 11 Jul 2020 04:19:57 -0400 Subject: [PATCH 11/21] fix lint --- Makefile | 4 ++-- tests/test_command_line/test_attack.py | 3 ++- tests/test_command_line/test_augment.py | 3 ++- tests/test_command_line/test_list.py | 3 ++- textattack/loggers/csv_logger.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 2a7e4ae9..c0fc3512 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,11 @@ format: FORCE ## Run black and isort (rewriting files) black . - isort --atomic tests textattack + isort --atomic --recursive tests textattack lint: FORCE ## Run black, isort, flake8 (in check mode) black . --check - isort --check-only tests textattack + isort --check-only --recursive tests textattack flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8 test: FORCE ## Run tests using pytest diff --git a/tests/test_command_line/test_attack.py b/tests/test_command_line/test_attack.py index 3199de8d..20e7d30e 100644 --- a/tests/test_command_line/test_attack.py +++ b/tests/test_command_line/test_attack.py @@ -1,9 +1,10 @@ import pdb import re -from helpers import run_command_and_get_result import pytest +from helpers import run_command_and_get_result + DEBUG = False """ diff --git a/tests/test_command_line/test_augment.py b/tests/test_command_line/test_augment.py index 00cf6176..0648e9f1 100644 --- a/tests/test_command_line/test_augment.py +++ b/tests/test_command_line/test_augment.py @@ -1,6 +1,7 @@ -from helpers import run_command_and_get_result import pytest +from helpers import run_command_and_get_result + augment_test_params = [ ( "simple_augment_test", diff --git a/tests/test_command_line/test_list.py b/tests/test_command_line/test_list.py index 488824a5..96e4dcd4 100644 --- a/tests/test_command_line/test_list.py +++ b/tests/test_command_line/test_list.py @@ -1,6 +1,7 @@ -from helpers import run_command_and_get_result import pytest +from helpers import run_command_and_get_result + list_test_params = [ ( "list_augmentation_recipes", diff --git a/textattack/loggers/csv_logger.py b/textattack/loggers/csv_logger.py index 78b1bc7f..34e8e689 100644 --- a/textattack/loggers/csv_logger.py +++ b/textattack/loggers/csv_logger.py @@ -5,7 +5,7 @@ import sys import pandas as pd from textattack.attack_results import FailedAttackResult -from textattack.shared import logger, AttackedText +from textattack.shared import AttackedText, logger from .logger import Logger From 31b3b2fa4e69913279d20e4f58ca41c1bb026265 Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Sat, 11 Jul 2020 04:27:56 -0400 Subject: [PATCH 12/21] fix linting again --- Makefile | 4 ++-- tests/test_command_line/test_attack.py | 3 +-- tests/test_command_line/test_augment.py | 3 +-- tests/test_command_line/test_list.py | 3 +-- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index c0fc3512..2a7e4ae9 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,11 @@ format: FORCE ## Run black and isort (rewriting files) black . - isort --atomic --recursive tests textattack + isort --atomic tests textattack lint: FORCE ## Run black, isort, flake8 (in check mode) black . --check - isort --check-only --recursive tests textattack + isort --check-only tests textattack flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8 test: FORCE ## Run tests using pytest diff --git a/tests/test_command_line/test_attack.py b/tests/test_command_line/test_attack.py index 20e7d30e..3199de8d 100644 --- a/tests/test_command_line/test_attack.py +++ b/tests/test_command_line/test_attack.py @@ -1,9 +1,8 @@ import pdb import re -import pytest - from helpers import run_command_and_get_result +import pytest DEBUG = False diff --git a/tests/test_command_line/test_augment.py b/tests/test_command_line/test_augment.py index 0648e9f1..00cf6176 100644 --- a/tests/test_command_line/test_augment.py +++ b/tests/test_command_line/test_augment.py @@ -1,6 +1,5 @@ -import pytest - from helpers import run_command_and_get_result +import pytest augment_test_params = [ ( diff --git a/tests/test_command_line/test_list.py b/tests/test_command_line/test_list.py index 96e4dcd4..488824a5 100644 --- a/tests/test_command_line/test_list.py +++ b/tests/test_command_line/test_list.py @@ -1,6 +1,5 @@ -import pytest - from helpers import run_command_and_get_result +import pytest list_test_params = [ ( From cff37df82948adcb1159e5e8bc58bfd94612a548 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Sat, 11 Jul 2020 19:13:01 -0400 Subject: [PATCH 13/21] train formatting, docstrings --- textattack/augmentation/augmenter.py | 24 +- textattack/augmentation/recipes.py | 3 + .../commands/train_model/run_training.py | 259 ++++++++++++++---- .../train_model/train_args_helpers.py | 3 +- .../train_model/train_model_command.py | 11 +- 5 files changed, 234 insertions(+), 66 deletions(-) diff --git a/textattack/augmentation/augmenter.py b/textattack/augmentation/augmenter.py index 348dd867..b2b64e42 100644 --- a/textattack/augmentation/augmenter.py +++ b/textattack/augmentation/augmenter.py @@ -3,7 +3,7 @@ import random import tqdm from textattack.constraints import PreTransformationConstraint -from textattack.shared import AttackedText +from textattack.shared import AttackedText, utils class Augmenter: @@ -75,7 +75,9 @@ class Augmenter: attacked_text = AttackedText(text) original_text = attacked_text all_transformed_texts = set() - num_words_to_swap = max(int(self.pct_words_to_swap * len(attacked_text.words)), 1) + num_words_to_swap = max( + int(self.pct_words_to_swap * len(attacked_text.words)), 1 + ) for _ in range(self.transformations_per_example): index_order = list(range(len(attacked_text.words))) random.shuffle(index_order) @@ -139,4 +141,20 @@ class Augmenter: return all_text_list, all_id_list def __repr__(self): - return type(self).__name__ + main_str = "Augmenter" + "(" + lines = [] + # self.transformation + lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2)) + # self.constraints + constraints_lines = [] + constraints = self.constraints + self.pre_transformation_constraints + if len(constraints): + for i, constraint in enumerate(constraints): + constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2)) + constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2) + else: + constraints_str = "None" + lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2)) + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 386bd778..47f0fe48 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -64,6 +64,9 @@ class EasyDataAugmenter(Augmenter): random.shuffle(augmented_text) return augmented_text[: self.transformations_per_example] + def __repr__(self): + return "EasyDataAugmenter" + class SwapAugmenter(Augmenter): def __init__(self, **kwargs): diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index 48790848..5b3eb1f2 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -1,8 +1,8 @@ import json import logging +import math import os import time -import math import numpy as np import scipy @@ -14,50 +14,91 @@ import transformers import textattack from .train_args_helpers import ( + attack_from_args, augmenter_from_args, dataset_from_args, model_from_args, - attack_from_args, write_readme, ) device = textattack.shared.utils.device logger = textattack.shared.logger -def save_args(args, save_path): - # Save args to file - final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)} + +def _save_args(args, save_path): + """ + Dump args dictionary to a json + + :param: args. Dictionary of arguments to save. + :save_path: Path to json file to write args to. + """ + final_args_dict = {k: v for k, v in vars(args).items() if _is_writable_type(v)} with open(save_path, "w", encoding="utf-8") as f: f.write(json.dumps(final_args_dict, indent=2) + "\n") -def get_sample_count(*lsts): + +def _get_sample_count(*lsts): + """ + Get sample count of a dataset. + + :param *lsts: variable number of lists. + :return: sample count of this dataset, if all lists match, else None. + """ if all(len(lst) == len(lsts[0]) for lst in lsts): sample_count = len(lsts[0]) else: sample_count = None return sample_count -def random_shuffle(*lsts): + +def _random_shuffle(*lsts): + """ + Randomly shuffle a dataset. Applies the same permutation + to each list (to preserve mapping between inputs and targets). + + :param *lsts: variable number of lists to shuffle. + :return: shuffled lsts. + """ permutation = np.random.permutation(len(lsts[0])) shuffled = [] for lst in lsts: shuffled.append((np.array(lst)[permutation]).tolist()) return tuple(shuffled) -def train_val_split(*arrays, split_val=.2): - sample_count = get_sample_count(*arrays) + +def _train_val_split(*lsts, split_val=0.2): + """ + Split dataset into training and validation sets. + + :param *lsts: variable number of lists that make up a dataset (e.g. text, labels) + :param split_val: float [0., 1.). Fraction of the dataset to reserve for evaluation. + :return: (train split of list for list in lsts), (val split of list for list in lsts) + """ + sample_count = _get_sample_count(*lsts) if not sample_count: - raise Exception("Batch Axis inconsistent. All input arrays must have first axis of equal length.") - arrays = random_shuffle(*arrays) + raise Exception( + "Batch Axis inconsistent. All input arrays must have first axis of equal length." + ) + lsts = _random_shuffle(*lsts) split_idx = math.floor(sample_count * split_val) - train_set = [array[split_idx:] for array in arrays] - val_set = [array[:split_idx] for array in arrays] + train_set = [lst[split_idx:] for lst in lsts] + val_set = [lst[:split_idx] for lst in lsts] if len(train_set) == 1 and len(val_set) == 1: train_set = train_set[0] val_set = val_set[0] return train_set, val_set -def filter_labels(text, labels, allowed_labels): + +def _filter_labels(text, labels, allowed_labels): + """ + Keep examples with approved labels + + :param text: list of text inputs. + :param labels: list of corresponding labels. + :param allowed_labels: list of approved label values. + + :return: (final_text, final_labels). Filtered version of text and labels + """ final_text, final_labels = [], [] for text, label in zip(text, labels): if label in allowed_labels: @@ -65,7 +106,15 @@ def filter_labels(text, labels, allowed_labels): final_labels.append(label) return final_text, final_labels -def save_model_checkpoint(model, output_dir, global_step): + +def _save_model_checkpoint(model, output_dir, global_step): + """ + Save model checkpoint to disk. + + :param model: Model to save (pytorch) + :param output_dir: Path to model save dir. + :param global_step: Current global training step #. Used in ckpt filename. + """ # Save model checkpoint output_dir = os.path.join(output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): @@ -73,10 +122,18 @@ def save_model_checkpoint(model, output_dir, global_step): # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, "module") else model model_to_save.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) -def save_model(model, output_dir, weights_name, config_name): - model_to_save = (model.module if hasattr(model, "module") else model) + +def _save_model(model, output_dir, weights_name, config_name): + """ + Save model to disk. + + :param model: Model to save (pytorch) + :param output_dir: Path to model save dir. + :param weights_name: filename for model parameters. + :param config_name: filename for config. + """ + model_to_save = model.module if hasattr(model, "module") else model # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(output_dir, weights_name) @@ -89,7 +146,17 @@ def save_model(model, output_dir, weights_name, config_name): # no config pass -def get_eval_score(model, eval_dataloader, do_regression): + +def _get_eval_score(model, eval_dataloader, do_regression): + """ + Measure performance of a model on the evaluation set. + + :param model: Model to test. + :param eval_dataloader: a torch DataLoader that iterates through the eval set. + :param do_regression: bool. Whether we are doing regression (True) or classification (False) + + :return: pearson correlation, if do_regression==True, else classification accuracy [0., 1.] + """ model.eval() correct = 0 total = 0 @@ -120,23 +187,36 @@ def get_eval_score(model, eval_dataloader, do_regression): correct = (preds == labels).sum() return float(correct) / len(labels) -def make_directories(output_dir): + +def _make_directories(output_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) -def is_writable_type(obj): + +def _is_writable_type(obj): for ok_type in [bool, int, str, float]: if isinstance(obj, ok_type): return True return False + def batch_encode(tokenizer, text_list): if hasattr(tokenizer, "batch_encode"): return tokenizer.batch_encode(text_list) else: return [tokenizer.encode(text_input) for text_input in text_list] -def make_dataloader(tokenizer, text, labels, batch_size): + +def _make_dataloader(tokenizer, text, labels, batch_size): + """ + Create torch DataLoader from list of input text and labels. + + :param tokenizer: Tokenizer to use for this text. + :param text: list of input text. + :param labels: list of corresponding labels. + :param batch_size: batch size (int). + :return: torch DataLoader for this training set. + """ text_ids = batch_encode(tokenizer, text) input_ids = np.array(text_ids) labels = np.array(labels) @@ -145,7 +225,17 @@ def make_dataloader(tokenizer, text, labels, batch_size): dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) return dataloader -def data_augmentation(text, labels, augmenter): + +def _data_augmentation(text, labels, augmenter): + """ + Use an augmentation method to expand a training set. + + :param text: list of input text. + :param labels: list of corresponding labels. + :param augmenter: textattack.augmentation.Augmenter, augmentation scheme. + + :return: augmented_text, augmented_labels. list of (augmented) input text and labels. + """ aug_text = augmenter.augment_many(text) # flatten augmented examples and duplicate labels flat_aug_text = [] @@ -156,19 +246,32 @@ def data_augmentation(text, labels, augmenter): flat_aug_labels.append(labels[i]) return flat_aug_text, flat_aug_labels -def attack_model(model, attack_t, dataset): - attack = attack_t(model) + +def _generate_adversarial_examples(model, attackCls, dataset): + """ + Create a dataset of adversarial examples based on perturbations of the existing dataset. + + :param model: Model to attack. + :param attackCls: class name of attack recipe to run. + :param dataset: iterable of (text, label) pairs. + + :return: list of adversarial examples. + """ + attack = attackCls(model) adv_train_text = [] - for adv_ex in tqdm.tqdm(attack.attack_dataset(dataset), desc="Attack", total=len(dataset)): + for adv_ex in tqdm.tqdm( + attack.attack_dataset(dataset), desc="Attack", total=len(dataset) + ): adv_train_text.append(adv_ex.perturbed_text()) return adv_train_text + def train_model(args): logger.warn( "WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues." ) start_time = time.time() - make_directories(args.output_dir) + _make_directories(args.output_dir) num_gpus = torch.cuda.device_count() @@ -183,6 +286,7 @@ def train_model(args): if args.enable_wandb: global wandb import wandb + wandb.init(sync_tensorboard=True) # Get list of text and list of label (integers) from disk. @@ -190,24 +294,34 @@ def train_model(args): # Filter labels if args.allowed_labels: - train_text, train_labels = filter_labels(train_text, train_labels, args.allowed_labels) - eval_text, eval_labels = filter_labels(eval_text, eval_labels, args.allowed_labels) - - if args.pct_dataset < 1.: + train_text, train_labels = _filter_labels( + train_text, train_labels, args.allowed_labels + ) + eval_text, eval_labels = _filter_labels( + eval_text, eval_labels, args.allowed_labels + ) + + if args.pct_dataset < 1.0: logger.info(f"Using {args.pct_dataset*100}% of the training set") - (train_text, train_labels), _ = train_val_split(train_text, train_labels, split_val=1.-args.pct_dataset) + (train_text, train_labels), _ = _train_val_split( + train_text, train_labels, split_val=1.0 - args.pct_dataset + ) train_examples_len = len(train_text) # data augmentation augmenter = augmenter_from_args(args) if augmenter: logger.info(f"Augmenting {len(train_text)} samples with {augmenter}") - train_text, train_labels = data_augmentation(train_text, train_labels, augmenter) + train_text, train_labels = _data_augmentation( + train_text, train_labels, augmenter + ) label_id_len = len(train_labels) label_set = set(train_labels) args.num_labels = len(label_set) - logger.info(f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})") + logger.info( + f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})" + ) if isinstance(train_labels[0], float): # TODO come up with a more sophisticated scheme for knowing when to do regression @@ -217,11 +331,14 @@ def train_model(args): else: args.do_regression = False - if len(train_labels) != len(train_text): - raise ValueError(f"Number of train examples ({len(train_text)}) does not match number of labels ({len(train_labels)})") + raise ValueError( + f"Number of train examples ({len(train_text)}) does not match number of labels ({len(train_labels)})" + ) if len(eval_labels) != len(eval_text): - raise ValueError(f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})") + raise ValueError( + f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})" + ) model = model_from_args(args, args.num_labels) tokenizer = model.tokenizer @@ -250,11 +367,23 @@ def train_model(args): param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ - {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01,}, - {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0,}, - ] + { + "params": [ + p for n, p in param_optimizer if not any(nd in n for nd in no_decay) + ], + "weight_decay": 0.01, + }, + { + "params": [ + p for n, p in param_optimizer if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] - optimizer = transformers.optimization.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + optimizer = transformers.optimization.AdamW( + optimizer_grouped_parameters, lr=args.learning_rate + ) scheduler = transformers.optimization.get_linear_schedule_with_warmup( optimizer, @@ -269,10 +398,12 @@ def train_model(args): # Save original args to file args_save_path = os.path.join(args.output_dir, "train_args.json") - save_args(args, args_save_path) + _save_args(args, args_save_path) logger.info(f"Wrote original training args to {args_save_path}.") - tb_writer.add_hparams({k: v for k, v in vars(args).items() if is_writable_type(v)}, {}) + tb_writer.add_hparams( + {k: v for k, v in vars(args).items() if _is_writable_type(v)}, {} + ) # Start training logger.info("***** Running training *****") @@ -287,12 +418,16 @@ def train_model(args): logger.info(f"\tNum epochs = {args.num_train_epochs}") logger.info(f"\tLearning rate = {args.learning_rate}") - eval_dataloader = make_dataloader(tokenizer, eval_text, eval_labels, args.batch_size) - train_dataloader = make_dataloader(tokenizer, train_text, train_labels, args.batch_size) - + eval_dataloader = _make_dataloader( + tokenizer, eval_text, eval_labels, args.batch_size + ) + train_dataloader = _make_dataloader( + tokenizer, train_text, train_labels, args.batch_size + ) + global_step = 0 tr_loss = 0 - + model.train() args.best_eval_score = 0 args.best_eval_score_epoch = 0 @@ -312,11 +447,17 @@ def train_model(args): else: loss_fct = torch.nn.CrossEntropyLoss() - for epoch in tqdm.trange(int(args.num_train_epochs), desc="Epoch", position=0, leave=False): + for epoch in tqdm.trange( + int(args.num_train_epochs), desc="Epoch", position=0, leave=False + ): if attack_t and epoch > 0: logger.info("Attacking model to generate new training set...") - adv_train_text = attack_model(model, attack_t, list(zip(train_text, train_labels))) - train_dataloader = make_dataloader(tokenizer, adv_train_text, train_labels, args.batch_size) + adv_train_text = _generate_adversarial_examples( + model, attack_t, list(zip(train_text, train_labels)) + ) + train_dataloader = _make_dataloader( + tokenizer, adv_train_text, train_labels, args.batch_size + ) prog_bar = tqdm.tqdm( train_dataloader, desc="Iteration", position=1, leave=False @@ -327,7 +468,9 @@ def train_model(args): if isinstance(input_ids, dict): ## HACK: dataloader collates dict backwards. This is a temporary # workaround to get ids in the right shape - input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()} + input_ids = { + k: torch.stack(v).T.to(device) for k, v in input_ids.items() + } logits = textattack.shared.utils.model_predict(model, input_ids) @@ -358,17 +501,17 @@ def train_model(args): and (args.checkpoint_steps > 0) and (global_step % args.checkpoint_steps) == 0 ): - save_model_checkpoint(model, args.output_dir, global_step) + _save_model_checkpoint(model, args.output_dir, global_step) # Inc step counter. global_step += 1 # Check accuracy after each epoch. - eval_score = get_eval_score(model, eval_dataloader, args.do_regression) + eval_score = _get_eval_score(model, eval_dataloader, args.do_regression) tb_writer.add_scalar("epoch_eval_score", eval_score, global_step) if args.checkpoint_every_epoch: - save_model_checkpoint(model, args.output_dir, args.global_step) + _save_model_checkpoint(model, args.output_dir, args.global_step) logger.info( f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" @@ -377,9 +520,9 @@ def train_model(args): args.best_eval_score = eval_score args.best_eval_score_epoch = epoch args.epochs_since_best_eval_score = 0 - save_model(model, args.output_dir, args.weights_name, args.config_name) + _save_model(model, args.output_dir, args.weights_name, args.config_name) logger.info(f"Best acc found. Saved model to {args.output_dir}.") - save_args(args, args_save_path) + _save_args(args, args_save_path) logger.info(f"Saved updated args to {args_save_path}") else: args.epochs_since_best_eval_score += 1 @@ -395,7 +538,7 @@ def train_model(args): logger.info("Finished training. Re-loading and evaluating model from disk.") model = model_from_args(args, args.num_labels) model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name))) - eval_score = get_eval_score(model, eval_dataloader, args.do_regression) + eval_score = _get_eval_score(model, eval_dataloader, args.do_regression) logger.info( f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" ) @@ -412,5 +555,5 @@ def train_model(args): # Save a little readme with model info write_readme(args, args.best_eval_score, args.best_eval_score_epoch) - save_args(args, args_save_path) + _save_args(args, args_save_path) logger.info(f"Wrote final training args to {args_save_path}.") diff --git a/textattack/commands/train_model/train_args_helpers.py b/textattack/commands/train_model/train_args_helpers.py index 6ffdb6f1..161980ff 100644 --- a/textattack/commands/train_model/train_args_helpers.py +++ b/textattack/commands/train_model/train_args_helpers.py @@ -1,8 +1,8 @@ import os import textattack -from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES from textattack.commands.attack.attack_args import ATTACK_RECIPE_NAMES +from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES logger = textattack.shared.logger @@ -132,6 +132,7 @@ def model_from_args(train_args, num_labels, model_path=None): return model + def attack_from_args(args): # note that this returns a recipe type, not an object # (we need to wait to have access to the model to initialize) diff --git a/textattack/commands/train_model/train_model_command.py b/textattack/commands/train_model/train_model_command.py index 79da07df..19a6c283 100644 --- a/textattack/commands/train_model/train_model_command.py +++ b/textattack/commands/train_model/train_model_command.py @@ -51,8 +51,8 @@ class TrainModelCommand(TextAttackCommand): parser.add_argument( "--pct-dataset", type=float, - default=1., - help="Fraction of dataset to use during training ([0., 1.])" + default=1.0, + help="Fraction of dataset to use during training ([0., 1.])", ) parser.add_argument( "--dataset-train-split", @@ -84,7 +84,7 @@ class TrainModelCommand(TextAttackCommand): help="save model after this many steps (-1 for no checkpointing)", ) parser.add_argument( - "--checkpoint-every_epoch", + "--checkpoint-every-epoch", action="store_true", default=False, help="save model checkpoint after each epoch", @@ -97,7 +97,10 @@ class TrainModelCommand(TextAttackCommand): help="Total number of epochs to train for", ) parser.add_argument( - "--attack", type=str, default=None, help="Attack recipe to use (enables adversarial training)" + "--attack", + type=str, + default=None, + help="Attack recipe to use (enables adversarial training)", ) parser.add_argument( "--augment", type=str, default=None, help="Augmentation recipe to use", From 5cbff8b84253cda2c8d1c95c3226637b9549eff3 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Sat, 11 Jul 2020 21:52:31 -0400 Subject: [PATCH 14/21] --num-clean-epochs during adv training --- .../commands/train_model/run_training.py | 72 ++++++++++--------- .../train_model/train_args_helpers.py | 9 ++- .../train_model/train_model_command.py | 6 ++ 3 files changed, 51 insertions(+), 36 deletions(-) diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index 5b3eb1f2..0df33fab 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -343,7 +343,8 @@ def train_model(args): model = model_from_args(args, args.num_labels) tokenizer = model.tokenizer - attack_t = attack_from_args(args) + attackCls = attack_from_args(args) + adversarial_training = attackCls is not None # multi-gpu training if num_gpus > 1: @@ -450,14 +451,17 @@ def train_model(args): for epoch in tqdm.trange( int(args.num_train_epochs), desc="Epoch", position=0, leave=False ): - if attack_t and epoch > 0: - logger.info("Attacking model to generate new training set...") - adv_train_text = _generate_adversarial_examples( - model, attack_t, list(zip(train_text, train_labels)) - ) - train_dataloader = _make_dataloader( - tokenizer, adv_train_text, train_labels, args.batch_size - ) + if adversarial_training: + if epoch >= args.num_clean_epochs: + logger.info("Attacking model to generate new training set...") + adv_train_text = _generate_adversarial_examples( + model, attackCls, list(zip(train_text, train_labels)) + ) + train_dataloader = _make_dataloader( + tokenizer, adv_train_text, train_labels, args.batch_size + ) + else: + logger.info(f"Running clean epoch {epoch+1}/{args.num_clean_epochs}") prog_bar = tqdm.tqdm( train_dataloader, desc="Iteration", position=1, leave=False @@ -507,32 +511,34 @@ def train_model(args): global_step += 1 # Check accuracy after each epoch. - eval_score = _get_eval_score(model, eval_dataloader, args.do_regression) - tb_writer.add_scalar("epoch_eval_score", eval_score, global_step) + # skip args.num_clean_epochs during adversarial training + if not adversarial_training or epoch >= args.num_clean_epochs: + eval_score = _get_eval_score(model, eval_dataloader, args.do_regression) + tb_writer.add_scalar("epoch_eval_score", eval_score, global_step) - if args.checkpoint_every_epoch: - _save_model_checkpoint(model, args.output_dir, args.global_step) + if args.checkpoint_every_epoch: + _save_model_checkpoint(model, args.output_dir, args.global_step) - logger.info( - f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" - ) - if eval_score > args.best_eval_score: - args.best_eval_score = eval_score - args.best_eval_score_epoch = epoch - args.epochs_since_best_eval_score = 0 - _save_model(model, args.output_dir, args.weights_name, args.config_name) - logger.info(f"Best acc found. Saved model to {args.output_dir}.") - _save_args(args, args_save_path) - logger.info(f"Saved updated args to {args_save_path}") - else: - args.epochs_since_best_eval_score += 1 - if (args.early_stopping_epochs > 0) and ( - args.epochs_since_best_eval_score > args.early_stopping_epochs - ): - logger.info( - f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased" - ) - break + logger.info( + f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%" + ) + if eval_score > args.best_eval_score: + args.best_eval_score = eval_score + args.best_eval_score_epoch = epoch + args.epochs_since_best_eval_score = 0 + _save_model(model, args.output_dir, args.weights_name, args.config_name) + logger.info(f"Best acc found. Saved model to {args.output_dir}.") + _save_args(args, args_save_path) + logger.info(f"Saved updated args to {args_save_path}") + else: + args.epochs_since_best_eval_score += 1 + if (args.early_stopping_epochs > 0) and ( + args.epochs_since_best_eval_score > args.early_stopping_epochs + ): + logger.info( + f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased" + ) + break # read the saved model and report its eval performance logger.info("Finished training. Re-loading and evaluating model from disk.") diff --git a/textattack/commands/train_model/train_args_helpers.py b/textattack/commands/train_model/train_args_helpers.py index 161980ff..742e288a 100644 --- a/textattack/commands/train_model/train_args_helpers.py +++ b/textattack/commands/train_model/train_args_helpers.py @@ -136,13 +136,16 @@ def model_from_args(train_args, num_labels, model_path=None): def attack_from_args(args): # note that this returns a recipe type, not an object # (we need to wait to have access to the model to initialize) - attack_t = None + attackCls = None if args.attack: if args.attack in ATTACK_RECIPE_NAMES: - attack_t = eval(ATTACK_RECIPE_NAMES[args.attack]) + attackCls = eval(ATTACK_RECIPE_NAMES[args.attack]) else: raise ValueError(f"Unrecognized attack recipe: {args.attack}") - return attack_t + + # check attack-related args + assert args.num_clean_epochs > 0, "--num-clean-epochs must be > 0" + return attackCls def augmenter_from_args(args): diff --git a/textattack/commands/train_model/train_model_command.py b/textattack/commands/train_model/train_model_command.py index 19a6c283..25db03fc 100644 --- a/textattack/commands/train_model/train_model_command.py +++ b/textattack/commands/train_model/train_model_command.py @@ -102,6 +102,12 @@ class TrainModelCommand(TextAttackCommand): default=None, help="Attack recipe to use (enables adversarial training)", ) + parser.add_argument( + "--num-clean-epochs", + type=int, + default=1, + help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)", + ) parser.add_argument( "--augment", type=str, default=None, help="Augmentation recipe to use", ) From eb742affad0beac46d873a9dfd59a1da685c1e78 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Sat, 11 Jul 2020 22:10:45 -0400 Subject: [PATCH 15/21] --attack-period for adversarial training --- textattack/commands/train_model/run_training.py | 17 ++++++++++------- .../commands/train_model/train_model_command.py | 6 ++++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/textattack/commands/train_model/run_training.py b/textattack/commands/train_model/run_training.py index 0df33fab..1496ad51 100644 --- a/textattack/commands/train_model/run_training.py +++ b/textattack/commands/train_model/run_training.py @@ -453,13 +453,16 @@ def train_model(args): ): if adversarial_training: if epoch >= args.num_clean_epochs: - logger.info("Attacking model to generate new training set...") - adv_train_text = _generate_adversarial_examples( - model, attackCls, list(zip(train_text, train_labels)) - ) - train_dataloader = _make_dataloader( - tokenizer, adv_train_text, train_labels, args.batch_size - ) + if (epoch - args.num_clean_epochs) % args.attack_period == 0: + # only generate a new adversarial training set every args.attack_period epochs + # after the clean epochs + logger.info("Attacking model to generate new training set...") + adv_train_text = _generate_adversarial_examples( + model, attackCls, list(zip(train_text, train_labels)) + ) + train_dataloader = _make_dataloader( + tokenizer, adv_train_text, train_labels, args.batch_size + ) else: logger.info(f"Running clean epoch {epoch+1}/{args.num_clean_epochs}") diff --git a/textattack/commands/train_model/train_model_command.py b/textattack/commands/train_model/train_model_command.py index 25db03fc..6537ec38 100644 --- a/textattack/commands/train_model/train_model_command.py +++ b/textattack/commands/train_model/train_model_command.py @@ -108,6 +108,12 @@ class TrainModelCommand(TextAttackCommand): default=1, help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)", ) + parser.add_argument( + "--attack-period", + type=int, + default=1, + help="How often (in epochs) to generate a new adversarial training set (N/A if --attack unspecified)", + ) parser.add_argument( "--augment", type=str, default=None, help="Augmentation recipe to use", ) From da7273ac6e65770cad28a6ee2188c108125296f8 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Sat, 11 Jul 2020 23:51:27 -0400 Subject: [PATCH 16/21] fix eval bug with multiarg nlp datasets --- textattack/commands/attack/attack_args_helpers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/textattack/commands/attack/attack_args_helpers.py b/textattack/commands/attack/attack_args_helpers.py index 7d9dff03..1fe6e691 100644 --- a/textattack/commands/attack/attack_args_helpers.py +++ b/textattack/commands/attack/attack_args_helpers.py @@ -355,9 +355,13 @@ def parse_dataset_from_args(args): ) model_train_args = json.loads(open(model_args_json_path).read()) try: + if ":" in model_train_args["dataset"]: + name, subset = model_train_args["dataset"].split(":") + else: + name, subset = model_train_args["dataset"], None args.dataset_from_nlp = ( - model_train_args["dataset"], - None, + name, + subset, model_train_args["dataset_dev_split"], ) except KeyError: From 139477056276ac7e95104534a6ba06cae2e281ea Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Sun, 12 Jul 2020 13:20:12 -0400 Subject: [PATCH 17/21] wip --- textattack/commands/attack/attack_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/textattack/commands/attack/attack_args.py b/textattack/commands/attack/attack_args.py index cc38e99e..c20d4aeb 100644 --- a/textattack/commands/attack/attack_args.py +++ b/textattack/commands/attack/attack_args.py @@ -358,6 +358,7 @@ SEARCH_METHOD_CLASS_NAMES = { "greedy": "textattack.search_methods.GreedySearch", "ga-word": "textattack.search_methods.GeneticAlgorithm", "greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR", + "pso": "textattack.search_methods.ParticleSwarmOptimization", } GOAL_FUNCTION_CLASS_NAMES = { From 09ea97579604f4581fc2937e193f686a2358d1b2 Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Sun, 12 Jul 2020 13:36:19 -0400 Subject: [PATCH 18/21] fixed flair error --- requirements.txt | 2 +- textattack/transformations/word_swap_hownet.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index d368e754..1ecea865 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ bert-score editdistance -flair>=0.5 +flair>=0.5.1 filelock language_tool_python lru-dict diff --git a/textattack/transformations/word_swap_hownet.py b/textattack/transformations/word_swap_hownet.py index 7a0b4bdc..69825ae6 100644 --- a/textattack/transformations/word_swap_hownet.py +++ b/textattack/transformations/word_swap_hownet.py @@ -51,10 +51,11 @@ class WordSwapHowNet(WordSwap): def _get_transformations(self, current_text, indices_to_modify): words = current_text.words - words_str = " ".join(words) - word_list, pos_list = zip_flair_result( - self._flair_pos_tagger.predict(words_str)[0] - ) + sentence = Sentence(" ".join(words)) + # in-place POS tagging + self._flair_pos_tagger.predict(sentence) + word_list, pos_list = zip_flair_result(sentence) + assert len(words) == len( word_list ), "Part-of-speech tagger returned incorrect number of tags" From d6a480abc9201266d00a53b2c6ede63acd134191 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Sun, 12 Jul 2020 16:00:15 -0400 Subject: [PATCH 19/21] patch lstm hotflip (fix #209) --- textattack/transformations/word_swap_gradient_based.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/textattack/transformations/word_swap_gradient_based.py b/textattack/transformations/word_swap_gradient_based.py index af149819..974e0a82 100644 --- a/textattack/transformations/word_swap_gradient_based.py +++ b/textattack/transformations/word_swap_gradient_based.py @@ -52,6 +52,7 @@ class WordSwapGradientBased(Transformation): word_index (int): index of the word to replace """ self.model.train() + self.model.emb_layer.embedding.weight.requires_grad = True lookup_table = self.model.lookup_table.to(utils.device) lookup_table_transpose = lookup_table.transpose(0, 1) @@ -105,6 +106,7 @@ class WordSwapGradientBased(Transformation): break self.model.eval() + self.model.emb_layer.embedding.weight.requires_grad = False return candidates def _call_model(self, text_ids): From 185c17eb77eaa048cf3c5b97552ae71c9cdeb498 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Sun, 12 Jul 2020 16:05:09 -0400 Subject: [PATCH 20/21] set emb_layer grad back to original flag (fix #209) --- textattack/models/helpers/lstm_for_classification.py | 1 + textattack/transformations/word_swap_gradient_based.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/textattack/models/helpers/lstm_for_classification.py b/textattack/models/helpers/lstm_for_classification.py index a0cdf949..52587710 100644 --- a/textattack/models/helpers/lstm_for_classification.py +++ b/textattack/models/helpers/lstm_for_classification.py @@ -31,6 +31,7 @@ class LSTMForClassification(nn.Module): # so if that's all we have, this will display a warning. dropout = 0 self.drop = nn.Dropout(dropout) + self.emb_layer_trainable = emb_layer_trainable self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable) self.word2id = self.emb_layer.word2id self.encoder = nn.LSTM( diff --git a/textattack/transformations/word_swap_gradient_based.py b/textattack/transformations/word_swap_gradient_based.py index 974e0a82..e7404208 100644 --- a/textattack/transformations/word_swap_gradient_based.py +++ b/textattack/transformations/word_swap_gradient_based.py @@ -106,7 +106,7 @@ class WordSwapGradientBased(Transformation): break self.model.eval() - self.model.emb_layer.embedding.weight.requires_grad = False + self.model.emb_layer.embedding.weight.requires_grad = self.model.emb_layer_trainable return candidates def _call_model(self, text_ids): From e937e015319df9c5c234c6ddda650f24e2676e76 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Sun, 12 Jul 2020 16:39:10 -0400 Subject: [PATCH 21/21] explicit eval mode + formatting --- textattack/goal_functions/goal_function.py | 1 + textattack/transformations/word_swap_gradient_based.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index 72fda564..30b9163e 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -40,6 +40,7 @@ class GoalFunction(ABC): self.__class__, model.__class__ ) self.model = model + self.model.eval() self.maximizable = maximizable self.tokenizer = tokenizer if not self.tokenizer: diff --git a/textattack/transformations/word_swap_gradient_based.py b/textattack/transformations/word_swap_gradient_based.py index e7404208..6656990d 100644 --- a/textattack/transformations/word_swap_gradient_based.py +++ b/textattack/transformations/word_swap_gradient_based.py @@ -106,7 +106,9 @@ class WordSwapGradientBased(Transformation): break self.model.eval() - self.model.emb_layer.embedding.weight.requires_grad = self.model.emb_layer_trainable + self.model.emb_layer.embedding.weight.requires_grad = ( + self.model.emb_layer_trainable + ) return candidates def _call_model(self, text_ids):