mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add MR models and SNLI
This commit is contained in:
@@ -54,9 +54,13 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
("glue", "wnli", "validation"),
|
||||
),
|
||||
"bert-base-uncased-mr": (
|
||||
"textattack/bert-base-uncased-rotten_tomatoes",
|
||||
"textattack/bert-base-uncased-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
"bert-base-uncased-snli": (
|
||||
"textattack/bert-base-uncased-snli",
|
||||
("snli", None, "test", [1, 2, 0]),
|
||||
),
|
||||
#
|
||||
# distilbert-base-cased
|
||||
#
|
||||
@@ -144,14 +148,21 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
("glue", "wnli", "validation"),
|
||||
),
|
||||
"roberta-base-mr": (
|
||||
"textattack/roberta-base-rotten_tomatoes",
|
||||
"textattack/roberta-base-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
#
|
||||
# albert-base-v2 (ALBERT is cased by default)
|
||||
#
|
||||
"albert-base-v2-mr": (
|
||||
"textattack/albert-base-v2-rotten_tomatoes",
|
||||
"textattack/albert-base-v2-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
#
|
||||
# xlnet-base-cased
|
||||
#
|
||||
"xlnet-base-cased-mr": (
|
||||
"textattack/xlnet-base-cased-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -333,9 +333,10 @@ def parse_dataset_from_args(args):
|
||||
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
|
||||
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
|
||||
|
||||
|
||||
# Automatically detect dataset for models trained with textattack.
|
||||
if args.model and os.path.exists(args.model):
|
||||
elif args.model and os.path.exists(args.model):
|
||||
model_args_json_path = os.path.join(args.model, "train_args.json")
|
||||
if not os.path.exists(model_args_json_path):
|
||||
raise FileNotFoundError(
|
||||
|
||||
Reference in New Issue
Block a user