1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update tests

This commit is contained in:
Jack Morris
2020-05-09 08:42:34 -04:00
parent f847e4c472
commit a4f72facbd
11 changed files with 267 additions and 43 deletions

View File

View File

@@ -50,10 +50,26 @@ register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num-e
# (takes about 171s) # (takes about 171s)
# #
register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 ' register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 '
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet ' '--enable-csv --model bert-mnli --num-examples 4 --transformation word-swap-wordnet '
'--constraints lang-tool --attack beam-search:beam_width=2'), '--constraints lang-tool --attack beam-search:beam_width=2'),
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn', name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
output_file='local_tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_10.txt', output_file='local_tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_4.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples')
)
#
# test: run_attack non-overlapping output of class 2 on T5 en->de translation with
# attack_n set, using the WordSwapRandomCharacterSubstitution transformation
# and greedy word swap, using edit distance constraint, on 6 samples
# (takes about 100s)
#
register_test(('python -m textattack --attack-n --goal-function non-overlapping-output '
'--model t5-en2de --num-examples 6 --transformation word-swap-random-char-substitution '
'--constraints edit-distance:12 words-perturbed:max_percent=0.75 --attack greedy-word'),
name='run_attack_nonoverlapping_t5en2de_randomcharsub_editdistance_wordsperturbed_greedyword',
output_file='local_tests/sample_outputs/run_attack_nonoverlapping_t5ende_editdistance_bleu.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with' desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam ' 'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples') 'search with beam width 2, using language tool constraint, on 10 samples')

View File

@@ -11,7 +11,6 @@ def register_test(function, name=None, output_file=None, desc=None):
function, name=name, output=output, desc=desc function, name=name, output=output, desc=desc
)) ))
####################################### #######################################
## BEGIN TESTS ## ## BEGIN TESTS ##
####################################### #######################################
@@ -23,7 +22,7 @@ def check_gpu_count():
import torch import torch
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
if num_gpus == 0: if num_gpus == 0:
print(f'Error: detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?') raise ValueError(f'detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?')
register_test(check_gpu_count, name='check CUDA', register_test(check_gpu_count, name='check CUDA',
output_file='local_tests/sample_outputs/empty_file.txt', output_file='local_tests/sample_outputs/empty_file.txt',

View File

@@ -1 +1,2 @@
colored colored
side_by_side

View File

@@ -1,6 +1,8 @@
import argparse
import os import os
import time import time
from test_lists import tests
from test_models import color_text from test_models import color_text
def log_sep(): def log_sep():
@@ -9,7 +11,7 @@ def log_sep():
def print_gray(s): def print_gray(s):
print(color_text(s, 'light_gray')) print(color_text(s, 'light_gray'))
def main(): def change_to_root_dir():
# Change to TextAttack root directory. # Change to TextAttack root directory.
this_file_path = os.path.abspath(__file__) this_file_path = os.path.abspath(__file__)
test_directory_name = os.path.dirname(this_file_path) test_directory_name = os.path.dirname(this_file_path)
@@ -17,11 +19,11 @@ def main():
os.chdir(textattack_root_directory_name) os.chdir(textattack_root_directory_name)
print_gray(f'Executing tests from {textattack_root_directory_name}.') print_gray(f'Executing tests from {textattack_root_directory_name}.')
# Execute tests. def run_all_tests():
change_to_root_dir()
start_time = time.time() start_time = time.time()
passed_tests = 0 passed_tests = 0
from tests import tests
for test in tests: for test in tests:
log_sep() log_sep()
test_passed = test() test_passed = test()
@@ -32,7 +34,38 @@ def main():
print_gray(f'Passed {passed_tests}/{len(tests)} in {end_time-start_time}s.') print_gray(f'Passed {passed_tests}/{len(tests)} in {end_time-start_time}s.')
def run_tests_by_name(test_names):
test_names = set(test_names)
start_time = time.time()
passed_tests = 0
executed_tests = 0
for test in tests:
if test.name not in test_names:
continue
log_sep()
test_passed = test()
if test_passed:
passed_tests += 1
executed_tests += 1
test_names.remove(test.name)
log_sep()
end_time = time.time()
print_gray(f'Passed {passed_tests}/{executed_tests} in {end_time-start_time}s.')
if len(test_names):
print(f'Tests not executed: {",".join(test_names)}')
def parse_args():
all_test_names = [t.name for t in tests]
parser = argparse.ArgumentParser(description='Run TextAttack local tests.')
parser.add_argument('--tests', default=str, nargs='+', choices=all_test_names,
help='names of specific tests to run')
return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
# @TODO add argparser and test sizes. args = parse_args()
main() if args.tests:
run_tests_by_name(args.tests)
else:
run_all_tests()

View File

@@ -0,0 +1,64 @@
GreedyWordSwap(
(goal_function): NonOverlappingOutput
(transformation): WordSwapRandomCharacterSubstitution(
(replace_stopwords): False
)
(constraints):
(0): LevenshteinEditDistance(
(max_edit_distance): 12
)
(1): WordsPerturbed(
(max_percent): 0.75
)
(is_black_box): True
)
Load time: /.*/s
--------------------------------------------- Result 1 ---------------------------------------------
Eine republikanische Strategie, um der Wiederwahl Obamas entgegenzuwirken-->[FAILED]
A Republican strategy to counter the re-election of Obama
--------------------------------------------- Result 2 ---------------------------------------------
Die republikanischen Führer rechtfertigten ihre Politik durch die Not-->[FAILED]
Republican leaders justified their policy by the need to combat electoral fraud.
--------------------------------------------- Result 3 ---------------------------------------------
Das Brennan-Zentrum betrachtet dies jedoch als Mythos und behaupt-->Allerdings hält das Brennan Centre dies für einen Mythos, indem e
However, the Brennan Centre considers this a myth, stating that electoral fraud is rarer in the United States than the number of people killed by lightning.
However, the Brennan Centre cTnsiders this a myth, stating that electoral fraud is rarer in the United States than the number of people killed by lightning.
--------------------------------------------- Result 4 ---------------------------------------------
Tatsächlich identifizierten republikanische Anwälte-->In einer DecOde identifizierten republikanische Anwält
Indeed, Republican lawyers identified only 300 cases of electoral fraud in the United States in a decade.
Indedd, Republican lawyers identified only 300 cases of electoral fraud in the United Ttates in a decOde.
--------------------------------------------- Result 5 ---------------------------------------------
Eines ist sicher: Diese neuen Bestimmungen werden sich negativ auf die Wahlbeteiligung aus-->Ein Hhing ist sicher: Diese neuen Bestimmungen werden sich negativ auf die Wahlbeteil
One thing is certain: these new provisions will have a negative impact on voter turn-out.
One Hhing is certain: these new provisions will have a negative impact on voter turn-out.
--------------------------------------------- Result 6 ---------------------------------------------
In diesem Sinne werden die Maßnahmen das demokratische System der USA teilweise untergraben-->[FAILED]
In this sense, the measures will partially undermine the American democratic system.
(0lqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwqqqqqqqqk(B
(0x(B Attack Results (0x(B (0x(B
(0tqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqnqqqqqqqqu(B
(0x(B Number of successful attacks: (0x(B 3 (0x(B
(0x(B Number of failed attacks: (0x(B 3 (0x(B
(0x(B Number of skipped attacks: (0x(B 0 (0x(B
(0x(B Original accuracy: (0x(B 100.0% (0x(B
(0x(B Accuracy under attack: (0x(B 50.0% (0x(B
(0x(B Attack success rate: (0x(B 50.0% (0x(B
(0x(B Average perturbed word %: (0x(B 9.62% (0x(B
(0x(B Average num. words per input: (0x(B 15.33 (0x(B
(0x(B Avg num queries: (0x(B 23.67 (0x(B
(0mqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqvqqqqqqqqj(B
Attack time: /.*/s

View File

@@ -0,0 +1,65 @@
BeamSearch(
(goal_function): TargetedClassification(
(target_class): 2
)
(transformation): WordSwapWordNet(
(replace_stopwords): False
)
(constraints):
(0): LanguageTool(
(grammar_error_threshold): 0
)
(is_black_box): True
)
Logging to CSV at path /.*/csv.
Load time: /.*/s
--------------------------------------------- Result 1 ---------------------------------------------
0-->2
In Temple Bar , the bookshop at the Gallery of Photography carries a large selection of photographic publications , and the Flying Pig is a secondhand bookshop .
There is a bookshop at the gallery .
In Temple Bar , the bookshop at the drift of Photography carries a large selection of photographic publications , and the Flying Pig is a secondhand bookshop .
There is a bookshop at the gallery .
--------------------------------------------- Result 2 ---------------------------------------------
0-->[FAILED]
On Naxos , you can walk through the pretty villages of the Tragea Valley and the foothills of Mount Zas , admiring Byzantine churches and exploring olive groves at your leisure .
Naxos is a place with beautiful scenery for leisure .
--------------------------------------------- Result 3 ---------------------------------------------
1-->[FAILED]
Impossible .
Impossible , unless circumstances are met .
--------------------------------------------- Result 4 ---------------------------------------------
0-->2
Expenses included in calculating net cost for education and training programs that are intended to increase or maintain national economic productive capacity shall be reported as investments in human capital as required supplementary stewardship information accompanying the financial statements of the Federal Government and its component units .
Net cost for education programs can be calculated as a way to increase productivity .
Expenses included in calculating net cost for education and training programs that are intended to increase or maintain national economic productive capacity shall be reported as investments in human capital as required supplementary stewardship information accompanying the financial statements of the Federal Government and its component units .
Net cost for education programs can be calculated as a way to increment productivity .
(0lqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwqqqqqqqqk(B
(0x(B Attack Results (0x(B (0x(B
(0tqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqnqqqqqqqqu(B
(0x(B Number of successful attacks: (0x(B 2 (0x(B
(0x(B Number of failed attacks: (0x(B 2 (0x(B
(0x(B Number of skipped attacks: (0x(B 0 (0x(B
(0x(B Original accuracy: (0x(B 100.0% (0x(B
(0x(B Accuracy under attack: (0x(B 50.0% (0x(B
(0x(B Attack success rate: (0x(B 50.0% (0x(B
(0x(B Average perturbed word %: (0x(B 2.38% (0x(B
(0x(B Average num. words per input: (0x(B 34.25 (0x(B
(0x(B Avg num queries: (0x(B 278.5 (0x(B
(0mqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqvqqqqqqqqj(B
Attack time: /.*/s

View File

@@ -2,16 +2,20 @@ import colored
import io import io
import os import os
import re import re
import signal
import sys import sys
import subprocess import subprocess
import traceback
from side_by_side import print_side_by_side
def color_text(s, color): def color_text(s, color):
return colored.stylize(s, colored.fg(color)) return colored.stylize(s, colored.fg(color))
FNULL = open('err.txt', 'w') stderr_file_name = 'err.out.txt'
MAGIC_STRING = '/.*/' MAGIC_STRING = '/.*/'
def compare_outputs(desired_output, test_output): def outputs_are_equivalent(desired_output, test_output):
""" Desired outputs have the magic string '/.*/' inserted wherever the """ Desired outputs have the magic string '/.*/' inserted wherever the
outputat that position doesn't actually matter. (For example, when the outputat that position doesn't actually matter. (For example, when the
time to execute is printed, or another non-deterministic feature of the time to execute is printed, or another non-deterministic feature of the
@@ -51,7 +55,7 @@ class TextAttackTest:
""" Runs test and prints success or failure. """ """ Runs test and prints success or failure. """
self.log_start() self.log_start()
test_output, errored = self.execute() test_output, errored = self.execute()
if compare_outputs(self.output, test_output): if (not errored) and outputs_are_equivalent(self.output, test_output):
self.log_success() self.log_success()
return True return True
else: else:
@@ -68,12 +72,19 @@ class TextAttackTest:
def log_failure(self, test_output, errored): def log_failure(self, test_output, errored):
fail_text = f'✗ Failed.' fail_text = f'✗ Failed.'
print(color_text(fail_text, 'red')) print(color_text(fail_text, 'red'))
print('\n')
if errored: if errored:
print(f'Test exited early with error: {test_output}') print(f'Test exited early with error: {test_output}')
else: else:
print(f'Test output: {test_output}.') output1 = f'Test output: {test_output}.'
print(f'Correct output: {self.output}.') output2 = f'Correct output: {self.output}.'
### begin delete
print()
print(output1)
print()
print(output2)
print()
### end delete
print_side_by_side(output1, output2)
class CommandLineTest(TextAttackTest): class CommandLineTest(TextAttackTest):
""" Runs a command-line command to check for desired output. """ """ Runs a command-line command to check for desired output. """
@@ -84,7 +95,7 @@ class CommandLineTest(TextAttackTest):
super().__init__(name=name, output=output, desc=desc) super().__init__(name=name, output=output, desc=desc)
def execute(self): def execute(self):
stderr_file = open('err.out', 'w+') stderr_file = open(stderr_file_name, 'w+')
result = subprocess.run( result = subprocess.run(
self.command.split(), self.command.split(),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@@ -94,7 +105,7 @@ class CommandLineTest(TextAttackTest):
stderr_file.seek(0) # go back to beginning of file so we can read the whole thing stderr_file.seek(0) # go back to beginning of file so we can read the whole thing
stderr_str = stderr_file.read() stderr_str = stderr_file.read()
# Remove temp file. # Remove temp file.
os.unlink(stderr_file.name) remove_stderr_file()
if result.returncode == 0: if result.returncode == 0:
# If the command succeeds, return stdout. # If the command succeeds, return stdout.
return result.stdout.decode(), False return result.stdout.decode(), False
@@ -132,6 +143,20 @@ class PythonFunctionTest(TextAttackTest):
output = '\n'.join(output_lines) output = '\n'.join(output_lines)
return output, False return output, False
except: # catch *all* exceptions except: # catch *all* exceptions
e = sys.exc_info()[0] exc_str_lines = traceback.format_exc().splitlines()
return str(e), True exc_str = '\n'.join(exc_str_lines)
return exc_str, True
def remove_stderr_file():
# Make sure the stderr file is removed on exit.
try:
os.unlink(stderr_file_name)
except FileNotFoundError:
# File doesn't exit - that means we never made it or already cleaned it up
pass
def exit_handler(_,__):
remove_stderr_file()
# If the program exits early, make sure it didn't create any unneeded files.
signal.signal(signal.SIGINT, exit_handler)

View File

@@ -88,31 +88,52 @@ DATASET_BY_MODEL = {
} }
TRANSFORMATION_CLASS_NAMES = { TRANSFORMATION_CLASS_NAMES = {
'word-swap-wordnet': 'textattack.transformations.WordSwapWordNet', 'word-swap-embedding': 'textattack.transformations.WordSwapEmbedding',
'word-swap-embedding': 'textattack.transformations.WordSwapEmbedding', 'word-swap-homoglyph': 'textattack.transformations.WordSwapHomoglyph',
'word-swap-homoglyph': 'textattack.transformations.WordSwapHomoglyph', 'word-swap-neighboring-char-swap': 'textattack.transformations.WordSwapNeighboringCharacterSwap',
'word-swap-neighboring-char-swap': 'textattack.transformations.WordSwapNeighboringCharacterSwap', 'word-swap-random-char-deletion': 'textattack.transformations.WordSwapRandomCharacterDeletion',
'word-swap-random-char-insertion': 'textattack.transformations.WordSwapRandomCharacterInsertion',
'word-swap-random-char-substitution': 'textattack.transformations.WordSwapRandomCharacterSubstitution',
'word-swap-wordnet': 'textattack.transformations.WordSwapWordNet',
} }
CONSTRAINT_CLASS_NAMES = { CONSTRAINT_CLASS_NAMES = {
'embedding': 'textattack.constraints.semantics.WordEmbeddingDistance', #
'goog-lm': 'textattack.constraints.semantics.language_models.GoogleLanguageModel', # Semantics constraints
'bert': 'textattack.constraints.semantics.sentence_encoders.BERT', #
'infer-sent': 'textattack.constraints.semantics.sentence_encoders.InferSent', 'embedding': 'textattack.constraints.semantics.WordEmbeddingDistance',
'use': 'textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder', 'bert': 'textattack.constraints.semantics.sentence_encoders.BERT',
'lang-tool': 'textattack.constraints.syntax.LanguageTool', 'infer-sent': 'textattack.constraints.semantics.sentence_encoders.InferSent',
'thought-vector': 'textattack.constraints.semantics.sentence_encoders.ThoughtVector',
'use': 'textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder',
#
# Grammaticality constraints
#
'lang-tool': 'textattack.constraints.grammaticality.LanguageTool',
'part-of-speech': 'textattack.constraints.grammaticality.PartOfSpeech',
'goog-lm': 'textattack.constraints.grammaticality.language_models.GoogleLanguageModel',
'gpt2': 'textattack.constraints.grammaticality.language_models.GPT2',
#
# Overlap constraints
#
'bleu': 'textattack.constraints.overlap.BLEU',
'chrf': 'textattack.constraints.overlap.chrF',
'edit-distance': 'textattack.constraints.overlap.LevenshteinEditDistance',
'meteor': 'textattack.constraints.overlap.METEOR',
'words-perturbed': 'textattack.constraints.overlap.WordsPerturbed',
} }
SEARCH_CLASS_NAMES = { SEARCH_CLASS_NAMES = {
'beam-search': 'textattack.search_methods.BeamSearch', 'beam-search': 'textattack.search_methods.BeamSearch',
'greedy-word': 'textattack.search_methods.GreedyWordSwap', 'greedy-word': 'textattack.search_methods.GreedyWordSwap',
'ga-word': 'textattack.search_methods.GeneticAlgorithm', 'ga-word': 'textattack.search_methods.GeneticAlgorithm',
'greedy-word-wir': 'textattack.search_methods.GreedyWordSwapWIR', 'greedy-word-wir': 'textattack.search_methods.GreedyWordSwapWIR',
} }
GOAL_FUNCTION_CLASS_NAMES = { GOAL_FUNCTION_CLASS_NAMES = {
'untargeted-classification': 'textattack.goal_functions.UntargetedClassification', 'non-overlapping-output': 'textattack.goal_functions.NonOverlappingOutput',
'targeted-classification': 'textattack.goal_functions.TargetedClassification', 'targeted-classification': 'textattack.goal_functions.TargetedClassification',
'untargeted-classification': 'textattack.goal_functions.UntargetedClassification',
} }
def set_seed(random_seed): def set_seed(random_seed):
@@ -133,8 +154,8 @@ def get_args():
choices=MODEL_CLASS_NAMES.keys(), help='The classification model to attack.') choices=MODEL_CLASS_NAMES.keys(), help='The classification model to attack.')
parser.add_argument('--constraints', type=str, required=False, nargs='*', parser.add_argument('--constraints', type=str, required=False, nargs='*',
default=[], choices=CONSTRAINT_CLASS_NAMES.keys(), default=[],
help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}"')) help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ' + str(CONSTRAINT_CLASS_NAMES.keys())))
parser.add_argument('--out-dir', type=str, required=False, default=None, parser.add_argument('--out-dir', type=str, required=False, default=None,
help='A directory to output results to.') help='A directory to output results to.')
@@ -149,7 +170,7 @@ def get_args():
help='Disable logging to stdout') help='Disable logging to stdout')
parser.add_argument('--enable-csv', nargs='?', default=None, const='fancy', type=str, parser.add_argument('--enable-csv', nargs='?', default=None, const='fancy', type=str,
help='Enable logging to csv. Use --enable_csv plain to remove [[]] around words.') help='Enable logging to csv. Use --enable-csv plain to remove [[]] around words.')
parser.add_argument('--num-examples', '-n', type=int, required=False, parser.add_argument('--num-examples', '-n', type=int, required=False,
default='5', help='The number of examples to process.') default='5', help='The number of examples to process.')
@@ -170,11 +191,11 @@ def get_args():
help='Run attack using multiple GPUs.') help='Run attack using multiple GPUs.')
goal_function_choices = ', '.join(GOAL_FUNCTION_CLASS_NAMES.keys()) goal_function_choices = ', '.join(GOAL_FUNCTION_CLASS_NAMES.keys())
parser.add_argument('--goal_function', '-g', default='untargeted-classification', parser.add_argument('--goal-function', '-g', default='untargeted-classification',
help=f'The goal function to use. choices: {goal_function_choices}') help=f'The goal function to use. choices: {goal_function_choices}')
def str_to_int(s): return sum((ord(c) for c in s)) def str_to_int(s): return sum((ord(c) for c in s))
parser.add_argument('--random_seed', default=str_to_int('TEXTATTACK')) parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
attack_group = parser.add_mutually_exclusive_group(required=False) attack_group = parser.add_mutually_exclusive_group(required=False)