mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
support regression; write a little readme
This commit is contained in:
@@ -16,8 +16,9 @@ torch
|
||||
transformers>=2.5.1
|
||||
tensorflow>=2
|
||||
tensorflow_hub
|
||||
tensorboardX
|
||||
terminaltables
|
||||
tqdm
|
||||
visdom
|
||||
wandb
|
||||
flair
|
||||
flair
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
import collections
|
||||
import textattack
|
||||
import numpy as np
|
||||
import re
|
||||
@@ -58,9 +58,13 @@ class PeekDatasetCommand(TextAttackCommand):
|
||||
logger.info('Last sample:')
|
||||
print(attacked_texts[-1].printable_text(), '\n')
|
||||
|
||||
outputs = set(outputs)
|
||||
logger.info(f'Found {len(outputs)} distinct outputs:')
|
||||
print(sorted(outputs))
|
||||
logger.info(f'Found {len(set(outputs))} distinct outputs.')
|
||||
if len(outputs) < 20:
|
||||
print(sorted(set(outputs)))
|
||||
|
||||
logger.info(f'Most common outputs:')
|
||||
for i, (key, value) in enumerate(collections.Counter(outputs).most_common(20)):
|
||||
print('\t', str(key)[:5].ljust(5), f' ({value})')
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import scipy
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
@@ -11,7 +12,7 @@ import tqdm
|
||||
import textattack
|
||||
import transformers
|
||||
|
||||
from .train_args_helpers import dataset_from_args, model_from_args
|
||||
from .train_args_helpers import dataset_from_args, model_from_args, write_readme
|
||||
|
||||
device = textattack.shared.utils.device
|
||||
logger = textattack.shared.logger
|
||||
@@ -31,34 +32,17 @@ def batch_encode(tokenizer, text_list):
|
||||
|
||||
|
||||
def train_model(args):
|
||||
logger.warn("WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues.")
|
||||
start_time = time.time()
|
||||
make_directories(args.output_dir)
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
||||
# Start Tensorboard and log hyperparams.
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
tb_writer = SummaryWriter(args.output_dir)
|
||||
def is_writable_type(obj):
|
||||
for ok_type in [bool, int, str, float]:
|
||||
if isinstance(obj, ok_type): return True
|
||||
return False
|
||||
args_dict = {k: v for k,v in vars(args).items() if is_writable_type(v)}
|
||||
|
||||
tb_writer.add_hparams(args_dict, {})
|
||||
|
||||
# Save logger writes to file
|
||||
log_txt_path = os.path.join(args.output_dir, "log.txt")
|
||||
fh = logging.FileHandler(log_txt_path)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
logger.info(f"Writing logs to {log_txt_path}.")
|
||||
|
||||
# Save args to file
|
||||
args_save_path = os.path.join(args.output_dir, "train_args.json")
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(args_dict, indent=2) + "\n")
|
||||
logger.info(f"Wrote training args to {args_save_path}.")
|
||||
|
||||
# Use Weights & Biases, if enabled.
|
||||
if args.enable_wandb:
|
||||
@@ -89,8 +73,16 @@ def train_model(args):
|
||||
|
||||
label_id_len = len(train_labels)
|
||||
label_set = set(train_labels)
|
||||
num_labels = len(label_set)
|
||||
logger.info(f"Loaded dataset. Found: {num_labels} labels: ({sorted(label_set)})")
|
||||
args.num_labels = len(label_set)
|
||||
logger.info(f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})")
|
||||
|
||||
if isinstance(train_labels[0], float):
|
||||
# TODO come up with a more sophisticated scheme for when to do regression
|
||||
logger.warn(f"Detected float labels. Doing regression.")
|
||||
args.num_labels = 1
|
||||
args.do_regression = True
|
||||
else:
|
||||
args.do_regression = False
|
||||
|
||||
train_examples_len = len(train_text)
|
||||
|
||||
@@ -103,7 +95,7 @@ def train_model(args):
|
||||
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})"
|
||||
)
|
||||
|
||||
model = model_from_args(args, num_labels)
|
||||
model = model_from_args(args, args.num_labels)
|
||||
tokenizer = model.tokenizer
|
||||
|
||||
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
|
||||
@@ -112,7 +104,6 @@ def train_model(args):
|
||||
eval_text_ids = batch_encode(tokenizer, eval_text)
|
||||
load_time = time.time()
|
||||
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
|
||||
# print(model)
|
||||
|
||||
# multi-gpu training
|
||||
if num_gpus > 1:
|
||||
@@ -151,6 +142,26 @@ def train_model(args):
|
||||
|
||||
global_step = 0
|
||||
|
||||
# Start Tensorboard and log hyperparams.
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
tb_writer = SummaryWriter(args.output_dir)
|
||||
def is_writable_type(obj):
|
||||
for ok_type in [bool, int, str, float]:
|
||||
if isinstance(obj, ok_type): return True
|
||||
return False
|
||||
args_dict = {k: v for k,v in vars(args).items() if is_writable_type(v)}
|
||||
|
||||
tb_writer.add_hparams(args_dict, {})
|
||||
|
||||
# Save args to file
|
||||
args_save_path = os.path.join(args.output_dir, "train_args.json")
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(args_dict, indent=2) + "\n")
|
||||
logger.info(f"Wrote training args to {args_save_path}.")
|
||||
|
||||
|
||||
# Start training
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f"\tNum examples = {train_examples_len}")
|
||||
logger.info(f"\tBatch size = {args.batch_size}")
|
||||
@@ -175,27 +186,38 @@ def train_model(args):
|
||||
eval_data, sampler=eval_sampler, batch_size=args.batch_size
|
||||
)
|
||||
|
||||
def get_eval_acc():
|
||||
def get_eval_score():
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
for input_ids, labels in tqdm.tqdm(eval_dataloader, desc="Evaluating accuracy"):
|
||||
logits = []
|
||||
labels = []
|
||||
for input_ids, batch_labels in tqdm.tqdm(eval_dataloader, desc="Evaluating accuracy"):
|
||||
if isinstance(input_ids, dict):
|
||||
## HACK: dataloader collates dict backwards. This is a temporary
|
||||
# workaround to get ids in the right shape
|
||||
input_ids = {
|
||||
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
|
||||
}
|
||||
labels = labels.to(device)
|
||||
batch_labels = batch_labels.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
correct += (logits.argmax(dim=1) == labels).sum()
|
||||
total += len(labels)
|
||||
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
logits.extend(batch_logits.cpu().squeeze().tolist())
|
||||
labels.extend(batch_labels)
|
||||
|
||||
model.train()
|
||||
return float(correct) / total
|
||||
logits = torch.tensor(logits)
|
||||
labels = torch.tensor(labels)
|
||||
|
||||
if args.do_regression:
|
||||
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
|
||||
return pearson_correlation
|
||||
else:
|
||||
preds = logits.argmax(dim=1)
|
||||
correct = (preds == labels).sum()
|
||||
return float(correct) / len(eval_dataloader)
|
||||
|
||||
def save_model():
|
||||
model_to_save = (
|
||||
@@ -213,7 +235,7 @@ def train_model(args):
|
||||
# no config
|
||||
pass
|
||||
|
||||
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
|
||||
tqdm.tqdm.write(f"Best acc found. Saved model to {args.output_dir}.")
|
||||
|
||||
global_step = 0
|
||||
|
||||
@@ -226,11 +248,11 @@ def train_model(args):
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info(f"Checkpoint saved to {output_dir}.")
|
||||
tqdm.tqdm.write(f"Checkpoint saved to {output_dir}.")
|
||||
|
||||
model.train()
|
||||
best_eval_acc = 0
|
||||
steps_since_best_eval_acc = 0
|
||||
best_eval_score = 0
|
||||
steps_since_best_eval_score = 0
|
||||
|
||||
def loss_backward(loss):
|
||||
if num_gpus > 1:
|
||||
@@ -253,8 +275,13 @@ def train_model(args):
|
||||
}
|
||||
logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
loss = torch.nn.CrossEntropyLoss()(logits, labels)
|
||||
if args.do_regression:
|
||||
# TODO integrate with textattack `metrics` package
|
||||
loss_fct = torch.nn.MSELoss()
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
loss = loss_backward(loss)
|
||||
|
||||
if global_step % args.tb_writer_step == 0:
|
||||
@@ -275,21 +302,21 @@ def train_model(args):
|
||||
global_step += 1
|
||||
|
||||
# Check accuracy after each epoch.
|
||||
eval_acc = get_eval_acc()
|
||||
tb_writer.add_scalar("epoch_eval_acc", eval_acc, global_step)
|
||||
eval_score = get_eval_score()
|
||||
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
|
||||
|
||||
if args.checkpoint_every_epoch:
|
||||
save_model_checkpoint()
|
||||
|
||||
logger.info(f"Eval acc: {eval_acc*100}%")
|
||||
if eval_acc > best_eval_acc:
|
||||
best_eval_acc = eval_acc
|
||||
steps_since_best_eval_acc = 0
|
||||
tqdm.tqdm.write(f"Eval acc: {eval_score*100}%")
|
||||
if eval_score > best_eval_score:
|
||||
best_eval_score = eval_score
|
||||
steps_since_best_eval_score = 0
|
||||
save_model()
|
||||
else:
|
||||
steps_since_best_eval_acc += 1
|
||||
steps_since_best_eval_score += 1
|
||||
if (args.early_stopping_epochs > 0) and (
|
||||
steps_since_best_eval_acc > args.early_stopping_epochs
|
||||
steps_since_best_eval_score > args.early_stopping_epochs
|
||||
):
|
||||
logger.info(
|
||||
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
|
||||
@@ -302,3 +329,7 @@ def train_model(args):
|
||||
logger.info(f"Saved tokenizer {tokenizer} to {args.output_dir}.")
|
||||
except AttributeError:
|
||||
logger.warn(f"Error: could not save tokenizer {tokenizer} to {args.output_dir}.")
|
||||
|
||||
# Save a little readme with model info
|
||||
write_readme(args, best_eval_score)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import textattack
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
def prepare_dataset_for_training(nlp_dataset):
|
||||
""" Changes an `nlp` dataset into the proper format for tokenization. """
|
||||
@@ -105,3 +107,30 @@ def model_from_args(args, num_labels):
|
||||
model = model.to(textattack.shared.utils.device)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def write_readme(args, best_eval_score):
|
||||
# Save args to file
|
||||
readme_save_path = os.path.join(args.output_dir, "README.md")
|
||||
dataset_name = args.dataset.split(':')[0] if ':' in args.dataset else args.dataset
|
||||
task_name = "regression" if args.do_regression else "classification"
|
||||
loss_func = "mean squared error" if args.do_regression else "cross-entropy"
|
||||
metric_name = "pearson correlation" if args.do_regression else "accuracy"
|
||||
readme_text = f"""
|
||||
## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
|
||||
|
||||
This `{args.model}` model was fine-tuned for sequence classificationusing TextAttack
|
||||
and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned
|
||||
for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning
|
||||
rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}.
|
||||
Since this was a {task_name} task, the model was trained with a {loss_func} loss function.
|
||||
The best score the model achieved on this task was {best_eval_score}, as measured by the
|
||||
eval set {metric_name}.
|
||||
|
||||
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
|
||||
|
||||
"""
|
||||
with open(readme_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(readme_text.strip() + "\n")
|
||||
logger.info(f"Wrote README to {readme_save_path}.")
|
||||
|
||||
@@ -56,7 +56,7 @@ class TrainModelCommand(TextAttackCommand):
|
||||
"(can automatically detect 'train'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-dev-split", '--dataset-val-split', '--dev-split',
|
||||
"--dataset-dev-split", '--dataset-eval-split', '--dev-split',
|
||||
type=str,
|
||||
default='',
|
||||
help="val dataset split, if non-standard "
|
||||
|
||||
@@ -72,7 +72,7 @@ class HuggingFaceNLPDataset(TextAttackDataset):
|
||||
shuffle=False,
|
||||
):
|
||||
|
||||
subset_print_str = f", subset `{_cb(subset)}`" if subset else ""
|
||||
subset_print_str = f", subset {_cb(subset)}" if subset else ""
|
||||
textattack.shared.logger.info(f"Loading {_cb('nlp')} dataset {_cb(name)}{subset_print_str}, split {_cb(split)}.")
|
||||
dataset = nlp.load_dataset(name, subset)
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user