mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
291 lines
10 KiB
Python
291 lines
10 KiB
Python
import textattack
|
|
|
|
RECIPE_NAMES = {
|
|
"alzantot": "textattack.attack_recipes.Alzantot2018",
|
|
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
|
|
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
|
|
"kuleshov": "textattack.attack_recipes.Kuleshov2017",
|
|
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
|
|
"textbugger": "textattack.attack_recipes.TextBuggerLi2018",
|
|
"textfooler": "textattack.attack_recipes.TextFoolerJin2019",
|
|
}
|
|
|
|
# AG News and MR are the last datasets self-hosted by textattack. Once they
|
|
# join `nlp`, we'll remove them from our hosting.
|
|
TEXTATTACK_MODEL_CLASS_NAMES = {
|
|
#
|
|
#
|
|
# BERT models - default uncased
|
|
"bert-base-uncased-ag-news": "textattack.models.classification.bert.BERTForAGNewsClassification",
|
|
"bert-base-uncased-mr": "textattack.models.classification.bert.BERTForMRSentimentClassification",
|
|
# CNN models
|
|
"cnn-ag-news": "textattack.models.classification.cnn.WordCNNForAGNewsClassification",
|
|
"cnn-mr": "textattack.models.classification.cnn.WordCNNForMRSentimentClassification",
|
|
# LSTM models
|
|
"lstm-ag-news": "textattack.models.classification.lstm.LSTMForAGNewsClassification",
|
|
"lstm-mr": "textattack.models.classification.lstm.LSTMForMRSentimentClassification",
|
|
#
|
|
# Translation models
|
|
#
|
|
"t5-en2fr": "textattack.models.translation.t5.T5EnglishToFrench",
|
|
"t5-en2de": "textattack.models.translation.t5.T5EnglishToGerman",
|
|
"t5-en2ro": "textattack.models.translation.t5.T5EnglishToRomanian",
|
|
#
|
|
# Summarization models
|
|
#
|
|
"t5-summ": "textattack.models.summarization.T5Summarization",
|
|
#
|
|
# Translation datasets
|
|
#
|
|
"t5-en2de": textattack.datasets.translation.NewsTest2013EnglishToGerman,
|
|
}
|
|
|
|
DATASET_BY_MODEL = {
|
|
# AG News
|
|
"bert-base-uncased-ag-news": textattack.datasets.classification.AGNews,
|
|
"cnn-ag-news": textattack.datasets.classification.AGNews,
|
|
"lstm-ag-news": textattack.datasets.classification.AGNews,
|
|
# MR
|
|
"bert-base-uncased-mr": textattack.datasets.classification.MovieReviewSentiment,
|
|
"cnn-mr": textattack.datasets.classification.MovieReviewSentiment,
|
|
"lstm-mr": textattack.datasets.classification.MovieReviewSentiment,
|
|
}
|
|
|
|
HUGGINGFACE_DATASET_BY_MODEL = {
|
|
#
|
|
# bert-base-uncased
|
|
#
|
|
"bert-base-uncased-cola": (
|
|
"textattack/bert-base-uncased-CoLA",
|
|
("glue", "cola", "validation"),
|
|
),
|
|
"bert-base-uncased-mnli": (
|
|
"textattack/bert-base-uncased-MNLI",
|
|
("glue", "mnli", "validation_matched", [1, 2, 0]),
|
|
),
|
|
"bert-base-uncased-mrpc": (
|
|
"textattack/bert-base-uncased-MRPC",
|
|
("glue", "mrpc", "validation"),
|
|
),
|
|
"bert-base-uncased-qnli": (
|
|
"textattack/bert-base-uncased-QNLI",
|
|
("glue", "qnli", "validation"),
|
|
),
|
|
"bert-base-uncased-qqp": (
|
|
"textattack/bert-base-uncased-QQP",
|
|
("glue", "qqp", "validation"),
|
|
),
|
|
"bert-base-uncased-rte": (
|
|
"textattack/bert-base-uncased-RTE",
|
|
("glue", "rte", "validation"),
|
|
),
|
|
"bert-base-uncased-sst2": (
|
|
"textattack/bert-base-uncased-SST-2",
|
|
("glue", "sst2", "validation"),
|
|
),
|
|
"bert-base-uncased-stsb": (
|
|
"textattack/bert-base-uncased-STS-B",
|
|
("glue", "stsb", "validation", None, 5.0),
|
|
),
|
|
"bert-base-uncased-wnli": (
|
|
"textattack/bert-base-uncased-WNLI",
|
|
("glue", "wnli", "validation"),
|
|
),
|
|
#
|
|
# distilbert-base-cased
|
|
#
|
|
"distilbert-base-cased-cola": (
|
|
"textattack/distilbert-base-cased-CoLA",
|
|
("glue", "cola", "validation"),
|
|
),
|
|
"distilbert-base-cased-mrpc": (
|
|
"textattack/distilbert-base-cased-MRPC",
|
|
("glue", "mrpc", "validation"),
|
|
),
|
|
"distilbert-base-cased-qqp": (
|
|
"textattack/distilbert-base-cased-QQP",
|
|
("glue", "qqp", "validation"),
|
|
),
|
|
"distilbert-base-cased-sst2": (
|
|
"textattack/distilbert-base-cased-SST-2",
|
|
("glue", "sst2", "validation"),
|
|
),
|
|
"distilbert-base-cased-stsb": (
|
|
"textattack/distilbert-base-cased-STS-B",
|
|
("glue", "stsb", "validation", None, 5.0),
|
|
),
|
|
#
|
|
# distilbert-base-uncased
|
|
#
|
|
"distilbert-base-uncased-mnli": (
|
|
"textattack/distilbert-base-uncased-MNLI",
|
|
("glue", "mnli", "validation_matched", [1, 2, 0]),
|
|
),
|
|
"distilbert-base-uncased-mrpc": (
|
|
"textattack/distilbert-base-uncased-MRPC",
|
|
("glue", "mrpc", "validation"),
|
|
),
|
|
"distilbert-base-uncased-qnli": (
|
|
"textattack/distilbert-base-uncased-QNLI",
|
|
("glue", "qnli", "validation"),
|
|
),
|
|
"distilbert-base-uncased-qqp": (
|
|
"textattack/distilbert-base-uncased-QQP",
|
|
("glue", "qqp", "validation"),
|
|
),
|
|
"distilbert-base-uncased-rte": (
|
|
"textattack/distilbert-base-uncased-RTE",
|
|
("glue", "rte", "validation"),
|
|
),
|
|
"distilbert-base-uncased-sst2": (
|
|
"textattack/distilbert-base-uncased-SST-2",
|
|
("glue", "sst2", "validation"),
|
|
),
|
|
"distilbert-base-uncased-stsb": (
|
|
"textattack/distilbert-base-uncased-STS-B",
|
|
("glue", "stsb", "validation", None, 5.0),
|
|
),
|
|
"distilbert-base-uncased-wnli": (
|
|
"textattack/distilbert-base-uncased-WNLI",
|
|
("glue", "wnli", "validation"),
|
|
),
|
|
#
|
|
# roberta-base (RoBERTa is cased by default)
|
|
#
|
|
"roberta-base-cola": (
|
|
"textattack/roberta-base-CoLA",
|
|
("glue", "cola", "validation"),
|
|
),
|
|
"roberta-base-mrpc": (
|
|
"textattack/roberta-base-MRPC",
|
|
("glue", "mrpc", "validation"),
|
|
),
|
|
"roberta-base-qnli": (
|
|
"textattack/roberta-base-QNLI",
|
|
("glue", "qnli", "validation"),
|
|
),
|
|
"roberta-base-rte": ("textattack/roberta-base-RTE", ("glue", "rte", "validation")),
|
|
"roberta-base-sst2": (
|
|
"textattack/roberta-base-SST-2",
|
|
("glue", "sst2", "validation"),
|
|
),
|
|
"roberta-base-stsb": (
|
|
"textattack/roberta-base-STS-B",
|
|
("glue", "stsb", "validation", None, 5.0),
|
|
),
|
|
"roberta-base-wnli": (
|
|
"textattack/roberta-base-WNLI",
|
|
("glue", "wnli", "validation"),
|
|
),
|
|
}
|
|
|
|
TEXTATTACK_DATASET_BY_MODEL = {
|
|
#
|
|
# CNNs
|
|
#
|
|
"lstm-sst": ("models/classification/lstm/sst", ("glue", "sst2", "validation")),
|
|
"lstm-yelp-sentiment": (
|
|
"models/classification/lstm/yelp_polarity",
|
|
("yelp_polarity", None, "test"),
|
|
),
|
|
"lstm-imdb": ("models/classification/lstm/imdb", ("imdb", None, "test")),
|
|
#
|
|
# LSTMs
|
|
#
|
|
"cnn-sst": ("models/classification/cnn/sst", ("glue", "sst2", "validation")),
|
|
"cnn-imdb": ("models/classification/cnn/imdb", ("imdb", None, "test")),
|
|
"cnn-yelp-sentiment": (
|
|
"models/classification/cnn/yelp_polarity",
|
|
("yelp", None, "test"),
|
|
),
|
|
#
|
|
# Textual entailment models
|
|
#
|
|
# BERT models
|
|
"bert-base-uncased-snli": ("snli", None, "test"),
|
|
#
|
|
# Text classification models
|
|
#
|
|
"bert-base-cased-imdb": (
|
|
"models/classification/bert/imdb-cased",
|
|
("imdb", None, "test"),
|
|
),
|
|
"bert-base-uncased-imdb": (
|
|
"models/classification/bert/imdb-uncased",
|
|
("imdb", None, "test"),
|
|
),
|
|
"bert-base-cased-yelp": (
|
|
"models/classification/bert/yelp-cased",
|
|
("yelp", None, "test"),
|
|
),
|
|
"bert-base-uncased-yelp": (
|
|
"models/classification/bert/yelp-uncased",
|
|
("yelp", None, "test"),
|
|
),
|
|
#
|
|
# Translation models
|
|
# TODO add proper datasets
|
|
#
|
|
# Summarization models
|
|
#
|
|
#'t5-summ': 'textattack.models.summarization.T5Summarization',
|
|
}
|
|
|
|
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
|
|
"word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
|
|
"word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
|
|
"word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
|
|
"word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
|
|
"word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
|
|
"word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
|
|
"word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
|
|
}
|
|
|
|
WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
|
|
"word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
|
|
}
|
|
|
|
CONSTRAINT_CLASS_NAMES = {
|
|
#
|
|
# Semantics constraints
|
|
#
|
|
"embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
|
|
"bert": "textattack.constraints.semantics.sentence_encoders.BERT",
|
|
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
|
|
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
|
|
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
|
|
#
|
|
# Grammaticality constraints
|
|
#
|
|
"lang-tool": "textattack.constraints.grammaticality.LanguageTool",
|
|
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
|
|
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
|
|
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
|
|
#
|
|
# Overlap constraints
|
|
#
|
|
"bleu": "textattack.constraints.overlap.BLEU",
|
|
"chrf": "textattack.constraints.overlap.chrF",
|
|
"edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
|
|
"meteor": "textattack.constraints.overlap.METEOR",
|
|
"max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
|
|
#
|
|
# Pre-transformation constraints
|
|
#
|
|
"repeat": "textattack.constraints.pre_transformation.RepeatModification",
|
|
"stopword": "textattack.constraints.pre_transformation.StopwordModification",
|
|
}
|
|
|
|
SEARCH_CLASS_NAMES = {
|
|
"beam-search": "textattack.search_methods.BeamSearch",
|
|
"greedy": "textattack.search_methods.GreedySearch",
|
|
"ga-word": "textattack.search_methods.GeneticAlgorithm",
|
|
"greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
|
|
}
|
|
|
|
GOAL_FUNCTION_CLASS_NAMES = {
|
|
"non-overlapping-output": "textattack.goal_functions.NonOverlappingOutput",
|
|
"targeted-classification": "textattack.goal_functions.TargetedClassification",
|
|
"untargeted-classification": "textattack.goal_functions.UntargetedClassification",
|
|
}
|