1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

support regression; write a little readme

This commit is contained in:
Jack Morris
2020-06-24 00:39:01 -04:00
parent 0d91781e88
commit 4014273fb0
6 changed files with 117 additions and 52 deletions

View File

@@ -16,8 +16,9 @@ torch
transformers>=2.5.1
tensorflow>=2
tensorflow_hub
tensorboardX
terminaltables
tqdm
visdom
wandb
flair
flair

View File

@@ -1,5 +1,5 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import collections
import textattack
import numpy as np
import re
@@ -58,9 +58,13 @@ class PeekDatasetCommand(TextAttackCommand):
logger.info('Last sample:')
print(attacked_texts[-1].printable_text(), '\n')
outputs = set(outputs)
logger.info(f'Found {len(outputs)} distinct outputs:')
print(sorted(outputs))
logger.info(f'Found {len(set(outputs))} distinct outputs.')
if len(outputs) < 20:
print(sorted(set(outputs)))
logger.info(f'Most common outputs:')
for i, (key, value) in enumerate(collections.Counter(outputs).most_common(20)):
print('\t', str(key)[:5].ljust(5), f' ({value})')
@staticmethod

View File

@@ -1,6 +1,7 @@
import json
import logging
import os
import scipy
import time
import numpy as np
@@ -11,7 +12,7 @@ import tqdm
import textattack
import transformers
from .train_args_helpers import dataset_from_args, model_from_args
from .train_args_helpers import dataset_from_args, model_from_args, write_readme
device = textattack.shared.utils.device
logger = textattack.shared.logger
@@ -31,34 +32,17 @@ def batch_encode(tokenizer, text_list):
def train_model(args):
logger.warn("WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues.")
start_time = time.time()
make_directories(args.output_dir)
num_gpus = torch.cuda.device_count()
# Start Tensorboard and log hyperparams.
from tensorboardX import SummaryWriter
tb_writer = SummaryWriter(args.output_dir)
def is_writable_type(obj):
for ok_type in [bool, int, str, float]:
if isinstance(obj, ok_type): return True
return False
args_dict = {k: v for k,v in vars(args).items() if is_writable_type(v)}
tb_writer.add_hparams(args_dict, {})
# Save logger writes to file
log_txt_path = os.path.join(args.output_dir, "log.txt")
fh = logging.FileHandler(log_txt_path)
fh.setLevel(logging.DEBUG)
logger.info(f"Writing logs to {log_txt_path}.")
# Save args to file
args_save_path = os.path.join(args.output_dir, "train_args.json")
with open(args_save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(args_dict, indent=2) + "\n")
logger.info(f"Wrote training args to {args_save_path}.")
# Use Weights & Biases, if enabled.
if args.enable_wandb:
@@ -89,8 +73,16 @@ def train_model(args):
label_id_len = len(train_labels)
label_set = set(train_labels)
num_labels = len(label_set)
logger.info(f"Loaded dataset. Found: {num_labels} labels: ({sorted(label_set)})")
args.num_labels = len(label_set)
logger.info(f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})")
if isinstance(train_labels[0], float):
# TODO come up with a more sophisticated scheme for when to do regression
logger.warn(f"Detected float labels. Doing regression.")
args.num_labels = 1
args.do_regression = True
else:
args.do_regression = False
train_examples_len = len(train_text)
@@ -103,7 +95,7 @@ def train_model(args):
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})"
)
model = model_from_args(args, num_labels)
model = model_from_args(args, args.num_labels)
tokenizer = model.tokenizer
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
@@ -112,7 +104,6 @@ def train_model(args):
eval_text_ids = batch_encode(tokenizer, eval_text)
load_time = time.time()
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
# print(model)
# multi-gpu training
if num_gpus > 1:
@@ -151,6 +142,26 @@ def train_model(args):
global_step = 0
# Start Tensorboard and log hyperparams.
from tensorboardX import SummaryWriter
tb_writer = SummaryWriter(args.output_dir)
def is_writable_type(obj):
for ok_type in [bool, int, str, float]:
if isinstance(obj, ok_type): return True
return False
args_dict = {k: v for k,v in vars(args).items() if is_writable_type(v)}
tb_writer.add_hparams(args_dict, {})
# Save args to file
args_save_path = os.path.join(args.output_dir, "train_args.json")
with open(args_save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(args_dict, indent=2) + "\n")
logger.info(f"Wrote training args to {args_save_path}.")
# Start training
logger.info("***** Running training *****")
logger.info(f"\tNum examples = {train_examples_len}")
logger.info(f"\tBatch size = {args.batch_size}")
@@ -175,27 +186,38 @@ def train_model(args):
eval_data, sampler=eval_sampler, batch_size=args.batch_size
)
def get_eval_acc():
def get_eval_score():
model.eval()
correct = 0
total = 0
for input_ids, labels in tqdm.tqdm(eval_dataloader, desc="Evaluating accuracy"):
logits = []
labels = []
for input_ids, batch_labels in tqdm.tqdm(eval_dataloader, desc="Evaluating accuracy"):
if isinstance(input_ids, dict):
## HACK: dataloader collates dict backwards. This is a temporary
# workaround to get ids in the right shape
input_ids = {
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
}
labels = labels.to(device)
batch_labels = batch_labels.to(device)
with torch.no_grad():
logits = textattack.shared.utils.model_predict(model, input_ids)
correct += (logits.argmax(dim=1) == labels).sum()
total += len(labels)
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
logits.extend(batch_logits.cpu().squeeze().tolist())
labels.extend(batch_labels)
model.train()
return float(correct) / total
logits = torch.tensor(logits)
labels = torch.tensor(labels)
if args.do_regression:
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
return pearson_correlation
else:
preds = logits.argmax(dim=1)
correct = (preds == labels).sum()
return float(correct) / len(eval_dataloader)
def save_model():
model_to_save = (
@@ -213,7 +235,7 @@ def train_model(args):
# no config
pass
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
tqdm.tqdm.write(f"Best acc found. Saved model to {args.output_dir}.")
global_step = 0
@@ -226,11 +248,11 @@ def train_model(args):
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info(f"Checkpoint saved to {output_dir}.")
tqdm.tqdm.write(f"Checkpoint saved to {output_dir}.")
model.train()
best_eval_acc = 0
steps_since_best_eval_acc = 0
best_eval_score = 0
steps_since_best_eval_score = 0
def loss_backward(loss):
if num_gpus > 1:
@@ -253,8 +275,13 @@ def train_model(args):
}
logits = textattack.shared.utils.model_predict(model, input_ids)
loss_fct = torch.nn.CrossEntropyLoss()
loss = torch.nn.CrossEntropyLoss()(logits, labels)
if args.do_regression:
# TODO integrate with textattack `metrics` package
loss_fct = torch.nn.MSELoss()
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
loss = loss_backward(loss)
if global_step % args.tb_writer_step == 0:
@@ -275,21 +302,21 @@ def train_model(args):
global_step += 1
# Check accuracy after each epoch.
eval_acc = get_eval_acc()
tb_writer.add_scalar("epoch_eval_acc", eval_acc, global_step)
eval_score = get_eval_score()
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
if args.checkpoint_every_epoch:
save_model_checkpoint()
logger.info(f"Eval acc: {eval_acc*100}%")
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
steps_since_best_eval_acc = 0
tqdm.tqdm.write(f"Eval acc: {eval_score*100}%")
if eval_score > best_eval_score:
best_eval_score = eval_score
steps_since_best_eval_score = 0
save_model()
else:
steps_since_best_eval_acc += 1
steps_since_best_eval_score += 1
if (args.early_stopping_epochs > 0) and (
steps_since_best_eval_acc > args.early_stopping_epochs
steps_since_best_eval_score > args.early_stopping_epochs
):
logger.info(
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
@@ -302,3 +329,7 @@ def train_model(args):
logger.info(f"Saved tokenizer {tokenizer} to {args.output_dir}.")
except AttributeError:
logger.warn(f"Error: could not save tokenizer {tokenizer} to {args.output_dir}.")
# Save a little readme with model info
write_readme(args, best_eval_score)

View File

@@ -1,5 +1,7 @@
import os
import textattack
logger = textattack.shared.logger
def prepare_dataset_for_training(nlp_dataset):
""" Changes an `nlp` dataset into the proper format for tokenization. """
@@ -105,3 +107,30 @@ def model_from_args(args, num_labels):
model = model.to(textattack.shared.utils.device)
return model
def write_readme(args, best_eval_score):
# Save args to file
readme_save_path = os.path.join(args.output_dir, "README.md")
dataset_name = args.dataset.split(':')[0] if ':' in args.dataset else args.dataset
task_name = "regression" if args.do_regression else "classification"
loss_func = "mean squared error" if args.do_regression else "cross-entropy"
metric_name = "pearson correlation" if args.do_regression else "accuracy"
readme_text = f"""
## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
This `{args.model}` model was fine-tuned for sequence classificationusing TextAttack
and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned
for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning
rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}.
Since this was a {task_name} task, the model was trained with a {loss_func} loss function.
The best score the model achieved on this task was {best_eval_score}, as measured by the
eval set {metric_name}.
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
"""
with open(readme_save_path, "w", encoding="utf-8") as f:
f.write(readme_text.strip() + "\n")
logger.info(f"Wrote README to {readme_save_path}.")

View File

@@ -56,7 +56,7 @@ class TrainModelCommand(TextAttackCommand):
"(can automatically detect 'train'",
)
parser.add_argument(
"--dataset-dev-split", '--dataset-val-split', '--dev-split',
"--dataset-dev-split", '--dataset-eval-split', '--dev-split',
type=str,
default='',
help="val dataset split, if non-standard "

View File

@@ -72,7 +72,7 @@ class HuggingFaceNLPDataset(TextAttackDataset):
shuffle=False,
):
subset_print_str = f", subset `{_cb(subset)}`" if subset else ""
subset_print_str = f", subset {_cb(subset)}" if subset else ""
textattack.shared.logger.info(f"Loading {_cb('nlp')} dataset {_cb(name)}{subset_print_str}, split {_cb(split)}.")
dataset = nlp.load_dataset(name, subset)
(