From c0edd5e08b1fea0c978f00d3648bd7c77cc17ee2 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Sun, 28 Jun 2020 19:22:57 -0400 Subject: [PATCH] add MR models and SNLI --- textattack/commands/attack/attack_args.py | 17 ++++++++++++++--- .../commands/attack/attack_args_helpers.py | 5 +++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/textattack/commands/attack/attack_args.py b/textattack/commands/attack/attack_args.py index 69e4873a..94df7d7f 100644 --- a/textattack/commands/attack/attack_args.py +++ b/textattack/commands/attack/attack_args.py @@ -54,9 +54,13 @@ HUGGINGFACE_DATASET_BY_MODEL = { ("glue", "wnli", "validation"), ), "bert-base-uncased-mr": ( - "textattack/bert-base-uncased-rotten_tomatoes", + "textattack/bert-base-uncased-rotten-tomatoes", ("rotten_tomatoes", None, "test"), ), + "bert-base-uncased-snli": ( + "textattack/bert-base-uncased-snli", + ("snli", None, "test", [1, 2, 0]), + ), # # distilbert-base-cased # @@ -144,14 +148,21 @@ HUGGINGFACE_DATASET_BY_MODEL = { ("glue", "wnli", "validation"), ), "roberta-base-mr": ( - "textattack/roberta-base-rotten_tomatoes", + "textattack/roberta-base-rotten-tomatoes", ("rotten_tomatoes", None, "test"), ), # # albert-base-v2 (ALBERT is cased by default) # "albert-base-v2-mr": ( - "textattack/albert-base-v2-rotten_tomatoes", + "textattack/albert-base-v2-rotten-tomatoes", + ("rotten_tomatoes", None, "test"), + ), + # + # xlnet-base-cased + # + "xlnet-base-cased-mr": ( + "textattack/xlnet-base-cased-rotten-tomatoes", ("rotten_tomatoes", None, "test"), ), } diff --git a/textattack/commands/attack/attack_args_helpers.py b/textattack/commands/attack/attack_args_helpers.py index adf2012b..6490b0fa 100644 --- a/textattack/commands/attack/attack_args_helpers.py +++ b/textattack/commands/attack/attack_args_helpers.py @@ -333,9 +333,10 @@ def parse_dataset_from_args(args): _, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model] elif args.model in TEXTATTACK_DATASET_BY_MODEL: _, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model] - + + # Automatically detect dataset for models trained with textattack. - if args.model and os.path.exists(args.model): + elif args.model and os.path.exists(args.model): model_args_json_path = os.path.join(args.model, "train_args.json") if not os.path.exists(model_args_json_path): raise FileNotFoundError(