mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
adding AG News models
This commit is contained in:
@@ -24,6 +24,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
#
|
||||
# bert-base-uncased
|
||||
#
|
||||
"bert-base-uncased-ag-news": (
|
||||
"textattack/bert-base-uncased-ag-news",
|
||||
("ag_news", None, "test"),
|
||||
),
|
||||
"bert-base-uncased-cola": (
|
||||
"textattack/bert-base-uncased-CoLA",
|
||||
("glue", "cola", "validation"),
|
||||
@@ -106,6 +110,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
#
|
||||
# distilbert-base-uncased
|
||||
#
|
||||
"distilbert-base-uncased-ag-news": (
|
||||
"textattack/distilbert-base-uncased-ag-news",
|
||||
("ag_news", None, "test"),
|
||||
),
|
||||
"distilbert-base-uncased-cola": (
|
||||
"textattack/distilbert-base-cased-CoLA",
|
||||
("glue", "cola", "validation"),
|
||||
@@ -137,11 +145,18 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
#
|
||||
# roberta-base (RoBERTa is cased by default)
|
||||
#
|
||||
"roberta-base-ag-news": (
|
||||
"textattack/roberta-base-ag-news",
|
||||
("ag_news", None, "test"),
|
||||
),
|
||||
"roberta-base-cola": (
|
||||
"textattack/roberta-base-CoLA",
|
||||
("glue", "cola", "validation"),
|
||||
),
|
||||
"roberta-base-imdb": ("textattack/roberta-base-imdb", ("imdb", None, "test"),),
|
||||
"roberta-base-imdb":
|
||||
("textattack/roberta-base-imdb",
|
||||
("imdb", None, "test"),
|
||||
),
|
||||
"roberta-base-mr": (
|
||||
"textattack/textattack/roberta-base-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
@@ -154,7 +169,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
"textattack/roberta-base-QNLI",
|
||||
("glue", "qnli", "validation"),
|
||||
),
|
||||
"roberta-base-rte": ("textattack/roberta-base-RTE", ("glue", "rte", "validation")),
|
||||
"roberta-base-rte":
|
||||
("textattack/roberta-base-RTE",
|
||||
("glue", "rte", "validation")
|
||||
),
|
||||
"roberta-base-sst2": (
|
||||
"textattack/roberta-base-SST-2",
|
||||
("glue", "sst2", "validation"),
|
||||
@@ -174,11 +192,18 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
#
|
||||
# albert-base-v2 (ALBERT is cased by default)
|
||||
#
|
||||
"albert-base-v2-ag-news": (
|
||||
"textattack/albert-base-v2-ag-news",
|
||||
("ag_news", None, "test"),
|
||||
),
|
||||
"albert-base-v2-cola": (
|
||||
"textattack/albert-base-v2-CoLA",
|
||||
("glue", "cola", "validation"),
|
||||
),
|
||||
"albert-base-v2-imdb": ("textattack/albert-base-v2-imdb", ("imdb", None, "test"),),
|
||||
"albert-base-v2-imdb":
|
||||
("textattack/albert-base-v2-imdb",
|
||||
("imdb", None, "test"),
|
||||
),
|
||||
"albert-base-v2-mr": (
|
||||
"textattack/albert-base-v2-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
@@ -191,7 +216,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
"textattack/albert-base-v2-QQP",
|
||||
("glue", "qqp", "validation"),
|
||||
),
|
||||
"albert-base-v2-snli": ("textattack/albert-base-v2-snli", ("snli", None, "test"),),
|
||||
"albert-base-v2-snli":
|
||||
("textattack/albert-base-v2-snli",
|
||||
("snli", None, "test"),
|
||||
),
|
||||
"albert-base-v2-sst2": (
|
||||
"textattack/albert-base-v2-SST-2",
|
||||
("glue", "sst2", "validation"),
|
||||
@@ -246,24 +274,52 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
# Models hosted by textattack.
|
||||
#
|
||||
TEXTATTACK_DATASET_BY_MODEL = {
|
||||
# @TODO restore ag-news models after agnews joins `nlp` as a dataset.
|
||||
#
|
||||
# CNNs
|
||||
#
|
||||
"lstm-sst2": ("models/classification/lstm/sst2", ("glue", "sst2", "validation")),
|
||||
"lstm-yelp": ("models/classification/lstm/yelp", ("yelp_polarity", None, "test"),),
|
||||
"lstm-imdb": ("models/classification/lstm/imdb", ("imdb", None, "test")),
|
||||
"lstm-mr": ("models/classification/lstm/mr", ("rotten_tomatoes", None, "test"),),
|
||||
"lstm-ag-news": (
|
||||
"models/classification/lstm/ag-news",
|
||||
("ag_news", None, "test"),
|
||||
),
|
||||
"lstm-imdb": (
|
||||
"models/classification/lstm/imdb",
|
||||
("imdb", None, "test")
|
||||
),
|
||||
"lstm-mr": (
|
||||
"models/classification/lstm/mr",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
"lstm-sst2": (
|
||||
"models/classification/lstm/sst2",
|
||||
("glue", "sst2", "validation")
|
||||
),
|
||||
"lstm-yelp": (
|
||||
"models/classification/lstm/yelp",
|
||||
("yelp_polarity", None, "test"),
|
||||
),
|
||||
#
|
||||
# LSTMs
|
||||
#
|
||||
"cnn-sst2": ("models/classification/cnn/sst", ("glue", "sst2", "validation")),
|
||||
"cnn-imdb": ("models/classification/cnn/imdb", ("imdb", None, "test")),
|
||||
"cnn-yelp": ("models/classification/cnn/yelp", ("yelp_polarity", None, "test"),),
|
||||
"cnn-ag-news": (
|
||||
"models/classification/cnn/ag-news",
|
||||
("ag_news", None, "test"),
|
||||
),
|
||||
"cnn-imdb": (
|
||||
"models/classification/cnn/imdb",
|
||||
("imdb", None, "test")
|
||||
),
|
||||
"cnn-mr": (
|
||||
"models/classification/cnn/rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
"cnn-sst2": (
|
||||
"models/classification/cnn/sst",
|
||||
("glue", "sst2", "validation")
|
||||
),
|
||||
"cnn-yelp": (
|
||||
"models/classification/cnn/yelp",
|
||||
("yelp_polarity", None, "test"),
|
||||
),
|
||||
#
|
||||
# T5 for translation
|
||||
#
|
||||
@@ -282,7 +338,10 @@ TEXTATTACK_DATASET_BY_MODEL = {
|
||||
#
|
||||
# T5 for summarization
|
||||
#
|
||||
"t5-summarization": ("summarization", ("gigaword", None, "test")),
|
||||
"t5-summarization":
|
||||
("summarization",
|
||||
("gigaword", None, "test")
|
||||
),
|
||||
}
|
||||
|
||||
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
|
||||
|
||||
@@ -314,6 +314,8 @@ def parse_model_from_args(args):
|
||||
f"Tried to load model from path {args.model} - could not find train_args.json."
|
||||
)
|
||||
model_train_args = json.loads(open(model_args_json_path).read())
|
||||
if model_train_args["model"] not in {"cnn", "lstm"}:
|
||||
# for huggingface models, set args.model to the path of the model
|
||||
model_train_args["model"] = args.model
|
||||
num_labels = model_train_args["num_labels"]
|
||||
from textattack.commands.train_model.train_args_helpers import (
|
||||
|
||||
@@ -25,6 +25,7 @@ class EvalModelCommand(TextAttackCommand):
|
||||
def get_preds(self, model, inputs):
|
||||
with torch.no_grad():
|
||||
preds = textattack.shared.utils.model_predict(model, inputs)
|
||||
breakpoint()
|
||||
return preds
|
||||
|
||||
def test_model_on_dataset(self, args):
|
||||
@@ -65,6 +66,9 @@ class EvalModelCommand(TextAttackCommand):
|
||||
perc_accuracy = "{:.2f}%".format(perc_accuracy)
|
||||
logger.info(f"Successes {successes}/{len(preds)} ({_cb(perc_accuracy)})")
|
||||
|
||||
import collections
|
||||
print(collections.Counter(guess_labels.tolist()))
|
||||
|
||||
def run(self, args):
|
||||
# Default to 'all' if no model chosen.
|
||||
if not (args.model or args.model_from_huggingface or args.model_from_file):
|
||||
|
||||
@@ -86,7 +86,7 @@ def train_model(args):
|
||||
)
|
||||
|
||||
if isinstance(train_labels[0], float):
|
||||
# TODO come up with a more sophisticated scheme for when to do regression
|
||||
# TODO come up with a more sophisticated scheme for knowing when to do regression
|
||||
logger.warn(f"Detected float labels. Doing regression.")
|
||||
args.num_labels = 1
|
||||
args.do_regression = True
|
||||
@@ -117,6 +117,7 @@ def train_model(args):
|
||||
# multi-gpu training
|
||||
if num_gpus > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
logger.info("Using torch.nn.DataParallel.")
|
||||
logger.info(f"Training model across {num_gpus} GPUs")
|
||||
|
||||
num_train_optimization_steps = (
|
||||
@@ -171,6 +172,12 @@ def train_model(args):
|
||||
|
||||
args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
|
||||
|
||||
# Save original args to file
|
||||
args_save_path = os.path.join(args.output_dir, "train_args.json")
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(args_dict, indent=2) + "\n")
|
||||
logger.info(f"Wrote original training args to {args_save_path}.")
|
||||
|
||||
tb_writer.add_hparams(args_dict, {})
|
||||
|
||||
# Start training
|
||||
@@ -356,6 +363,8 @@ def train_model(args):
|
||||
break
|
||||
|
||||
# read the saved model and report its eval performance
|
||||
logger.info("Finished training. Re-loading and evaluating model from disk.")
|
||||
model = model_from_args(args, args.num_labels)
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
|
||||
eval_score = get_eval_score()
|
||||
logger.info(
|
||||
@@ -375,8 +384,7 @@ def train_model(args):
|
||||
write_readme(args, args.best_eval_score, args.best_eval_score_epoch)
|
||||
|
||||
# Save args to file
|
||||
args_save_path = os.path.join(args.output_dir, "train_args.json")
|
||||
final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(final_args_dict, indent=2) + "\n")
|
||||
logger.info(f"Wrote training args to {args_save_path}.")
|
||||
logger.info(f"Wrote final training args to {args_save_path}.")
|
||||
|
||||
@@ -71,7 +71,7 @@ def dataset_from_args(args):
|
||||
args.dataset_dev_split = "validation"
|
||||
except KeyError:
|
||||
try:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
eval_dataset = textattack.datasets.HuggingFaceNlpDataset(
|
||||
*dataset_args, split="test"
|
||||
)
|
||||
args.dataset_dev_split = "test"
|
||||
|
||||
@@ -66,6 +66,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
|
||||
|
||||
<section>
|
||||
|
||||
- AG News (`albert-base-v2-ag-news`)
|
||||
- nlp dataset `ag_news`, split `test`
|
||||
- Successes: 943/1000
|
||||
- Accuracy: 94.30%
|
||||
- CoLA (`albert-base-v2-cola`)
|
||||
- nlp dataset `glue`, subset `cola`, split `validation`
|
||||
- Successes: 829/1000
|
||||
@@ -113,6 +117,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
|
||||
|
||||
<section>
|
||||
|
||||
- AG News (`bert-base-uncased-ag-news`)
|
||||
- nlp dataset `ag_news`, split `test`
|
||||
- Successes: 942/1000
|
||||
- Accuracy: 94.20%
|
||||
- CoLA (`bert-base-uncased-cola`)
|
||||
- nlp dataset `glue`, subset `cola`, split `validation`
|
||||
- Successes: 812/1000
|
||||
@@ -172,6 +180,7 @@ All evaluations shown are on the full validation or test set up to 1000 examples
|
||||
|
||||
<section>
|
||||
|
||||
|
||||
- CoLA (`distilbert-base-cased-cola`)
|
||||
- nlp dataset `glue`, subset `cola`, split `validation`
|
||||
- Successes: 786/1000
|
||||
@@ -203,6 +212,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
|
||||
|
||||
<section>
|
||||
|
||||
- AG News (`distilbert-base-uncased-ag-news`)
|
||||
- nlp dataset `ag_news`, split `test`
|
||||
- Successes: 944/1000
|
||||
- Accuracy: 94.40%
|
||||
- CoLA (`distilbert-base-uncased-cola`)
|
||||
- nlp dataset `glue`, subset `cola`, split `validation`
|
||||
- Successes: 786/1000
|
||||
@@ -242,6 +255,10 @@ All evaluations shown are on the full validation or test set up to 1000 examples
|
||||
|
||||
<section>
|
||||
|
||||
- AG News (`roberta-base-ag-news`)
|
||||
- nlp dataset `ag_news`, split `test`
|
||||
- Successes: 947/1000
|
||||
- Accuracy: 94.70%
|
||||
- CoLA (`roberta-base-cola`)
|
||||
- nlp dataset `glue`, subset `cola`, split `validation`
|
||||
- Successes: 857/1000
|
||||
|
||||
@@ -61,7 +61,7 @@ class LSTMForClassification(nn.Module):
|
||||
|
||||
def forward(self, _input):
|
||||
# ensure RNN module weights are part of single contiguous chunk of memory
|
||||
# self.encoder.flatten_parameters()
|
||||
self.encoder.flatten_parameters()
|
||||
|
||||
emb = self.emb_layer(_input.t())
|
||||
emb = self.drop(emb)
|
||||
|
||||
Reference in New Issue
Block a user