mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add interactive mode + test
This commit is contained in:
@@ -16,6 +16,16 @@ def register_test(command, name=None, output_file=None, desc=None):
|
||||
## BEGIN TESTS ##
|
||||
#######################################
|
||||
|
||||
|
||||
#
|
||||
# test: run_attack --interactive
|
||||
#
|
||||
register_test(('printf "All that glitters is not gold\nq\n"',
|
||||
'python -m textattack --recipe textfooler --model bert-imdb --interactive'),
|
||||
name='interactive_mode',
|
||||
output_file='local_tests/sample_outputs/interactive_mode.txt',
|
||||
desc='Runs textfooler attack on BERT trained on IMDB using interactive mode')
|
||||
|
||||
#
|
||||
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR
|
||||
# (takes about 81s)
|
||||
|
||||
41
local_tests/sample_outputs/interactive_mode.txt
Normal file
41
local_tests/sample_outputs/interactive_mode.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
Attack(
|
||||
(search_method): GreedyWordSwapWIR(
|
||||
(wir_method): unk
|
||||
)
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 50
|
||||
(embedding_type): paragramcf
|
||||
)
|
||||
(constraints):
|
||||
(0): WordEmbeddingDistance(
|
||||
(embedding_type): paragramcf
|
||||
(min_cos_sim): 0.5
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(1): PartOfSpeech(
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
)
|
||||
(2): UniversalSentenceEncoder(
|
||||
(metric): angular
|
||||
(threshold): 0.904458599
|
||||
(compare_with_original): False
|
||||
(window_size): 15
|
||||
(skip_text_shorter_than_window): True
|
||||
)
|
||||
(3): RepeatModification
|
||||
(4): StopwordModification
|
||||
(is_black_box): True
|
||||
)
|
||||
|
||||
Load time: /.*/s
|
||||
Running in interactive mode
|
||||
----------------------------
|
||||
Enter a sentence to attack or "q" to quit:
|
||||
Attacking...
|
||||
[92m1[0m-->[91m0[0m
|
||||
All that [92mglitters[0m is not gold
|
||||
All that [91mglisten[0m is not gold
|
||||
Enter a sentence to attack or "q" to quit:
|
||||
@@ -2,6 +2,7 @@ import colored
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import signal
|
||||
import sys
|
||||
import subprocess
|
||||
@@ -87,11 +88,29 @@ class CommandLineTest(TextAttackTest):
|
||||
|
||||
def execute(self):
|
||||
stderr_file = open(stderr_file_name, 'w+')
|
||||
result = subprocess.run(
|
||||
self.command.split(),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=stderr_file
|
||||
)
|
||||
if isinstance(self.command, tuple):
|
||||
# Support pipes via tuple of commands
|
||||
procs = []
|
||||
for i in range(len(self.command) - 1):
|
||||
if i == 0:
|
||||
proc = subprocess.Popen(shlex.split(self.command[i]), stdout=subprocess.PIPE)
|
||||
else:
|
||||
proc = subprocess.Popen(shlex.split(self.command[i]), stdout=subprocess.PIPE, stdin=proc.stdout)
|
||||
procs.append(proc)
|
||||
# Run last commmand
|
||||
result = subprocess.run(
|
||||
shlex.split(self.command[-1]), stdin=procs[-1].stdout,
|
||||
stdout=subprocess.PIPE, stderr=stderr_file
|
||||
)
|
||||
# Wait for all intermittent processes
|
||||
for proc in procs:
|
||||
proc.wait()
|
||||
else:
|
||||
result = subprocess.run(
|
||||
shlex.split(self.command.split),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=stderr_file
|
||||
)
|
||||
stderr_file.seek(0) # go back to beginning of file so we can read the whole thing
|
||||
stderr_str = stderr_file.read()
|
||||
# Remove temp file.
|
||||
|
||||
@@ -162,6 +162,8 @@ class Attack:
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(dataset.examples)
|
||||
|
||||
num_examples = num_examples or len(dataset)
|
||||
|
||||
if num_examples <= 0:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user