1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update tests

This commit is contained in:
Jack Morris
2020-05-09 08:42:34 -04:00
parent f847e4c472
commit a4f72facbd
11 changed files with 267 additions and 43 deletions

View File

View File

@@ -50,10 +50,26 @@ register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num-e
# (takes about 171s)
#
register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 '
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet '
'--enable-csv --model bert-mnli --num-examples 4 --transformation word-swap-wordnet '
'--constraints lang-tool --attack beam-search:beam_width=2'),
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
output_file='local_tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_10.txt',
output_file='local_tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_4.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples')
)
#
# test: run_attack non-overlapping output of class 2 on T5 en->de translation with
# attack_n set, using the WordSwapRandomCharacterSubstitution transformation
# and greedy word swap, using edit distance constraint, on 6 samples
# (takes about 100s)
#
register_test(('python -m textattack --attack-n --goal-function non-overlapping-output '
'--model t5-en2de --num-examples 6 --transformation word-swap-random-char-substitution '
'--constraints edit-distance:12 words-perturbed:max_percent=0.75 --attack greedy-word'),
name='run_attack_nonoverlapping_t5en2de_randomcharsub_editdistance_wordsperturbed_greedyword',
output_file='local_tests/sample_outputs/run_attack_nonoverlapping_t5ende_editdistance_bleu.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples')

View File

@@ -11,7 +11,6 @@ def register_test(function, name=None, output_file=None, desc=None):
function, name=name, output=output, desc=desc
))
#######################################
## BEGIN TESTS ##
#######################################
@@ -23,7 +22,7 @@ def check_gpu_count():
import torch
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
print(f'Error: detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?')
raise ValueError(f'detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?')
register_test(check_gpu_count, name='check CUDA',
output_file='local_tests/sample_outputs/empty_file.txt',

View File

@@ -1 +1,2 @@
colored
side_by_side

View File

@@ -1,6 +1,8 @@
import argparse
import os
import time
from test_lists import tests
from test_models import color_text
def log_sep():
@@ -9,7 +11,7 @@ def log_sep():
def print_gray(s):
print(color_text(s, 'light_gray'))
def main():
def change_to_root_dir():
# Change to TextAttack root directory.
this_file_path = os.path.abspath(__file__)
test_directory_name = os.path.dirname(this_file_path)
@@ -17,11 +19,11 @@ def main():
os.chdir(textattack_root_directory_name)
print_gray(f'Executing tests from {textattack_root_directory_name}.')
# Execute tests.
def run_all_tests():
change_to_root_dir()
start_time = time.time()
passed_tests = 0
from tests import tests
for test in tests:
log_sep()
test_passed = test()
@@ -32,7 +34,38 @@ def main():
print_gray(f'Passed {passed_tests}/{len(tests)} in {end_time-start_time}s.')
def run_tests_by_name(test_names):
test_names = set(test_names)
start_time = time.time()
passed_tests = 0
executed_tests = 0
for test in tests:
if test.name not in test_names:
continue
log_sep()
test_passed = test()
if test_passed:
passed_tests += 1
executed_tests += 1
test_names.remove(test.name)
log_sep()
end_time = time.time()
print_gray(f'Passed {passed_tests}/{executed_tests} in {end_time-start_time}s.')
if len(test_names):
print(f'Tests not executed: {",".join(test_names)}')
def parse_args():
all_test_names = [t.name for t in tests]
parser = argparse.ArgumentParser(description='Run TextAttack local tests.')
parser.add_argument('--tests', default=str, nargs='+', choices=all_test_names,
help='names of specific tests to run')
return parser.parse_args()
if __name__ == '__main__':
# @TODO add argparser and test sizes.
main()
args = parse_args()
if args.tests:
run_tests_by_name(args.tests)
else:
run_all_tests()

View File

@@ -0,0 +1,64 @@
GreedyWordSwap(
(goal_function): NonOverlappingOutput
(transformation): WordSwapRandomCharacterSubstitution(
(replace_stopwords): False
)
(constraints):
(0): LevenshteinEditDistance(
(max_edit_distance): 12
)
(1): WordsPerturbed(
(max_percent): 0.75
)
(is_black_box): True
)
Load time: /.*/s
--------------------------------------------- Result 1 ---------------------------------------------
Eine republikanische Strategie, um der Wiederwahl Obamas entgegenzuwirken-->[FAILED]
A Republican strategy to counter the re-election of Obama
--------------------------------------------- Result 2 ---------------------------------------------
Die republikanischen Führer rechtfertigten ihre Politik durch die Not-->[FAILED]
Republican leaders justified their policy by the need to combat electoral fraud.
--------------------------------------------- Result 3 ---------------------------------------------
Das Brennan-Zentrum betrachtet dies jedoch als Mythos und behaupt-->Allerdings hält das Brennan Centre dies für einen Mythos, indem e
However, the Brennan Centre considers this a myth, stating that electoral fraud is rarer in the United States than the number of people killed by lightning.
However, the Brennan Centre cTnsiders this a myth, stating that electoral fraud is rarer in the United States than the number of people killed by lightning.
--------------------------------------------- Result 4 ---------------------------------------------
Tatsächlich identifizierten republikanische Anwälte-->In einer DecOde identifizierten republikanische Anwält
Indeed, Republican lawyers identified only 300 cases of electoral fraud in the United States in a decade.
Indedd, Republican lawyers identified only 300 cases of electoral fraud in the United Ttates in a decOde.
--------------------------------------------- Result 5 ---------------------------------------------
Eines ist sicher: Diese neuen Bestimmungen werden sich negativ auf die Wahlbeteiligung aus-->Ein Hhing ist sicher: Diese neuen Bestimmungen werden sich negativ auf die Wahlbeteil
One thing is certain: these new provisions will have a negative impact on voter turn-out.
One Hhing is certain: these new provisions will have a negative impact on voter turn-out.
--------------------------------------------- Result 6 ---------------------------------------------
In diesem Sinne werden die Maßnahmen das demokratische System der USA teilweise untergraben-->[FAILED]
In this sense, the measures will partially undermine the American democratic system.
(0lqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwqqqqqqqqk(B
(0x(B Attack Results (0x(B (0x(B
(0tqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqnqqqqqqqqu(B
(0x(B Number of successful attacks: (0x(B 3 (0x(B
(0x(B Number of failed attacks: (0x(B 3 (0x(B
(0x(B Number of skipped attacks: (0x(B 0 (0x(B
(0x(B Original accuracy: (0x(B 100.0% (0x(B
(0x(B Accuracy under attack: (0x(B 50.0% (0x(B
(0x(B Attack success rate: (0x(B 50.0% (0x(B
(0x(B Average perturbed word %: (0x(B 9.62% (0x(B
(0x(B Average num. words per input: (0x(B 15.33 (0x(B
(0x(B Avg num queries: (0x(B 23.67 (0x(B
(0mqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqvqqqqqqqqj(B
Attack time: /.*/s

View File

@@ -0,0 +1,65 @@
BeamSearch(
(goal_function): TargetedClassification(
(target_class): 2
)
(transformation): WordSwapWordNet(
(replace_stopwords): False
)
(constraints):
(0): LanguageTool(
(grammar_error_threshold): 0
)
(is_black_box): True
)
Logging to CSV at path /.*/csv.
Load time: /.*/s
--------------------------------------------- Result 1 ---------------------------------------------
0-->2
In Temple Bar , the bookshop at the Gallery of Photography carries a large selection of photographic publications , and the Flying Pig is a secondhand bookshop .
There is a bookshop at the gallery .
In Temple Bar , the bookshop at the drift of Photography carries a large selection of photographic publications , and the Flying Pig is a secondhand bookshop .
There is a bookshop at the gallery .
--------------------------------------------- Result 2 ---------------------------------------------
0-->[FAILED]
On Naxos , you can walk through the pretty villages of the Tragea Valley and the foothills of Mount Zas , admiring Byzantine churches and exploring olive groves at your leisure .
Naxos is a place with beautiful scenery for leisure .
--------------------------------------------- Result 3 ---------------------------------------------
1-->[FAILED]
Impossible .
Impossible , unless circumstances are met .
--------------------------------------------- Result 4 ---------------------------------------------
0-->2
Expenses included in calculating net cost for education and training programs that are intended to increase or maintain national economic productive capacity shall be reported as investments in human capital as required supplementary stewardship information accompanying the financial statements of the Federal Government and its component units .
Net cost for education programs can be calculated as a way to increase productivity .
Expenses included in calculating net cost for education and training programs that are intended to increase or maintain national economic productive capacity shall be reported as investments in human capital as required supplementary stewardship information accompanying the financial statements of the Federal Government and its component units .
Net cost for education programs can be calculated as a way to increment productivity .
(0lqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwqqqqqqqqk(B
(0x(B Attack Results (0x(B (0x(B
(0tqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqnqqqqqqqqu(B
(0x(B Number of successful attacks: (0x(B 2 (0x(B
(0x(B Number of failed attacks: (0x(B 2 (0x(B
(0x(B Number of skipped attacks: (0x(B 0 (0x(B
(0x(B Original accuracy: (0x(B 100.0% (0x(B
(0x(B Accuracy under attack: (0x(B 50.0% (0x(B
(0x(B Attack success rate: (0x(B 50.0% (0x(B
(0x(B Average perturbed word %: (0x(B 2.38% (0x(B
(0x(B Average num. words per input: (0x(B 34.25 (0x(B
(0x(B Avg num queries: (0x(B 278.5 (0x(B
(0mqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqvqqqqqqqqj(B
Attack time: /.*/s

View File

@@ -2,16 +2,20 @@ import colored
import io
import os
import re
import signal
import sys
import subprocess
import traceback
from side_by_side import print_side_by_side
def color_text(s, color):
return colored.stylize(s, colored.fg(color))
FNULL = open('err.txt', 'w')
stderr_file_name = 'err.out.txt'
MAGIC_STRING = '/.*/'
def compare_outputs(desired_output, test_output):
def outputs_are_equivalent(desired_output, test_output):
""" Desired outputs have the magic string '/.*/' inserted wherever the
outputat that position doesn't actually matter. (For example, when the
time to execute is printed, or another non-deterministic feature of the
@@ -51,7 +55,7 @@ class TextAttackTest:
""" Runs test and prints success or failure. """
self.log_start()
test_output, errored = self.execute()
if compare_outputs(self.output, test_output):
if (not errored) and outputs_are_equivalent(self.output, test_output):
self.log_success()
return True
else:
@@ -68,12 +72,19 @@ class TextAttackTest:
def log_failure(self, test_output, errored):
fail_text = f'✗ Failed.'
print(color_text(fail_text, 'red'))
print('\n')
if errored:
print(f'Test exited early with error: {test_output}')
else:
print(f'Test output: {test_output}.')
print(f'Correct output: {self.output}.')
output1 = f'Test output: {test_output}.'
output2 = f'Correct output: {self.output}.'
### begin delete
print()
print(output1)
print()
print(output2)
print()
### end delete
print_side_by_side(output1, output2)
class CommandLineTest(TextAttackTest):
""" Runs a command-line command to check for desired output. """
@@ -84,7 +95,7 @@ class CommandLineTest(TextAttackTest):
super().__init__(name=name, output=output, desc=desc)
def execute(self):
stderr_file = open('err.out', 'w+')
stderr_file = open(stderr_file_name, 'w+')
result = subprocess.run(
self.command.split(),
stdout=subprocess.PIPE,
@@ -94,7 +105,7 @@ class CommandLineTest(TextAttackTest):
stderr_file.seek(0) # go back to beginning of file so we can read the whole thing
stderr_str = stderr_file.read()
# Remove temp file.
os.unlink(stderr_file.name)
remove_stderr_file()
if result.returncode == 0:
# If the command succeeds, return stdout.
return result.stdout.decode(), False
@@ -132,6 +143,20 @@ class PythonFunctionTest(TextAttackTest):
output = '\n'.join(output_lines)
return output, False
except: # catch *all* exceptions
e = sys.exc_info()[0]
return str(e), True
exc_str_lines = traceback.format_exc().splitlines()
exc_str = '\n'.join(exc_str_lines)
return exc_str, True
def remove_stderr_file():
# Make sure the stderr file is removed on exit.
try:
os.unlink(stderr_file_name)
except FileNotFoundError:
# File doesn't exit - that means we never made it or already cleaned it up
pass
def exit_handler(_,__):
remove_stderr_file()
# If the program exits early, make sure it didn't create any unneeded files.
signal.signal(signal.SIGINT, exit_handler)

View File

@@ -88,31 +88,52 @@ DATASET_BY_MODEL = {
}
TRANSFORMATION_CLASS_NAMES = {
'word-swap-wordnet': 'textattack.transformations.WordSwapWordNet',
'word-swap-embedding': 'textattack.transformations.WordSwapEmbedding',
'word-swap-homoglyph': 'textattack.transformations.WordSwapHomoglyph',
'word-swap-neighboring-char-swap': 'textattack.transformations.WordSwapNeighboringCharacterSwap',
'word-swap-embedding': 'textattack.transformations.WordSwapEmbedding',
'word-swap-homoglyph': 'textattack.transformations.WordSwapHomoglyph',
'word-swap-neighboring-char-swap': 'textattack.transformations.WordSwapNeighboringCharacterSwap',
'word-swap-random-char-deletion': 'textattack.transformations.WordSwapRandomCharacterDeletion',
'word-swap-random-char-insertion': 'textattack.transformations.WordSwapRandomCharacterInsertion',
'word-swap-random-char-substitution': 'textattack.transformations.WordSwapRandomCharacterSubstitution',
'word-swap-wordnet': 'textattack.transformations.WordSwapWordNet',
}
CONSTRAINT_CLASS_NAMES = {
'embedding': 'textattack.constraints.semantics.WordEmbeddingDistance',
'goog-lm': 'textattack.constraints.semantics.language_models.GoogleLanguageModel',
'bert': 'textattack.constraints.semantics.sentence_encoders.BERT',
'infer-sent': 'textattack.constraints.semantics.sentence_encoders.InferSent',
'use': 'textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder',
'lang-tool': 'textattack.constraints.syntax.LanguageTool',
#
# Semantics constraints
#
'embedding': 'textattack.constraints.semantics.WordEmbeddingDistance',
'bert': 'textattack.constraints.semantics.sentence_encoders.BERT',
'infer-sent': 'textattack.constraints.semantics.sentence_encoders.InferSent',
'thought-vector': 'textattack.constraints.semantics.sentence_encoders.ThoughtVector',
'use': 'textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder',
#
# Grammaticality constraints
#
'lang-tool': 'textattack.constraints.grammaticality.LanguageTool',
'part-of-speech': 'textattack.constraints.grammaticality.PartOfSpeech',
'goog-lm': 'textattack.constraints.grammaticality.language_models.GoogleLanguageModel',
'gpt2': 'textattack.constraints.grammaticality.language_models.GPT2',
#
# Overlap constraints
#
'bleu': 'textattack.constraints.overlap.BLEU',
'chrf': 'textattack.constraints.overlap.chrF',
'edit-distance': 'textattack.constraints.overlap.LevenshteinEditDistance',
'meteor': 'textattack.constraints.overlap.METEOR',
'words-perturbed': 'textattack.constraints.overlap.WordsPerturbed',
}
SEARCH_CLASS_NAMES = {
'beam-search': 'textattack.search_methods.BeamSearch',
'greedy-word': 'textattack.search_methods.GreedyWordSwap',
'ga-word': 'textattack.search_methods.GeneticAlgorithm',
'greedy-word-wir': 'textattack.search_methods.GreedyWordSwapWIR',
'beam-search': 'textattack.search_methods.BeamSearch',
'greedy-word': 'textattack.search_methods.GreedyWordSwap',
'ga-word': 'textattack.search_methods.GeneticAlgorithm',
'greedy-word-wir': 'textattack.search_methods.GreedyWordSwapWIR',
}
GOAL_FUNCTION_CLASS_NAMES = {
'untargeted-classification': 'textattack.goal_functions.UntargetedClassification',
'targeted-classification': 'textattack.goal_functions.TargetedClassification',
'non-overlapping-output': 'textattack.goal_functions.NonOverlappingOutput',
'targeted-classification': 'textattack.goal_functions.TargetedClassification',
'untargeted-classification': 'textattack.goal_functions.UntargetedClassification',
}
def set_seed(random_seed):
@@ -133,8 +154,8 @@ def get_args():
choices=MODEL_CLASS_NAMES.keys(), help='The classification model to attack.')
parser.add_argument('--constraints', type=str, required=False, nargs='*',
default=[], choices=CONSTRAINT_CLASS_NAMES.keys(),
help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}"'))
default=[],
help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ' + str(CONSTRAINT_CLASS_NAMES.keys())))
parser.add_argument('--out-dir', type=str, required=False, default=None,
help='A directory to output results to.')
@@ -149,7 +170,7 @@ def get_args():
help='Disable logging to stdout')
parser.add_argument('--enable-csv', nargs='?', default=None, const='fancy', type=str,
help='Enable logging to csv. Use --enable_csv plain to remove [[]] around words.')
help='Enable logging to csv. Use --enable-csv plain to remove [[]] around words.')
parser.add_argument('--num-examples', '-n', type=int, required=False,
default='5', help='The number of examples to process.')
@@ -170,11 +191,11 @@ def get_args():
help='Run attack using multiple GPUs.')
goal_function_choices = ', '.join(GOAL_FUNCTION_CLASS_NAMES.keys())
parser.add_argument('--goal_function', '-g', default='untargeted-classification',
parser.add_argument('--goal-function', '-g', default='untargeted-classification',
help=f'The goal function to use. choices: {goal_function_choices}')
def str_to_int(s): return sum((ord(c) for c in s))
parser.add_argument('--random_seed', default=str_to_int('TEXTATTACK'))
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
attack_group = parser.add_mutually_exclusive_group(required=False)