mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
entailment works; add peek-dataset command
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -333,6 +333,8 @@ def parse_dataset_from_args(args):
|
||||
dataset_args = args.dataset_from_nlp
|
||||
if ":" in dataset_args:
|
||||
dataset_args = dataset_args.split(":")
|
||||
else:
|
||||
dataset_args = (dataset_args,)
|
||||
dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, shuffle=args.shuffle
|
||||
)
|
||||
|
||||
1
textattack/commands/eval_model/__init__.py
Normal file
1
textattack/commands/eval_model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .eval_model_command import EvalModelCommand
|
||||
73
textattack/commands/peek_dataset.py
Normal file
73
textattack/commands/peek_dataset.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
import textattack
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
from textattack.commands import TextAttackCommand
|
||||
from textattack.commands.attack.attack_args_helpers import add_dataset_args, parse_dataset_from_args
|
||||
|
||||
|
||||
def _cb(s):
|
||||
return textattack.shared.utils.color_text(str(s), color="red", method="ansi")
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
import re
|
||||
|
||||
class PeekDatasetCommand(TextAttackCommand):
|
||||
"""
|
||||
The peek dataset module:
|
||||
|
||||
Takes a peek into a dataset in textattack.
|
||||
"""
|
||||
def run(self, args):
|
||||
UPPERCASE_LETTERS_REGEX = re.compile('[A-Z]')
|
||||
|
||||
args.model = None # set model to None for parse_dataset_from_args to work
|
||||
dataset = parse_dataset_from_args(args)
|
||||
|
||||
num_words = []
|
||||
attacked_texts = []
|
||||
data_all_lowercased = True
|
||||
outputs = []
|
||||
for inputs, output in dataset:
|
||||
at = textattack.shared.AttackedText(inputs)
|
||||
if data_all_lowercased:
|
||||
# Test if any of the letters in the string are lowercase.
|
||||
if re.search(UPPERCASE_LETTERS_REGEX, at.text):
|
||||
data_all_lowercased = False
|
||||
attacked_texts.append(at)
|
||||
num_words.append(len(at.words))
|
||||
outputs.append(output)
|
||||
|
||||
logger.info(f'Number of samples: {_cb(len(attacked_texts))}')
|
||||
logger.info(f'Number of words per input:')
|
||||
num_words = np.array(num_words)
|
||||
logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}')
|
||||
mean_words = f'{num_words.mean():.2f}'
|
||||
logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}')
|
||||
std_words = f'{num_words.std():.2f}'
|
||||
logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}')
|
||||
logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}')
|
||||
logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}')
|
||||
logger.info(f'Dataset lowercased: {_cb(data_all_lowercased)}')
|
||||
|
||||
logger.info('First sample:')
|
||||
print(attacked_texts[0].printable_text(), '\n')
|
||||
logger.info('Last sample:')
|
||||
print(attacked_texts[-1].printable_text(), '\n')
|
||||
|
||||
outputs = set(outputs)
|
||||
logger.info(f'Found {len(outputs)} distinct outputs:')
|
||||
print(sorted(outputs))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"peek-dataset", help="show main statistics about a dataset",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
add_dataset_args(parser)
|
||||
parser.set_defaults(func=PeekDatasetCommand())
|
||||
@@ -8,9 +8,9 @@ from textattack.commands.augment import AugmentCommand
|
||||
from textattack.commands.benchmark_recipe import BenchmarkRecipeCommand
|
||||
from textattack.commands.eval_model import EvalModelCommand
|
||||
from textattack.commands.list_things import ListThingsCommand
|
||||
from textattack.commands.peek_dataset import PeekDatasetCommand
|
||||
from textattack.commands.train_model import TrainModelCommand
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
"TextAttack CLI",
|
||||
@@ -27,6 +27,7 @@ def main():
|
||||
EvalModelCommand.register_subcommand(subparsers)
|
||||
ListThingsCommand.register_subcommand(subparsers)
|
||||
TrainModelCommand.register_subcommand(subparsers)
|
||||
PeekDatasetCommand.register_subcommand(subparsers)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -7,6 +7,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
import tqdm
|
||||
|
||||
import textattack
|
||||
import transformers
|
||||
|
||||
from .train_args_helpers import dataset_from_args, model_from_args
|
||||
|
||||
@@ -19,16 +20,15 @@ def make_directories(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
|
||||
def encode_batch(tokenizer, text_list):
|
||||
try:
|
||||
return tokenizer.encode_batch(text_list)
|
||||
except AttributeError:
|
||||
return [tokenizer.encode(text) for text in text_list]
|
||||
def batch_encode(tokenizer, text_list):
|
||||
return [tokenizer.encode(text_input) for text_input in text_list]
|
||||
# try:
|
||||
# return tokenizer.batch_encode(text_list) # TODO get batch encoding to work with fast tokenizer
|
||||
# except AttributeError:
|
||||
# return [tokenizer.encode(text_input) for text_input in text_list]
|
||||
|
||||
|
||||
def train_model(args):
|
||||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
start_time = time.time()
|
||||
make_directories(args.output_dir)
|
||||
|
||||
@@ -38,8 +38,12 @@ def train_model(args):
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
tb_writer = SummaryWriter(args.output_dir)
|
||||
args_dict = vars(args)
|
||||
del args_dict["func"]
|
||||
def is_writable_type(obj):
|
||||
for ok_type in [bool, int, str, float]:
|
||||
if isinstance(obj, ok_type): return True
|
||||
return False
|
||||
args_dict = {k: v for k,v in vars(args).items() if is_writable_type(v)}
|
||||
|
||||
tb_writer.add_hparams(args_dict, {})
|
||||
|
||||
# Use Weights & Biases, if enabled.
|
||||
@@ -88,9 +92,9 @@ def train_model(args):
|
||||
model = model_from_args(args, num_labels)
|
||||
|
||||
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
|
||||
train_text_ids = encode_batch(model.tokenizer, train_text)
|
||||
logger.info(f"Tokenizing test data (len: {len(eval_labels)})")
|
||||
eval_text_ids = encode_batch(model.tokenizer, eval_text)
|
||||
train_text_ids = batch_encode(model.tokenizer, train_text)
|
||||
logger.info(f"Tokenizing eval data (len: {len(eval_labels)})")
|
||||
eval_text_ids = batch_encode(model.tokenizer, eval_text)
|
||||
load_time = time.time()
|
||||
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
|
||||
# print(model)
|
||||
@@ -122,9 +126,9 @@ def train_model(args):
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
optimizer = transformers.optimization.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
scheduler = transformers.optimization.get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=args.warmup_proportion,
|
||||
num_training_steps=num_train_optimization_steps,
|
||||
@@ -133,12 +137,12 @@ def train_model(args):
|
||||
global_step = 0
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", train_examples_len)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Max sequence length = %d", args.max_length)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
logger.info(" Num epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Learning rate = %d", args.learning_rate)
|
||||
logger.info(f"\tNum examples = {train_examples_len}")
|
||||
logger.info(f"\tBatch size = {args.batch_size}")
|
||||
logger.info(f"\tMax sequence length = {args.max_length}")
|
||||
logger.info(f"\tNum steps = {num_train_optimization_steps}")
|
||||
logger.info(f"\tNum epochs = {args.num_train_epochs}")
|
||||
logger.info(f"\tLearning rate = {args.learning_rate}")
|
||||
|
||||
train_input_ids = np.array(train_text_ids)
|
||||
train_labels = np.array(train_labels)
|
||||
@@ -219,6 +223,7 @@ def train_model(args):
|
||||
if args.grad_accum_steps > 1:
|
||||
loss = loss / args.grad_accum_steps
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
for _ in tqdm.trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
prog_bar = tqdm.tqdm(train_dataloader, desc="Iteration")
|
||||
@@ -231,20 +236,20 @@ def train_model(args):
|
||||
input_ids = {
|
||||
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
|
||||
}
|
||||
logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
loss = torch.nn.CrossEntropyLoss()(logits, labels)
|
||||
input_ids['labels'] = labels # @TODO change this back to calculate loss in this body
|
||||
loss = model(**input_ids)[0]
|
||||
loss = loss_backward(loss)
|
||||
|
||||
if global_step % args.tb_writer_step == 0:
|
||||
tb_writer.add_scalar("loss", loss, global_step)
|
||||
tb_writer.add_scalar("lr", loss, global_step)
|
||||
loss_backward(loss)
|
||||
tb_writer.add_scalar("loss", loss.item(), global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_last_lr()[0], global_step)
|
||||
prog_bar.set_description(f"Loss {loss.item()}")
|
||||
if (step + 1) % args.grad_accum_steps == 0:
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
# Save model checkpoint to file.
|
||||
if global_step > 0 and global_step % args.checkpoint_steps == 0:
|
||||
if global_step > 0 and (args.checkpoint_steps > 0) and (global_step % args.checkpoint_steps) == 0:
|
||||
save_model_checkpoint()
|
||||
|
||||
model.zero_grad()
|
||||
@@ -273,3 +278,10 @@ def train_model(args):
|
||||
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
|
||||
)
|
||||
break
|
||||
|
||||
# end of training, save tokenizer
|
||||
try:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
logger.info(f"Saved tokenizer {tokenizer} to {args.output_dir}.")
|
||||
except AttributeError:
|
||||
logger.warn(f"Error: could not save tokenizer {tokenizer} to {args.output_dir}.")
|
||||
|
||||
@@ -27,17 +27,22 @@ def dataset_from_args(args):
|
||||
"""
|
||||
dataset_args = args.dataset.split(":")
|
||||
# TODO `HuggingFaceNLPDataset` -> `HuggingFaceDataset`
|
||||
try:
|
||||
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split="train"
|
||||
)
|
||||
except KeyError:
|
||||
raise KeyError(f"Error: no `train` split found in `{args.dataset}` dataset")
|
||||
if args.dataset_train_split:
|
||||
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split=args.dataset_train_split
|
||||
)
|
||||
else:
|
||||
try:
|
||||
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split="train"
|
||||
)
|
||||
except KeyError:
|
||||
raise KeyError(f"Error: no `train` split found in `{args.dataset}` dataset")
|
||||
train_text, train_labels = prepare_dataset_for_training(train_dataset)
|
||||
|
||||
if args.dataset_split:
|
||||
if args.dataset_dev_split:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split=args.dataset_split
|
||||
*dataset_args, split=args.dataset_dev_split
|
||||
)
|
||||
else:
|
||||
# try common dev split names
|
||||
@@ -60,7 +65,6 @@ def dataset_from_args(args):
|
||||
f"Could not find `dev` or `test` split in dataset {args.dataset}."
|
||||
)
|
||||
eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
|
||||
|
||||
|
||||
return train_text, train_labels, eval_text, eval_labels
|
||||
|
||||
@@ -94,7 +98,7 @@ def model_from_args(args, num_labels):
|
||||
config=config,
|
||||
)
|
||||
tokenizer = textattack.models.tokenizers.AutoTokenizer(
|
||||
args.model, use_fast=False, max_length=args.max_length
|
||||
args.model, use_fast=True, max_length=args.max_length
|
||||
)
|
||||
setattr(model, "tokenizer", tokenizer)
|
||||
|
||||
|
||||
@@ -49,10 +49,17 @@ class TrainModelCommand(TextAttackCommand):
|
||||
" ex: `glue:sst2` or `rotten_tomatoes`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
"--dataset-train-split", '--train-split',
|
||||
type=str,
|
||||
default='',
|
||||
help="dataset split, if non-standard "
|
||||
help="train dataset split, if non-standard "
|
||||
"(can automatically detect 'train'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-dev-split", '--dataset-val-split', '--dev-split',
|
||||
type=str,
|
||||
default='',
|
||||
help="val dataset split, if non-standard "
|
||||
"(can automatically detect 'dev', 'validation', 'eval')",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -64,10 +71,9 @@ class TrainModelCommand(TextAttackCommand):
|
||||
parser.add_argument(
|
||||
"--checkpoint-steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="save model after this many steps",
|
||||
default=-1,
|
||||
help="save model after this many steps (-1 for no checkpointing)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-every_epoch",
|
||||
action="store_true",
|
||||
@@ -102,7 +108,6 @@ class TrainModelCommand(TextAttackCommand):
|
||||
help="Maximum length of a sequence (anything beyond this will "
|
||||
"be truncated)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
"--lr",
|
||||
@@ -110,7 +115,6 @@ class TrainModelCommand(TextAttackCommand):
|
||||
default=2e-5,
|
||||
help="Learning rate for Adam Optimization",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--grad-accum-steps",
|
||||
type=int,
|
||||
@@ -118,33 +122,28 @@ class TrainModelCommand(TextAttackCommand):
|
||||
help="Number of steps to accumulate gradients before optimizing, "
|
||||
"advancing scheduler, etc.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warmup-proportion",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Warmup proportion for linear scheduling",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config-name",
|
||||
type=str,
|
||||
default="config.json",
|
||||
help="Filename to save BERT config as",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--weights-name",
|
||||
type=str,
|
||||
default="pytorch_model.bin",
|
||||
help="Filename to save model weights as",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-wandb",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="log metrics to Weights & Biases",
|
||||
)
|
||||
|
||||
parser.set_defaults(func=TrainModelCommand())
|
||||
|
||||
@@ -2,6 +2,7 @@ import collections
|
||||
import random
|
||||
|
||||
import nlp
|
||||
import textattack
|
||||
|
||||
from textattack.datasets import TextAttackDataset
|
||||
from textattack.shared import AttackedText
|
||||
@@ -67,6 +68,9 @@ class HuggingFaceNLPDataset(TextAttackDataset):
|
||||
dataset_columns=None,
|
||||
shuffle=False,
|
||||
):
|
||||
|
||||
subset_print_str = f", subset `{subset}`" if subset else ""
|
||||
textattack.shared.logger.info(f"Loading `nlp` dataset `{name}`{subset_print_str}, split `{split}`.")
|
||||
dataset = nlp.load_dataset(name, subset)
|
||||
(
|
||||
self.input_columns,
|
||||
|
||||
@@ -28,6 +28,8 @@ class AutoTokenizer(Tokenizer):
|
||||
name, use_fast=use_fast
|
||||
)
|
||||
self.max_length = max_length
|
||||
print(f'AutoTokenizer using {"fast" if use_fast else "slow"} tokenizer')
|
||||
self.save_pretrained = self.tokenizer.save_pretrained
|
||||
|
||||
def encode(self, input_text):
|
||||
""" Encodes ``input_text``.
|
||||
@@ -36,19 +38,24 @@ class AutoTokenizer(Tokenizer):
|
||||
model takes 1 or multiple inputs. The ``transformers.AutoTokenizer``
|
||||
will automatically handle either case.
|
||||
"""
|
||||
if isinstance(input_text, str):
|
||||
input_text = (input_text, )
|
||||
encoded_text = self.tokenizer.encode_plus(
|
||||
input_text,
|
||||
*input_text,
|
||||
max_length=self.max_length,
|
||||
add_special_tokens=True,
|
||||
pad_to_max_length=True,
|
||||
truncation=True
|
||||
)
|
||||
return dict(encoded_text)
|
||||
|
||||
def encode_batch(self, input_text_list):
|
||||
def batch_encode(self, input_text_list):
|
||||
""" The batch equivalent of ``encode``."""
|
||||
if hasattr(self.tokenizer, "encode_batch"):
|
||||
return self.tokenizer.encode_batch(
|
||||
if hasattr(self.tokenizer, "batch_encode_plus"):
|
||||
print('utilizing batch encode')
|
||||
return self.tokenizer.batch_encode_plus(
|
||||
input_text_list,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
add_special_tokens=True,
|
||||
pad_to_max_length=True,
|
||||
|
||||
@@ -21,6 +21,8 @@ class SpacyTokenizer(Tokenizer):
|
||||
|
||||
def convert_text_to_tokens(self, text):
|
||||
if isinstance(text, tuple):
|
||||
if len(text) > 1:
|
||||
raise TypeError('Cannot train LSTM/CNN models with multi-sequence inputs.')
|
||||
text = text[0]
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(
|
||||
|
||||
Reference in New Issue
Block a user