1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

entailment works; add peek-dataset command

This commit is contained in:
Jack Morris
2020-06-23 20:22:03 -04:00
parent 69e114099b
commit 5991aa6d4e
15 changed files with 159 additions and 54 deletions

View File

@@ -333,6 +333,8 @@ def parse_dataset_from_args(args):
dataset_args = args.dataset_from_nlp
if ":" in dataset_args:
dataset_args = dataset_args.split(":")
else:
dataset_args = (dataset_args,)
dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, shuffle=args.shuffle
)

View File

@@ -0,0 +1 @@
from .eval_model_command import EvalModelCommand

View File

@@ -0,0 +1,73 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import textattack
import numpy as np
import re
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args_helpers import add_dataset_args, parse_dataset_from_args
def _cb(s):
return textattack.shared.utils.color_text(str(s), color="red", method="ansi")
logger = textattack.shared.logger
import re
class PeekDatasetCommand(TextAttackCommand):
"""
The peek dataset module:
Takes a peek into a dataset in textattack.
"""
def run(self, args):
UPPERCASE_LETTERS_REGEX = re.compile('[A-Z]')
args.model = None # set model to None for parse_dataset_from_args to work
dataset = parse_dataset_from_args(args)
num_words = []
attacked_texts = []
data_all_lowercased = True
outputs = []
for inputs, output in dataset:
at = textattack.shared.AttackedText(inputs)
if data_all_lowercased:
# Test if any of the letters in the string are lowercase.
if re.search(UPPERCASE_LETTERS_REGEX, at.text):
data_all_lowercased = False
attacked_texts.append(at)
num_words.append(len(at.words))
outputs.append(output)
logger.info(f'Number of samples: {_cb(len(attacked_texts))}')
logger.info(f'Number of words per input:')
num_words = np.array(num_words)
logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}')
mean_words = f'{num_words.mean():.2f}'
logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}')
std_words = f'{num_words.std():.2f}'
logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}')
logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}')
logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}')
logger.info(f'Dataset lowercased: {_cb(data_all_lowercased)}')
logger.info('First sample:')
print(attacked_texts[0].printable_text(), '\n')
logger.info('Last sample:')
print(attacked_texts[-1].printable_text(), '\n')
outputs = set(outputs)
logger.info(f'Found {len(outputs)} distinct outputs:')
print(sorted(outputs))
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"peek-dataset", help="show main statistics about a dataset",
formatter_class=ArgumentDefaultsHelpFormatter,
)
add_dataset_args(parser)
parser.set_defaults(func=PeekDatasetCommand())

View File

@@ -8,9 +8,9 @@ from textattack.commands.augment import AugmentCommand
from textattack.commands.benchmark_recipe import BenchmarkRecipeCommand
from textattack.commands.eval_model import EvalModelCommand
from textattack.commands.list_things import ListThingsCommand
from textattack.commands.peek_dataset import PeekDatasetCommand
from textattack.commands.train_model import TrainModelCommand
def main():
parser = argparse.ArgumentParser(
"TextAttack CLI",
@@ -27,6 +27,7 @@ def main():
EvalModelCommand.register_subcommand(subparsers)
ListThingsCommand.register_subcommand(subparsers)
TrainModelCommand.register_subcommand(subparsers)
PeekDatasetCommand.register_subcommand(subparsers)
# Let's go
args = parser.parse_args()

View File

@@ -7,6 +7,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
import tqdm
import textattack
import transformers
from .train_args_helpers import dataset_from_args, model_from_args
@@ -19,16 +20,15 @@ def make_directories(output_dir):
os.makedirs(output_dir)
def encode_batch(tokenizer, text_list):
try:
return tokenizer.encode_batch(text_list)
except AttributeError:
return [tokenizer.encode(text) for text in text_list]
def batch_encode(tokenizer, text_list):
return [tokenizer.encode(text_input) for text_input in text_list]
# try:
# return tokenizer.batch_encode(text_list) # TODO get batch encoding to work with fast tokenizer
# except AttributeError:
# return [tokenizer.encode(text_input) for text_input in text_list]
def train_model(args):
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
start_time = time.time()
make_directories(args.output_dir)
@@ -38,8 +38,12 @@ def train_model(args):
from tensorboardX import SummaryWriter
tb_writer = SummaryWriter(args.output_dir)
args_dict = vars(args)
del args_dict["func"]
def is_writable_type(obj):
for ok_type in [bool, int, str, float]:
if isinstance(obj, ok_type): return True
return False
args_dict = {k: v for k,v in vars(args).items() if is_writable_type(v)}
tb_writer.add_hparams(args_dict, {})
# Use Weights & Biases, if enabled.
@@ -88,9 +92,9 @@ def train_model(args):
model = model_from_args(args, num_labels)
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
train_text_ids = encode_batch(model.tokenizer, train_text)
logger.info(f"Tokenizing test data (len: {len(eval_labels)})")
eval_text_ids = encode_batch(model.tokenizer, eval_text)
train_text_ids = batch_encode(model.tokenizer, train_text)
logger.info(f"Tokenizing eval data (len: {len(eval_labels)})")
eval_text_ids = batch_encode(model.tokenizer, eval_text)
load_time = time.time()
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
# print(model)
@@ -122,9 +126,9 @@ def train_model(args):
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = transformers.optimization.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
scheduler = get_linear_schedule_with_warmup(
scheduler = transformers.optimization.get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_proportion,
num_training_steps=num_train_optimization_steps,
@@ -133,12 +137,12 @@ def train_model(args):
global_step = 0
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_examples_len)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Max sequence length = %d", args.max_length)
logger.info(" Num steps = %d", num_train_optimization_steps)
logger.info(" Num epochs = %d", args.num_train_epochs)
logger.info(" Learning rate = %d", args.learning_rate)
logger.info(f"\tNum examples = {train_examples_len}")
logger.info(f"\tBatch size = {args.batch_size}")
logger.info(f"\tMax sequence length = {args.max_length}")
logger.info(f"\tNum steps = {num_train_optimization_steps}")
logger.info(f"\tNum epochs = {args.num_train_epochs}")
logger.info(f"\tLearning rate = {args.learning_rate}")
train_input_ids = np.array(train_text_ids)
train_labels = np.array(train_labels)
@@ -219,6 +223,7 @@ def train_model(args):
if args.grad_accum_steps > 1:
loss = loss / args.grad_accum_steps
loss.backward()
return loss
for _ in tqdm.trange(int(args.num_train_epochs), desc="Epoch"):
prog_bar = tqdm.tqdm(train_dataloader, desc="Iteration")
@@ -231,20 +236,20 @@ def train_model(args):
input_ids = {
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
}
logits = textattack.shared.utils.model_predict(model, input_ids)
loss_fct = torch.nn.CrossEntropyLoss()
loss = torch.nn.CrossEntropyLoss()(logits, labels)
input_ids['labels'] = labels # @TODO change this back to calculate loss in this body
loss = model(**input_ids)[0]
loss = loss_backward(loss)
if global_step % args.tb_writer_step == 0:
tb_writer.add_scalar("loss", loss, global_step)
tb_writer.add_scalar("lr", loss, global_step)
loss_backward(loss)
tb_writer.add_scalar("loss", loss.item(), global_step)
tb_writer.add_scalar("lr", scheduler.get_last_lr()[0], global_step)
prog_bar.set_description(f"Loss {loss.item()}")
if (step + 1) % args.grad_accum_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Save model checkpoint to file.
if global_step > 0 and global_step % args.checkpoint_steps == 0:
if global_step > 0 and (args.checkpoint_steps > 0) and (global_step % args.checkpoint_steps) == 0:
save_model_checkpoint()
model.zero_grad()
@@ -273,3 +278,10 @@ def train_model(args):
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
)
break
# end of training, save tokenizer
try:
tokenizer.save_pretrained(args.output_dir)
logger.info(f"Saved tokenizer {tokenizer} to {args.output_dir}.")
except AttributeError:
logger.warn(f"Error: could not save tokenizer {tokenizer} to {args.output_dir}.")

View File

@@ -27,6 +27,11 @@ def dataset_from_args(args):
"""
dataset_args = args.dataset.split(":")
# TODO `HuggingFaceNLPDataset` -> `HuggingFaceDataset`
if args.dataset_train_split:
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split=args.dataset_train_split
)
else:
try:
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split="train"
@@ -35,9 +40,9 @@ def dataset_from_args(args):
raise KeyError(f"Error: no `train` split found in `{args.dataset}` dataset")
train_text, train_labels = prepare_dataset_for_training(train_dataset)
if args.dataset_split:
if args.dataset_dev_split:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split=args.dataset_split
*dataset_args, split=args.dataset_dev_split
)
else:
# try common dev split names
@@ -61,7 +66,6 @@ def dataset_from_args(args):
)
eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
return train_text, train_labels, eval_text, eval_labels
@@ -94,7 +98,7 @@ def model_from_args(args, num_labels):
config=config,
)
tokenizer = textattack.models.tokenizers.AutoTokenizer(
args.model, use_fast=False, max_length=args.max_length
args.model, use_fast=True, max_length=args.max_length
)
setattr(model, "tokenizer", tokenizer)

View File

@@ -49,10 +49,17 @@ class TrainModelCommand(TextAttackCommand):
" ex: `glue:sst2` or `rotten_tomatoes`",
)
parser.add_argument(
"--dataset-split",
"--dataset-train-split", '--train-split',
type=str,
default='',
help="dataset split, if non-standard "
help="train dataset split, if non-standard "
"(can automatically detect 'train'",
)
parser.add_argument(
"--dataset-dev-split", '--dataset-val-split', '--dev-split',
type=str,
default='',
help="val dataset split, if non-standard "
"(can automatically detect 'dev', 'validation', 'eval')",
)
parser.add_argument(
@@ -64,10 +71,9 @@ class TrainModelCommand(TextAttackCommand):
parser.add_argument(
"--checkpoint-steps",
type=int,
default=5000,
help="save model after this many steps",
default=-1,
help="save model after this many steps (-1 for no checkpointing)",
)
parser.add_argument(
"--checkpoint-every_epoch",
action="store_true",
@@ -102,7 +108,6 @@ class TrainModelCommand(TextAttackCommand):
help="Maximum length of a sequence (anything beyond this will "
"be truncated)",
)
parser.add_argument(
"--learning-rate",
"--lr",
@@ -110,7 +115,6 @@ class TrainModelCommand(TextAttackCommand):
default=2e-5,
help="Learning rate for Adam Optimization",
)
parser.add_argument(
"--grad-accum-steps",
type=int,
@@ -118,33 +122,28 @@ class TrainModelCommand(TextAttackCommand):
help="Number of steps to accumulate gradients before optimizing, "
"advancing scheduler, etc.",
)
parser.add_argument(
"--warmup-proportion",
type=float,
default=0.1,
help="Warmup proportion for linear scheduling",
)
parser.add_argument(
"--config-name",
type=str,
default="config.json",
help="Filename to save BERT config as",
)
parser.add_argument(
"--weights-name",
type=str,
default="pytorch_model.bin",
help="Filename to save model weights as",
)
parser.add_argument(
"--enable-wandb",
default=False,
action="store_true",
help="log metrics to Weights & Biases",
)
parser.set_defaults(func=TrainModelCommand())

View File

@@ -2,6 +2,7 @@ import collections
import random
import nlp
import textattack
from textattack.datasets import TextAttackDataset
from textattack.shared import AttackedText
@@ -67,6 +68,9 @@ class HuggingFaceNLPDataset(TextAttackDataset):
dataset_columns=None,
shuffle=False,
):
subset_print_str = f", subset `{subset}`" if subset else ""
textattack.shared.logger.info(f"Loading `nlp` dataset `{name}`{subset_print_str}, split `{split}`.")
dataset = nlp.load_dataset(name, subset)
(
self.input_columns,

View File

@@ -28,6 +28,8 @@ class AutoTokenizer(Tokenizer):
name, use_fast=use_fast
)
self.max_length = max_length
print(f'AutoTokenizer using {"fast" if use_fast else "slow"} tokenizer')
self.save_pretrained = self.tokenizer.save_pretrained
def encode(self, input_text):
""" Encodes ``input_text``.
@@ -36,19 +38,24 @@ class AutoTokenizer(Tokenizer):
model takes 1 or multiple inputs. The ``transformers.AutoTokenizer``
will automatically handle either case.
"""
if isinstance(input_text, str):
input_text = (input_text, )
encoded_text = self.tokenizer.encode_plus(
input_text,
*input_text,
max_length=self.max_length,
add_special_tokens=True,
pad_to_max_length=True,
truncation=True
)
return dict(encoded_text)
def encode_batch(self, input_text_list):
def batch_encode(self, input_text_list):
""" The batch equivalent of ``encode``."""
if hasattr(self.tokenizer, "encode_batch"):
return self.tokenizer.encode_batch(
if hasattr(self.tokenizer, "batch_encode_plus"):
print('utilizing batch encode')
return self.tokenizer.batch_encode_plus(
input_text_list,
truncation=True,
max_length=self.max_length,
add_special_tokens=True,
pad_to_max_length=True,

View File

@@ -21,6 +21,8 @@ class SpacyTokenizer(Tokenizer):
def convert_text_to_tokens(self, text):
if isinstance(text, tuple):
if len(text) > 1:
raise TypeError('Cannot train LSTM/CNN models with multi-sequence inputs.')
text = text[0]
if not isinstance(text, str):
raise TypeError(