1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add MR models and SNLI

This commit is contained in:
Jack Morris
2020-06-28 19:22:57 -04:00
parent 3f15bb34f8
commit c0edd5e08b
2 changed files with 17 additions and 5 deletions

View File

@@ -54,9 +54,13 @@ HUGGINGFACE_DATASET_BY_MODEL = {
("glue", "wnli", "validation"),
),
"bert-base-uncased-mr": (
"textattack/bert-base-uncased-rotten_tomatoes",
"textattack/bert-base-uncased-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
"bert-base-uncased-snli": (
"textattack/bert-base-uncased-snli",
("snli", None, "test", [1, 2, 0]),
),
#
# distilbert-base-cased
#
@@ -144,14 +148,21 @@ HUGGINGFACE_DATASET_BY_MODEL = {
("glue", "wnli", "validation"),
),
"roberta-base-mr": (
"textattack/roberta-base-rotten_tomatoes",
"textattack/roberta-base-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
#
# albert-base-v2 (ALBERT is cased by default)
#
"albert-base-v2-mr": (
"textattack/albert-base-v2-rotten_tomatoes",
"textattack/albert-base-v2-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
#
# xlnet-base-cased
#
"xlnet-base-cased-mr": (
"textattack/xlnet-base-cased-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
}

View File

@@ -333,9 +333,10 @@ def parse_dataset_from_args(args):
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
# Automatically detect dataset for models trained with textattack.
if args.model and os.path.exists(args.model):
elif args.model and os.path.exists(args.model):
model_args_json_path = os.path.join(args.model, "train_args.json")
if not os.path.exists(model_args_json_path):
raise FileNotFoundError(