mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add a test that includes --attack_n, --enable_csv, targeted goal function
This commit is contained in:
@@ -33,6 +33,7 @@ register_test('python scripts/run_attack.py --model bert-snli --recipe textfoole
|
|||||||
name='run_attack_textfooler_bert_snli_10',
|
name='run_attack_textfooler_bert_snli_10',
|
||||||
output_file='local_tests/outputs/run_attack_textfooler_bert_snli_10.txt',
|
output_file='local_tests/outputs/run_attack_textfooler_bert_snli_10.txt',
|
||||||
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset')
|
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset')
|
||||||
|
|
||||||
#
|
#
|
||||||
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
|
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
|
||||||
# (takes about 41s)
|
# (takes about 41s)
|
||||||
@@ -41,3 +42,19 @@ register_test('python scripts/run_attack.py --model lstm-mr --recipe deepwordbug
|
|||||||
name='run_attack_deepwordbug_lstm_mr_10',
|
name='run_attack_deepwordbug_lstm_mr_10',
|
||||||
output_file='local_tests/outputs/run_attack_deepwordbug_lstm_mr_10.txt',
|
output_file='local_tests/outputs/run_attack_deepwordbug_lstm_mr_10.txt',
|
||||||
desc='Runs attack using DeepWordBUG recipe on LSTM using 10 examples from the MR dataset')
|
desc='Runs attack using DeepWordBUG recipe on LSTM using 10 examples from the MR dataset')
|
||||||
|
|
||||||
|
#
|
||||||
|
# test: run_attack targeted classification of class 2 on BERT MNLI with enable_csv
|
||||||
|
# and attack_n set, using the WordNet transformation and beam search with
|
||||||
|
# beam width 2, using language tool constraint, on 10 samples
|
||||||
|
# (takes about 171s)
|
||||||
|
#
|
||||||
|
register_test(('python scripts/run_attack.py --attack_n --goal_function targeted-classification:target_class=2 '
|
||||||
|
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet '
|
||||||
|
'--constraints lang-tool --attack beam-search:beam_width=2'),
|
||||||
|
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
|
||||||
|
output_file='local_tests/outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_10.txt',
|
||||||
|
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
|
||||||
|
'enable_csv and attack_n set, using the WordNet transformation and beam '
|
||||||
|
'search with beam width 2, using language tool constraint, on 10 samples')
|
||||||
|
)
|
||||||
@@ -1,13 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from tests import tests
|
|
||||||
from test_models import color_text
|
from test_models import color_text
|
||||||
|
|
||||||
def log_sep():
|
def log_sep():
|
||||||
print('\n' + ('-' * 60) + '\n')
|
print('\n' + ('-' * 60) + '\n')
|
||||||
|
|
||||||
|
|
||||||
def print_gray(s):
|
def print_gray(s):
|
||||||
print(color_text(s, 'light_gray'))
|
print(color_text(s, 'light_gray'))
|
||||||
|
|
||||||
@@ -22,6 +20,8 @@ def main():
|
|||||||
# Execute tests.
|
# Execute tests.
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
passed_tests = 0
|
passed_tests = 0
|
||||||
|
|
||||||
|
from tests import tests
|
||||||
for test in tests:
|
for test in tests:
|
||||||
log_sep()
|
log_sep()
|
||||||
test_passed = test()
|
test_passed = test()
|
||||||
|
|||||||
Reference in New Issue
Block a user