1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

merge and format

This commit is contained in:
Jack Morris
2020-07-12 20:40:22 -04:00
20 changed files with 443 additions and 201 deletions

View File

@@ -1,12 +1,13 @@
PEP_IGNORE_ERRORS="C901 E501 W503 E203 E231 E266 F403"
format: FORCE ## Run black and isort (rewriting files)
black .
isort --atomic tests textattack
lint: FORCE ## Run black, isort, flake8 (in check mode)
black . --check
isort --check-only tests textattack
flake8 . --count --ignore=C901,E501,W503,E203,E231,E266,F403 --show-source --statistics --exclude=./.*,build,dist
flake8 . --count --ignore=$(PEP_IGNORE_ERRORS) --show-source --statistics --exclude=./.*,build,dist
test: FORCE ## Run tests using pytest

View File

@@ -1,6 +1,6 @@
bert-score
editdistance
flair>=0.5
flair>=0.5.1
filelock
language_tool_python
lru-dict

View File

@@ -3,7 +3,7 @@ import random
import tqdm
from textattack.constraints import PreTransformationConstraint
from textattack.shared import AttackedText
from textattack.shared import AttackedText, utils
class Augmenter:
@@ -70,7 +70,9 @@ class Augmenter:
attacked_text = AttackedText(text)
original_text = attacked_text
all_transformed_texts = set()
num_words_to_swap = int(self.pct_words_to_swap * len(attacked_text.words))
num_words_to_swap = max(
int(self.pct_words_to_swap * len(attacked_text.words)), 1
)
for _ in range(self.transformations_per_example):
index_order = list(range(len(attacked_text.words)))
random.shuffle(index_order)
@@ -132,3 +134,22 @@ class Augmenter:
all_text_list.extend([text] + augmented_texts)
all_id_list.extend([_id] * (1 + len(augmented_texts)))
return all_text_list, all_id_list
def __repr__(self):
main_str = "Augmenter" + "("
lines = []
# self.transformation
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
# self.constraints
constraints_lines = []
constraints = self.constraints + self.pre_transformation_constraints
if len(constraints):
for i, constraint in enumerate(constraints):
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
else:
constraints_str = "None"
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str

View File

@@ -61,6 +61,9 @@ class EasyDataAugmenter(Augmenter):
random.shuffle(augmented_text)
return augmented_text[: self.transformations_per_example]
def __repr__(self):
return "EasyDataAugmenter"
class SwapAugmenter(Augmenter):
def __init__(self, **kwargs):

View File

@@ -308,6 +308,7 @@ BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
"word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
"word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
"word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
"word-swap-hownet": "textattack.transformations.WordSwapHowNet",
}
WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
@@ -353,6 +354,7 @@ SEARCH_METHOD_CLASS_NAMES = {
"greedy": "textattack.search_methods.GreedySearch",
"ga-word": "textattack.search_methods.GeneticAlgorithm",
"greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
"pso": "textattack.search_methods.ParticleSwarmOptimization",
}
GOAL_FUNCTION_CLASS_NAMES = {

View File

@@ -360,9 +360,13 @@ def parse_dataset_from_args(args):
)
model_train_args = json.loads(open(model_args_json_path).read())
try:
if ":" in model_train_args["dataset"]:
name, subset = model_train_args["dataset"].split(":")
else:
name, subset = model_train_args["dataset"], None
args.dataset_from_nlp = (
model_train_args["dataset"],
None,
name,
subset,
model_train_args["dataset_dev_split"],
)
except KeyError:

View File

@@ -94,9 +94,18 @@ def run(args, checkpoint=None):
in_queue = torch.multiprocessing.Queue()
out_queue = torch.multiprocessing.Queue()
# Add stuff to queue.
missing_datapoints = set()
for i in worklist:
text, output = dataset[i]
in_queue.put((i, text, output))
try:
text, output = dataset[i]
in_queue.put((i, text, output))
except IndexError:
missing_datapoints.add(i)
# if our dataset is shorter than the number of samples chosen, remove the
# out-of-bounds indices from the dataset
for i in missing_datapoints:
worklist.remove(i)
# Start workers.
# pool = torch.multiprocessing.Pool(num_gpus, attack_from_queue, (args, in_queue, out_queue))
@@ -146,7 +155,7 @@ def run(args, checkpoint=None):
in_queue.put((worklist_tail, text, output))
except IndexError:
raise IndexError(
"Out of bounds access of dataset. Size of data is {} but tried to access index {}".format(
"Tried adding to worklist, but ran out of datapoints. Size of data is {} but tried to access index {}".format(
len(dataset), worklist_tail
)
)

View File

@@ -1,18 +1,19 @@
import json
import logging
import math
import os
import time
import numpy as np
import scipy
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, RandomSampler
import tqdm
import transformers
import textattack
from .train_args_helpers import (
attack_from_args,
augmenter_from_args,
dataset_from_args,
model_from_args,
@@ -23,9 +24,178 @@ device = textattack.shared.utils.device
logger = textattack.shared.logger
def make_directories(output_dir):
def _save_args(args, save_path):
"""
Dump args dictionary to a json
:param: args. Dictionary of arguments to save.
:save_path: Path to json file to write args to.
"""
final_args_dict = {k: v for k, v in vars(args).items() if _is_writable_type(v)}
with open(save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(final_args_dict, indent=2) + "\n")
def _get_sample_count(*lsts):
"""
Get sample count of a dataset.
:param *lsts: variable number of lists.
:return: sample count of this dataset, if all lists match, else None.
"""
if all(len(lst) == len(lsts[0]) for lst in lsts):
sample_count = len(lsts[0])
else:
sample_count = None
return sample_count
def _random_shuffle(*lsts):
"""
Randomly shuffle a dataset. Applies the same permutation
to each list (to preserve mapping between inputs and targets).
:param *lsts: variable number of lists to shuffle.
:return: shuffled lsts.
"""
permutation = np.random.permutation(len(lsts[0]))
shuffled = []
for lst in lsts:
shuffled.append((np.array(lst)[permutation]).tolist())
return tuple(shuffled)
def _train_val_split(*lsts, split_val=0.2):
"""
Split dataset into training and validation sets.
:param *lsts: variable number of lists that make up a dataset (e.g. text, labels)
:param split_val: float [0., 1.). Fraction of the dataset to reserve for evaluation.
:return: (train split of list for list in lsts), (val split of list for list in lsts)
"""
sample_count = _get_sample_count(*lsts)
if not sample_count:
raise Exception(
"Batch Axis inconsistent. All input arrays must have first axis of equal length."
)
lsts = _random_shuffle(*lsts)
split_idx = math.floor(sample_count * split_val)
train_set = [lst[split_idx:] for lst in lsts]
val_set = [lst[:split_idx] for lst in lsts]
if len(train_set) == 1 and len(val_set) == 1:
train_set = train_set[0]
val_set = val_set[0]
return train_set, val_set
def _filter_labels(text, labels, allowed_labels):
"""
Keep examples with approved labels
:param text: list of text inputs.
:param labels: list of corresponding labels.
:param allowed_labels: list of approved label values.
:return: (final_text, final_labels). Filtered version of text and labels
"""
final_text, final_labels = [], []
for text, label in zip(text, labels):
if label in allowed_labels:
final_text.append(text)
final_labels.append(label)
return final_text, final_labels
def _save_model_checkpoint(model, output_dir, global_step):
"""
Save model checkpoint to disk.
:param model: Model to save (pytorch)
:param output_dir: Path to model save dir.
:param global_step: Current global training step #. Used in ckpt filename.
"""
# Save model checkpoint
output_dir = os.path.join(output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Take care of distributed/parallel training
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir)
def _save_model(model, output_dir, weights_name, config_name):
"""
Save model to disk.
:param model: Model to save (pytorch)
:param output_dir: Path to model save dir.
:param weights_name: filename for model parameters.
:param config_name: filename for config.
"""
model_to_save = model.module if hasattr(model, "module") else model
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(output_dir, weights_name)
output_config_file = os.path.join(output_dir, config_name)
torch.save(model_to_save.state_dict(), output_model_file)
try:
model_to_save.config.to_json_file(output_config_file)
except AttributeError:
# no config
pass
def _get_eval_score(model, eval_dataloader, do_regression):
"""
Measure performance of a model on the evaluation set.
:param model: Model to test.
:param eval_dataloader: a torch DataLoader that iterates through the eval set.
:param do_regression: bool. Whether we are doing regression (True) or classification (False)
:return: pearson correlation, if do_regression==True, else classification accuracy [0., 1.]
"""
model.eval()
correct = 0
logits = []
labels = []
for input_ids, batch_labels in eval_dataloader:
if isinstance(input_ids, dict):
## HACK: dataloader collates dict backwards. This is a temporary
# workaround to get ids in the right shape
input_ids = {k: torch.stack(v).T.to(device) for k, v in input_ids.items()}
batch_labels = batch_labels.to(device)
with torch.no_grad():
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
logits.extend(batch_logits.cpu().squeeze().tolist())
labels.extend(batch_labels)
model.train()
logits = torch.tensor(logits)
labels = torch.tensor(labels)
if do_regression:
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
return pearson_correlation
else:
preds = logits.argmax(dim=1)
correct = (preds == labels).sum()
return float(correct) / len(labels)
def _make_directories(output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def _is_writable_type(obj):
for ok_type in [bool, int, str, float]:
if isinstance(obj, ok_type):
return True
return False
def batch_encode(tokenizer, text_list):
@@ -35,12 +205,70 @@ def batch_encode(tokenizer, text_list):
return [tokenizer.encode(text_input) for text_input in text_list]
def _make_dataloader(tokenizer, text, labels, batch_size):
"""
Create torch DataLoader from list of input text and labels.
:param tokenizer: Tokenizer to use for this text.
:param text: list of input text.
:param labels: list of corresponding labels.
:param batch_size: batch size (int).
:return: torch DataLoader for this training set.
"""
text_ids = batch_encode(tokenizer, text)
input_ids = np.array(text_ids)
labels = np.array(labels)
data = list((ids, label) for ids, label in zip(input_ids, labels))
sampler = RandomSampler(data)
dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size)
return dataloader
def _data_augmentation(text, labels, augmenter):
"""
Use an augmentation method to expand a training set.
:param text: list of input text.
:param labels: list of corresponding labels.
:param augmenter: textattack.augmentation.Augmenter, augmentation scheme.
:return: augmented_text, augmented_labels. list of (augmented) input text and labels.
"""
aug_text = augmenter.augment_many(text)
# flatten augmented examples and duplicate labels
flat_aug_text = []
flat_aug_labels = []
for i, examples in enumerate(aug_text):
for aug_ver in examples:
flat_aug_text.append(aug_ver)
flat_aug_labels.append(labels[i])
return flat_aug_text, flat_aug_labels
def _generate_adversarial_examples(model, attackCls, dataset):
"""
Create a dataset of adversarial examples based on perturbations of the existing dataset.
:param model: Model to attack.
:param attackCls: class name of attack recipe to run.
:param dataset: iterable of (text, label) pairs.
:return: list of adversarial examples.
"""
attack = attackCls(model)
adv_train_text = []
for adv_ex in tqdm.tqdm(
attack.attack_dataset(dataset), desc="Attack", total=len(dataset)
):
adv_train_text.append(adv_ex.perturbed_text())
return adv_train_text
def train_model(args):
logger.warn(
"WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues."
)
start_time = time.time()
make_directories(args.output_dir)
_make_directories(args.output_dir)
num_gpus = torch.cuda.device_count()
@@ -63,52 +291,27 @@ def train_model(args):
# Filter labels
if args.allowed_labels:
logger.info(f"Filtering samples with labels outside of {args.allowed_labels}.")
final_train_text, final_train_labels = [], []
for text, label in zip(train_text, train_labels):
if label in args.allowed_labels:
final_train_text.append(text)
final_train_labels.append(label)
logger.info(
f"Filtered {len(train_text)} train samples to {len(final_train_text)} points."
train_text, train_labels = _filter_labels(
train_text, train_labels, args.allowed_labels
)
train_text, train_labels = final_train_text, final_train_labels
final_eval_text, final_eval_labels = [], []
for text, label in zip(eval_text, eval_labels):
if label in args.allowed_labels:
final_eval_text.append(text)
final_eval_labels.append(label)
logger.info(
f"Filtered {len(eval_text)} dev samples to {len(final_eval_text)} points."
eval_text, eval_labels = _filter_labels(
eval_text, eval_labels, args.allowed_labels
)
eval_text, eval_labels = final_eval_text, final_eval_labels
if args.pct_dataset < 1.0:
logger.info(f"Using {args.pct_dataset*100}% of the training set")
(train_text, train_labels), _ = _train_val_split(
train_text, train_labels, split_val=1.0 - args.pct_dataset
)
train_examples_len = len(train_text)
# data augmentation
augmenter = augmenter_from_args(args)
if augmenter:
# augment the training set
aug_train_text = augmenter.augment_many(train_text)
# flatten augmented examples and duplicate labels
flat_aug_train_text = []
flat_aug_train_labels = []
for i, examples in enumerate(aug_train_text):
for aug_ver in examples:
flat_aug_train_text.append(aug_ver)
flat_aug_train_labels.append(train_labels[i])
train_text = flat_aug_train_text
train_labels = flat_aug_train_labels
# augment the eval set
aug_eval_text = augmenter.augment_many(eval_text)
# flatten the augmented examples and duplicate labels
flat_aug_eval_text = []
flat_aug_eval_labels = []
for i, examples in enumerate(aug_eval_text):
for aug_ver in examples:
flat_aug_eval_text.append(aug_ver)
flat_aug_eval_labels.append(eval_labels[i])
eval_text = flat_aug_eval_text
eval_labels = flat_aug_eval_labels
logger.info(f"Augmenting {len(train_text)} samples with {augmenter}")
train_text, train_labels = _data_augmentation(
train_text, train_labels, augmenter
)
# label_id_len = len(train_labels)
label_set = set(train_labels)
@@ -125,11 +328,9 @@ def train_model(args):
else:
args.do_regression = False
train_examples_len = len(train_text)
if len(train_labels) != train_examples_len:
if len(train_labels) != len(train_text):
raise ValueError(
f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})"
f"Number of train examples ({len(train_text)}) does not match number of labels ({len(train_labels)})"
)
if len(eval_labels) != len(eval_text):
raise ValueError(
@@ -139,16 +340,13 @@ def train_model(args):
model = model_from_args(args, args.num_labels)
tokenizer = model.tokenizer
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
train_text_ids = batch_encode(tokenizer, train_text)
logger.info(f"Tokenizing eval data (len: {len(eval_labels)})")
eval_text_ids = batch_encode(tokenizer, eval_text)
load_time = time.time()
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
attackCls = attack_from_args(args)
adversarial_training = attackCls is not None
# multi-gpu training
if num_gpus > 1:
model = torch.nn.DataParallel(model)
model.tokenizer = model.module.tokenizer
logger.info("Using torch.nn.DataParallel.")
logger.info(f"Training model across {num_gpus} GPUs")
@@ -199,110 +397,38 @@ def train_model(args):
tb_writer = SummaryWriter(args.output_dir)
def is_writable_type(obj):
for ok_type in [bool, int, str, float]:
if isinstance(obj, ok_type):
return True
return False
args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
# Save original args to file
args_save_path = os.path.join(args.output_dir, "train_args.json")
with open(args_save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(args_dict, indent=2) + "\n")
_save_args(args, args_save_path)
logger.info(f"Wrote original training args to {args_save_path}.")
tb_writer.add_hparams(args_dict, {})
tb_writer.add_hparams(
{k: v for k, v in vars(args).items() if _is_writable_type(v)}, {}
)
# Start training
logger.info("***** Running training *****")
logger.info(f"\tNum examples = {train_examples_len}")
if augmenter:
logger.info(f"\tNum original examples = {train_examples_len}")
logger.info(f"\tNum examples after augmentation = {len(train_text)}")
else:
logger.info(f"\tNum examples = {train_examples_len}")
logger.info(f"\tBatch size = {args.batch_size}")
logger.info(f"\tMax sequence length = {args.max_length}")
logger.info(f"\tNum steps = {num_train_optimization_steps}")
logger.info(f"\tNum epochs = {args.num_train_epochs}")
logger.info(f"\tLearning rate = {args.learning_rate}")
train_input_ids = np.array(train_text_ids)
train_labels = np.array(train_labels)
train_data = list((ids, label) for ids, label in zip(train_input_ids, train_labels))
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(
train_data, sampler=train_sampler, batch_size=args.batch_size
eval_dataloader = _make_dataloader(
tokenizer, eval_text, eval_labels, args.batch_size
)
eval_input_ids = np.array(eval_text_ids)
eval_labels = np.array(eval_labels)
eval_data = list((ids, label) for ids, label in zip(eval_input_ids, eval_labels))
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(
eval_data, sampler=eval_sampler, batch_size=args.batch_size
train_dataloader = _make_dataloader(
tokenizer, train_text, train_labels, args.batch_size
)
def get_eval_score():
model.eval()
correct = 0
# total = 0
logits = []
labels = []
for input_ids, batch_labels in eval_dataloader:
if isinstance(input_ids, dict):
## HACK: dataloader collates dict backwards. This is a temporary
# workaround to get ids in the right shape
input_ids = {
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
}
batch_labels = batch_labels.to(device)
with torch.no_grad():
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
logits.extend(batch_logits.cpu().squeeze().tolist())
labels.extend(batch_labels)
model.train()
logits = torch.tensor(logits)
labels = torch.tensor(labels)
if args.do_regression:
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
return pearson_correlation
else:
preds = logits.argmax(dim=1)
correct = (preds == labels).sum()
return float(correct) / len(labels)
def save_model():
model_to_save = (
model.module if hasattr(model, "module") else model
) # Only save the model itself
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, args.weights_name)
output_config_file = os.path.join(args.output_dir, args.config_name)
torch.save(model_to_save.state_dict(), output_model_file)
try:
model_to_save.config.to_json_file(output_config_file)
except AttributeError:
# no config
pass
global_step = 0
tr_loss = 0
def save_model_checkpoint():
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Take care of distributed/parallel training
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info(f"Checkpoint saved to {output_dir}.")
model.train()
args.best_eval_score = 0
args.best_eval_score_epoch = 0
@@ -325,6 +451,21 @@ def train_model(args):
for epoch in tqdm.trange(
int(args.num_train_epochs), desc="Epoch", position=0, leave=False
):
if adversarial_training:
if epoch >= args.num_clean_epochs:
if (epoch - args.num_clean_epochs) % args.attack_period == 0:
# only generate a new adversarial training set every args.attack_period epochs
# after the clean epochs
logger.info("Attacking model to generate new training set...")
adv_train_text = _generate_adversarial_examples(
model, attackCls, list(zip(train_text, train_labels))
)
train_dataloader = _make_dataloader(
tokenizer, adv_train_text, train_labels, args.batch_size
)
else:
logger.info(f"Running clean epoch {epoch+1}/{args.num_clean_epochs}")
prog_bar = tqdm.tqdm(
train_dataloader, desc="Iteration", position=1, leave=False
)
@@ -337,6 +478,7 @@ def train_model(args):
input_ids = {
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
}
logits = textattack.shared.utils.model_predict(model, input_ids)
if args.do_regression:
@@ -366,42 +508,46 @@ def train_model(args):
and (args.checkpoint_steps > 0)
and (global_step % args.checkpoint_steps) == 0
):
save_model_checkpoint()
_save_model_checkpoint(model, args.output_dir, global_step)
# Inc step counter.
global_step += 1
# Check accuracy after each epoch.
eval_score = get_eval_score()
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
# skip args.num_clean_epochs during adversarial training
if not adversarial_training or epoch >= args.num_clean_epochs:
eval_score = _get_eval_score(model, eval_dataloader, args.do_regression)
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
if args.checkpoint_every_epoch:
save_model_checkpoint()
if args.checkpoint_every_epoch:
_save_model_checkpoint(model, args.output_dir, args.global_step)
logger.info(
f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
)
if eval_score > args.best_eval_score:
args.best_eval_score = eval_score
args.best_eval_score_epoch = epoch
args.epochs_since_best_eval_score = 0
save_model()
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
else:
args.epochs_since_best_eval_score += 1
if (args.early_stopping_epochs > 0) and (
args.epochs_since_best_eval_score > args.early_stopping_epochs
):
logger.info(
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
)
break
logger.info(
f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
)
if eval_score > args.best_eval_score:
args.best_eval_score = eval_score
args.best_eval_score_epoch = epoch
args.epochs_since_best_eval_score = 0
_save_model(model, args.output_dir, args.weights_name, args.config_name)
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
_save_args(args, args_save_path)
logger.info(f"Saved updated args to {args_save_path}")
else:
args.epochs_since_best_eval_score += 1
if (args.early_stopping_epochs > 0) and (
args.epochs_since_best_eval_score > args.early_stopping_epochs
):
logger.info(
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
)
break
# read the saved model and report its eval performance
logger.info("Finished training. Re-loading and evaluating model from disk.")
model = model_from_args(args, args.num_labels)
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
eval_score = get_eval_score()
eval_score = _get_eval_score(model, eval_dataloader, args.do_regression)
logger.info(
f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
)
@@ -418,8 +564,5 @@ def train_model(args):
# Save a little readme with model info
write_readme(args, args.best_eval_score, args.best_eval_score_epoch)
# Save args to file
final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
with open(args_save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(final_args_dict, indent=2) + "\n")
_save_args(args, args_save_path)
logger.info(f"Wrote final training args to {args_save_path}.")

View File

@@ -1,6 +1,7 @@
import os
import textattack
from textattack.commands.attack.attack_args import ATTACK_RECIPE_NAMES
from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES
logger = textattack.shared.logger
@@ -133,6 +134,21 @@ def model_from_args(train_args, num_labels, model_path=None):
return model
def attack_from_args(args):
# note that this returns a recipe type, not an object
# (we need to wait to have access to the model to initialize)
attackCls = None
if args.attack:
if args.attack in ATTACK_RECIPE_NAMES:
attackCls = eval(ATTACK_RECIPE_NAMES[args.attack])
else:
raise ValueError(f"Unrecognized attack recipe: {args.attack}")
# check attack-related args
assert args.num_clean_epochs > 0, "--num-clean-epochs must be > 0"
return attackCls
def augmenter_from_args(args):
augmenter = None
if args.augment:

View File

@@ -13,7 +13,7 @@ class TrainModelCommand(TextAttackCommand):
def run(self, args):
date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M")
date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "outputs", "training"
@@ -47,6 +47,12 @@ class TrainModelCommand(TextAttackCommand):
"`nlp` library. if dataset has a subset, separate with a colon. "
" ex: `glue:sst2` or `rotten_tomatoes`",
)
parser.add_argument(
"--pct-dataset",
type=float,
default=1.0,
help="Fraction of dataset to use during training ([0., 1.])",
)
parser.add_argument(
"--dataset-train-split",
"--train-split",
@@ -77,7 +83,7 @@ class TrainModelCommand(TextAttackCommand):
help="save model after this many steps (-1 for no checkpointing)",
)
parser.add_argument(
"--checkpoint-every_epoch",
"--checkpoint-every-epoch",
action="store_true",
default=False,
help="save model checkpoint after each epoch",
@@ -89,6 +95,24 @@ class TrainModelCommand(TextAttackCommand):
default=100,
help="Total number of epochs to train for",
)
parser.add_argument(
"--attack",
type=str,
default=None,
help="Attack recipe to use (enables adversarial training)",
)
parser.add_argument(
"--num-clean-epochs",
type=int,
default=1,
help="Number of epochs to train on the clean dataset before adversarial training (N/A if --attack unspecified)",
)
parser.add_argument(
"--attack-period",
type=int,
default=1,
help="How often (in epochs) to generate a new adversarial training set (N/A if --attack unspecified)",
)
parser.add_argument(
"--augment", type=str, default=None, help="Augmentation recipe to use",
)

View File

@@ -28,7 +28,15 @@ class QueryHandler:
except Exception:
probs = []
for s, w in zip(sentences, swapped_words):
probs.append(self.try_query([s], [w], batch_size=1)[0])
try:
probs.append(self.try_query([s], [w], batch_size=1)[0])
except RuntimeError:
print(
"WARNING: got runtime error trying languag emodel on language model w s/w",
s,
w,
)
probs.append(float("-inf"))
return probs
def try_query(self, sentences, swapped_words, batch_size=32):
@@ -61,6 +69,8 @@ class QueryHandler:
hidden = self.model.init_hidden(len(batch))
source = word_idxs[:-1, :]
target = word_idxs[1:, :]
if (not len(source)) or not len(hidden):
return [float("-inf")] * len(batch)
decode, hidden = self.model(source, hidden)
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
for i in range(len(batch)):

View File

@@ -32,12 +32,13 @@ class GoalFunction(ABC):
use_cache=True,
query_budget=float("inf"),
model_batch_size=32,
model_cache_size=2 ** 18,
model_cache_size=2 ** 20,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model.__class__
)
self.model = model
self.model.eval()
self.maximizable = maximizable
self.tokenizer = tokenizer
if not self.tokenizer:

View File

@@ -2,14 +2,10 @@ import csv
import pandas as pd
# from textattack.attack_results import FailedAttackResult
from textattack.shared import logger
from textattack.shared import AttackedText, logger
from .logger import Logger
# import os
# import sys
class CSVLogger(Logger):
"""Logs attack results to a CSV."""
@@ -22,6 +18,8 @@ class CSVLogger(Logger):
def log_attack_result(self, result):
original_text, perturbed_text = result.diff_color(self.color_method)
original_text = original_text.replace("\n", AttackedText.SPLIT_TOKEN)
perturbed_text = perturbed_text.replace("\n", AttackedText.SPLIT_TOKEN)
result_type = result.__class__.__name__.replace("AttackResult", "")
row = {
"original_text": original_text,

View File

@@ -31,6 +31,7 @@ class LSTMForClassification(nn.Module):
# so if that's all we have, this will display a warning.
dropout = 0
self.drop = nn.Dropout(dropout)
self.emb_layer_trainable = emb_layer_trainable
self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable)
self.word2id = self.emb_layer.word2id
self.encoder = nn.LSTM(

View File

@@ -58,7 +58,9 @@ class GreedyWordSwapWIR(SearchMethod):
initial_result, leave_one_texts
)
softmax_saliency_scores = softmax(torch.Tensor(saliency_scores)).numpy()
softmax_saliency_scores = softmax(
torch.Tensor(saliency_scores), dim=0
).numpy()
# compute the largest change in score we can find by swapping each word
delta_ps = []
@@ -72,9 +74,7 @@ class GreedyWordSwapWIR(SearchMethod):
# no valid synonym substitutions for this word
delta_ps.append(0.0)
continue
swap_results, _ = self.get_goal_results(
transformed_text_candidates, initial_result.output
)
swap_results, _ = self.get_goal_results(transformed_text_candidates)
score_change = [result.score for result in swap_results]
max_score_change = np.max(score_change)
delta_ps.append(max_score_change)

View File

@@ -38,7 +38,7 @@ class Attack:
constraints=[],
transformation=None,
search_method=None,
constraint_cache_size=2 ** 18,
constraint_cache_size=2 ** 20,
):
"""Initialize an attack object.
@@ -150,7 +150,8 @@ class Attack:
self, transformed_texts, current_text, original_text=None
):
"""Filters a list of potential transformed texts based on
``self.constraints`` Checks cache first.
``self.constraints`` Utilizes an LRU cache to attempt to avoid
recomputing common transformations.
Args:
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.

View File

@@ -14,8 +14,7 @@ from . import logger
# A list of goal functions and the corresponding available models.
MODELS_BY_GOAL_FUNCTIONS = {
(TargetedClassification, UntargetedClassification, InputReduction): [
r"^textattack.models.classification.*",
r"^textattack.models.entailment.*",
r"^textattack.models.lstm_for_classification.*",
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
],
(NonOverlappingOutput,): [

View File

@@ -52,6 +52,7 @@ class WordSwapGradientBased(Transformation):
word_index (int): index of the word to replace
"""
self.model.train()
self.model.emb_layer.embedding.weight.requires_grad = True
lookup_table = self.model.lookup_table.to(utils.device)
lookup_table_transpose = lookup_table.transpose(0, 1)
@@ -105,6 +106,9 @@ class WordSwapGradientBased(Transformation):
break
self.model.eval()
self.model.emb_layer.embedding.weight.requires_grad = (
self.model.emb_layer_trainable
)
return candidates
def _call_model(self, text_ids):

View File

@@ -53,10 +53,11 @@ class WordSwapHowNet(WordSwap):
def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
words_str = " ".join(words)
word_list, pos_list = zip_flair_result(
self._flair_pos_tagger.predict(words_str)[0]
)
sentence = Sentence(" ".join(words))
# in-place POS tagging
self._flair_pos_tagger.predict(sentence)
word_list, pos_list = zip_flair_result(sentence)
assert len(words) == len(
word_list
), "Part-of-speech tagger returned incorrect number of tags"

View File

@@ -93,7 +93,7 @@ class WordSwapMaskedLM(WordSwap):
replacement_words = []
for id in top_ids:
token = self._lm_tokenizer.convert_ids_to_tokens(id)
if utils.is_one_word(token):
if utils.is_one_word(token) and not check_if_subword(token):
replacement_words.append(token)
return replacement_words
@@ -141,7 +141,7 @@ class WordSwapMaskedLM(WordSwap):
replacement_words = []
for id in top_preds:
token = self._lm_tokenizer.convert_ids_to_tokens(id)
if utils.is_one_word(token):
if utils.is_one_word(token) and not check_if_subword(token):
replacement_words.append(token)
return replacement_words
else:
@@ -231,3 +231,7 @@ def recover_word_case(word, reference_word):
else:
# if other, just do not alter the word's case
return word
def check_if_subword(text):
return True if "##" in text else False