mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update tests
This commit is contained in:
@@ -88,31 +88,52 @@ DATASET_BY_MODEL = {
|
||||
}
|
||||
|
||||
TRANSFORMATION_CLASS_NAMES = {
|
||||
'word-swap-wordnet': 'textattack.transformations.WordSwapWordNet',
|
||||
'word-swap-embedding': 'textattack.transformations.WordSwapEmbedding',
|
||||
'word-swap-homoglyph': 'textattack.transformations.WordSwapHomoglyph',
|
||||
'word-swap-neighboring-char-swap': 'textattack.transformations.WordSwapNeighboringCharacterSwap',
|
||||
'word-swap-embedding': 'textattack.transformations.WordSwapEmbedding',
|
||||
'word-swap-homoglyph': 'textattack.transformations.WordSwapHomoglyph',
|
||||
'word-swap-neighboring-char-swap': 'textattack.transformations.WordSwapNeighboringCharacterSwap',
|
||||
'word-swap-random-char-deletion': 'textattack.transformations.WordSwapRandomCharacterDeletion',
|
||||
'word-swap-random-char-insertion': 'textattack.transformations.WordSwapRandomCharacterInsertion',
|
||||
'word-swap-random-char-substitution': 'textattack.transformations.WordSwapRandomCharacterSubstitution',
|
||||
'word-swap-wordnet': 'textattack.transformations.WordSwapWordNet',
|
||||
}
|
||||
|
||||
CONSTRAINT_CLASS_NAMES = {
|
||||
'embedding': 'textattack.constraints.semantics.WordEmbeddingDistance',
|
||||
'goog-lm': 'textattack.constraints.semantics.language_models.GoogleLanguageModel',
|
||||
'bert': 'textattack.constraints.semantics.sentence_encoders.BERT',
|
||||
'infer-sent': 'textattack.constraints.semantics.sentence_encoders.InferSent',
|
||||
'use': 'textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder',
|
||||
'lang-tool': 'textattack.constraints.syntax.LanguageTool',
|
||||
#
|
||||
# Semantics constraints
|
||||
#
|
||||
'embedding': 'textattack.constraints.semantics.WordEmbeddingDistance',
|
||||
'bert': 'textattack.constraints.semantics.sentence_encoders.BERT',
|
||||
'infer-sent': 'textattack.constraints.semantics.sentence_encoders.InferSent',
|
||||
'thought-vector': 'textattack.constraints.semantics.sentence_encoders.ThoughtVector',
|
||||
'use': 'textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder',
|
||||
#
|
||||
# Grammaticality constraints
|
||||
#
|
||||
'lang-tool': 'textattack.constraints.grammaticality.LanguageTool',
|
||||
'part-of-speech': 'textattack.constraints.grammaticality.PartOfSpeech',
|
||||
'goog-lm': 'textattack.constraints.grammaticality.language_models.GoogleLanguageModel',
|
||||
'gpt2': 'textattack.constraints.grammaticality.language_models.GPT2',
|
||||
#
|
||||
# Overlap constraints
|
||||
#
|
||||
'bleu': 'textattack.constraints.overlap.BLEU',
|
||||
'chrf': 'textattack.constraints.overlap.chrF',
|
||||
'edit-distance': 'textattack.constraints.overlap.LevenshteinEditDistance',
|
||||
'meteor': 'textattack.constraints.overlap.METEOR',
|
||||
'words-perturbed': 'textattack.constraints.overlap.WordsPerturbed',
|
||||
}
|
||||
|
||||
SEARCH_CLASS_NAMES = {
|
||||
'beam-search': 'textattack.search_methods.BeamSearch',
|
||||
'greedy-word': 'textattack.search_methods.GreedyWordSwap',
|
||||
'ga-word': 'textattack.search_methods.GeneticAlgorithm',
|
||||
'greedy-word-wir': 'textattack.search_methods.GreedyWordSwapWIR',
|
||||
'beam-search': 'textattack.search_methods.BeamSearch',
|
||||
'greedy-word': 'textattack.search_methods.GreedyWordSwap',
|
||||
'ga-word': 'textattack.search_methods.GeneticAlgorithm',
|
||||
'greedy-word-wir': 'textattack.search_methods.GreedyWordSwapWIR',
|
||||
}
|
||||
|
||||
GOAL_FUNCTION_CLASS_NAMES = {
|
||||
'untargeted-classification': 'textattack.goal_functions.UntargetedClassification',
|
||||
'targeted-classification': 'textattack.goal_functions.TargetedClassification',
|
||||
'non-overlapping-output': 'textattack.goal_functions.NonOverlappingOutput',
|
||||
'targeted-classification': 'textattack.goal_functions.TargetedClassification',
|
||||
'untargeted-classification': 'textattack.goal_functions.UntargetedClassification',
|
||||
}
|
||||
|
||||
def set_seed(random_seed):
|
||||
@@ -133,8 +154,8 @@ def get_args():
|
||||
choices=MODEL_CLASS_NAMES.keys(), help='The classification model to attack.')
|
||||
|
||||
parser.add_argument('--constraints', type=str, required=False, nargs='*',
|
||||
default=[], choices=CONSTRAINT_CLASS_NAMES.keys(),
|
||||
help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}"'))
|
||||
default=[],
|
||||
help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ' + str(CONSTRAINT_CLASS_NAMES.keys())))
|
||||
|
||||
parser.add_argument('--out-dir', type=str, required=False, default=None,
|
||||
help='A directory to output results to.')
|
||||
@@ -149,7 +170,7 @@ def get_args():
|
||||
help='Disable logging to stdout')
|
||||
|
||||
parser.add_argument('--enable-csv', nargs='?', default=None, const='fancy', type=str,
|
||||
help='Enable logging to csv. Use --enable_csv plain to remove [[]] around words.')
|
||||
help='Enable logging to csv. Use --enable-csv plain to remove [[]] around words.')
|
||||
|
||||
parser.add_argument('--num-examples', '-n', type=int, required=False,
|
||||
default='5', help='The number of examples to process.')
|
||||
@@ -170,11 +191,11 @@ def get_args():
|
||||
help='Run attack using multiple GPUs.')
|
||||
|
||||
goal_function_choices = ', '.join(GOAL_FUNCTION_CLASS_NAMES.keys())
|
||||
parser.add_argument('--goal_function', '-g', default='untargeted-classification',
|
||||
parser.add_argument('--goal-function', '-g', default='untargeted-classification',
|
||||
help=f'The goal function to use. choices: {goal_function_choices}')
|
||||
|
||||
def str_to_int(s): return sum((ord(c) for c in s))
|
||||
parser.add_argument('--random_seed', default=str_to_int('TEXTATTACK'))
|
||||
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user