1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

adding AG News models

This commit is contained in:
Jack Morris
2020-07-07 20:45:51 -04:00
parent ad2dba067e
commit c0a3a734d8
7 changed files with 109 additions and 19 deletions

View File

@@ -24,6 +24,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
#
# bert-base-uncased
#
"bert-base-uncased-ag-news": (
"textattack/bert-base-uncased-ag-news",
("ag_news", None, "test"),
),
"bert-base-uncased-cola": (
"textattack/bert-base-uncased-CoLA",
("glue", "cola", "validation"),
@@ -106,6 +110,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
#
# distilbert-base-uncased
#
"distilbert-base-uncased-ag-news": (
"textattack/distilbert-base-uncased-ag-news",
("ag_news", None, "test"),
),
"distilbert-base-uncased-cola": (
"textattack/distilbert-base-cased-CoLA",
("glue", "cola", "validation"),
@@ -137,11 +145,18 @@ HUGGINGFACE_DATASET_BY_MODEL = {
#
# roberta-base (RoBERTa is cased by default)
#
"roberta-base-ag-news": (
"textattack/roberta-base-ag-news",
("ag_news", None, "test"),
),
"roberta-base-cola": (
"textattack/roberta-base-CoLA",
("glue", "cola", "validation"),
),
"roberta-base-imdb": ("textattack/roberta-base-imdb", ("imdb", None, "test"),),
"roberta-base-imdb":
("textattack/roberta-base-imdb",
("imdb", None, "test"),
),
"roberta-base-mr": (
"textattack/textattack/roberta-base-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
@@ -154,7 +169,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
"textattack/roberta-base-QNLI",
("glue", "qnli", "validation"),
),
"roberta-base-rte": ("textattack/roberta-base-RTE", ("glue", "rte", "validation")),
"roberta-base-rte":
("textattack/roberta-base-RTE",
("glue", "rte", "validation")
),
"roberta-base-sst2": (
"textattack/roberta-base-SST-2",
("glue", "sst2", "validation"),
@@ -174,11 +192,18 @@ HUGGINGFACE_DATASET_BY_MODEL = {
#
# albert-base-v2 (ALBERT is cased by default)
#
"albert-base-v2-ag-news": (
"textattack/albert-base-v2-ag-news",
("ag_news", None, "test"),
),
"albert-base-v2-cola": (
"textattack/albert-base-v2-CoLA",
("glue", "cola", "validation"),
),
"albert-base-v2-imdb": ("textattack/albert-base-v2-imdb", ("imdb", None, "test"),),
"albert-base-v2-imdb":
("textattack/albert-base-v2-imdb",
("imdb", None, "test"),
),
"albert-base-v2-mr": (
"textattack/albert-base-v2-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
@@ -191,7 +216,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
"textattack/albert-base-v2-QQP",
("glue", "qqp", "validation"),
),
"albert-base-v2-snli": ("textattack/albert-base-v2-snli", ("snli", None, "test"),),
"albert-base-v2-snli":
("textattack/albert-base-v2-snli",
("snli", None, "test"),
),
"albert-base-v2-sst2": (
"textattack/albert-base-v2-SST-2",
("glue", "sst2", "validation"),
@@ -246,24 +274,52 @@ HUGGINGFACE_DATASET_BY_MODEL = {
# Models hosted by textattack.
#
TEXTATTACK_DATASET_BY_MODEL = {
# @TODO restore ag-news models after agnews joins `nlp` as a dataset.
#
# CNNs
#
"lstm-sst2": ("models/classification/lstm/sst2", ("glue", "sst2", "validation")),
"lstm-yelp": ("models/classification/lstm/yelp", ("yelp_polarity", None, "test"),),
"lstm-imdb": ("models/classification/lstm/imdb", ("imdb", None, "test")),
"lstm-mr": ("models/classification/lstm/mr", ("rotten_tomatoes", None, "test"),),
"lstm-ag-news": (
"models/classification/lstm/ag-news",
("ag_news", None, "test"),
),
"lstm-imdb": (
"models/classification/lstm/imdb",
("imdb", None, "test")
),
"lstm-mr": (
"models/classification/lstm/mr",
("rotten_tomatoes", None, "test"),
),
"lstm-sst2": (
"models/classification/lstm/sst2",
("glue", "sst2", "validation")
),
"lstm-yelp": (
"models/classification/lstm/yelp",
("yelp_polarity", None, "test"),
),
#
# LSTMs
#
"cnn-sst2": ("models/classification/cnn/sst", ("glue", "sst2", "validation")),
"cnn-imdb": ("models/classification/cnn/imdb", ("imdb", None, "test")),
"cnn-yelp": ("models/classification/cnn/yelp", ("yelp_polarity", None, "test"),),
"cnn-ag-news": (
"models/classification/cnn/ag-news",
("ag_news", None, "test"),
),
"cnn-imdb": (
"models/classification/cnn/imdb",
("imdb", None, "test")
),
"cnn-mr": (
"models/classification/cnn/rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
"cnn-sst2": (
"models/classification/cnn/sst",
("glue", "sst2", "validation")
),
"cnn-yelp": (
"models/classification/cnn/yelp",
("yelp_polarity", None, "test"),
),
#
# T5 for translation
#
@@ -282,7 +338,10 @@ TEXTATTACK_DATASET_BY_MODEL = {
#
# T5 for summarization
#
"t5-summarization": ("summarization", ("gigaword", None, "test")),
"t5-summarization":
("summarization",
("gigaword", None, "test")
),
}
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {

View File

@@ -314,6 +314,8 @@ def parse_model_from_args(args):
f"Tried to load model from path {args.model} - could not find train_args.json."
)
model_train_args = json.loads(open(model_args_json_path).read())
if model_train_args["model"] not in {"cnn", "lstm"}:
# for huggingface models, set args.model to the path of the model
model_train_args["model"] = args.model
num_labels = model_train_args["num_labels"]
from textattack.commands.train_model.train_args_helpers import (

View File

@@ -25,6 +25,7 @@ class EvalModelCommand(TextAttackCommand):
def get_preds(self, model, inputs):
with torch.no_grad():
preds = textattack.shared.utils.model_predict(model, inputs)
breakpoint()
return preds
def test_model_on_dataset(self, args):
@@ -65,6 +66,9 @@ class EvalModelCommand(TextAttackCommand):
perc_accuracy = "{:.2f}%".format(perc_accuracy)
logger.info(f"Successes {successes}/{len(preds)} ({_cb(perc_accuracy)})")
import collections
print(collections.Counter(guess_labels.tolist()))
def run(self, args):
# Default to 'all' if no model chosen.
if not (args.model or args.model_from_huggingface or args.model_from_file):

View File

@@ -86,7 +86,7 @@ def train_model(args):
)
if isinstance(train_labels[0], float):
# TODO come up with a more sophisticated scheme for when to do regression
# TODO come up with a more sophisticated scheme for knowing when to do regression
logger.warn(f"Detected float labels. Doing regression.")
args.num_labels = 1
args.do_regression = True
@@ -117,6 +117,7 @@ def train_model(args):
# multi-gpu training
if num_gpus > 1:
model = torch.nn.DataParallel(model)
logger.info("Using torch.nn.DataParallel.")
logger.info(f"Training model across {num_gpus} GPUs")
num_train_optimization_steps = (
@@ -171,6 +172,12 @@ def train_model(args):
args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
# Save original args to file
args_save_path = os.path.join(args.output_dir, "train_args.json")
with open(args_save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(args_dict, indent=2) + "\n")
logger.info(f"Wrote original training args to {args_save_path}.")
tb_writer.add_hparams(args_dict, {})
# Start training
@@ -356,6 +363,8 @@ def train_model(args):
break
# read the saved model and report its eval performance
logger.info("Finished training. Re-loading and evaluating model from disk.")
model = model_from_args(args, args.num_labels)
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
eval_score = get_eval_score()
logger.info(
@@ -375,8 +384,7 @@ def train_model(args):
write_readme(args, args.best_eval_score, args.best_eval_score_epoch)
# Save args to file
args_save_path = os.path.join(args.output_dir, "train_args.json")
final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
with open(args_save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(final_args_dict, indent=2) + "\n")
logger.info(f"Wrote training args to {args_save_path}.")
logger.info(f"Wrote final training args to {args_save_path}.")

View File

@@ -71,7 +71,7 @@ def dataset_from_args(args):
args.dataset_dev_split = "validation"
except KeyError:
try:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
eval_dataset = textattack.datasets.HuggingFaceNlpDataset(
*dataset_args, split="test"
)
args.dataset_dev_split = "test"

View File

@@ -66,6 +66,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
<section>
- AG News (`albert-base-v2-ag-news`)
- nlp dataset `ag_news`, split `test`
- Successes: 943/1000
- Accuracy: 94.30%
- CoLA (`albert-base-v2-cola`)
- nlp dataset `glue`, subset `cola`, split `validation`
- Successes: 829/1000
@@ -113,6 +117,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
<section>
- AG News (`bert-base-uncased-ag-news`)
- nlp dataset `ag_news`, split `test`
- Successes: 942/1000
- Accuracy: 94.20%
- CoLA (`bert-base-uncased-cola`)
- nlp dataset `glue`, subset `cola`, split `validation`
- Successes: 812/1000
@@ -172,6 +180,7 @@ All evaluations shown are on the full validation or test set up to 1000 examples
<section>
- CoLA (`distilbert-base-cased-cola`)
- nlp dataset `glue`, subset `cola`, split `validation`
- Successes: 786/1000
@@ -203,6 +212,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
<section>
- AG News (`distilbert-base-uncased-ag-news`)
- nlp dataset `ag_news`, split `test`
- Successes: 944/1000
- Accuracy: 94.40%
- CoLA (`distilbert-base-uncased-cola`)
- nlp dataset `glue`, subset `cola`, split `validation`
- Successes: 786/1000
@@ -242,6 +255,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
<section>
- AG News (`roberta-base-ag-news`)
- nlp dataset `ag_news`, split `test`
- Successes: 947/1000
- Accuracy: 94.70%
- CoLA (`roberta-base-cola`)
- nlp dataset `glue`, subset `cola`, split `validation`
- Successes: 857/1000

View File

@@ -61,7 +61,7 @@ class LSTMForClassification(nn.Module):
def forward(self, _input):
# ensure RNN module weights are part of single contiguous chunk of memory
# self.encoder.flatten_parameters()
self.encoder.flatten_parameters()
emb = self.emb_layer(_input.t())
emb = self.drop(emb)