1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #87 from QData/logging

Logging updates; add Weights & Biases
This commit is contained in:
Jack Morris
2020-05-06 11:28:31 -04:00
committed by GitHub
13 changed files with 146 additions and 81 deletions

3
.gitignore vendored
View File

@@ -31,3 +31,6 @@ tensorflow-hub
# build outputs for PyPI
build/
dist/
# Weights & Biases outputs
wandb/

View File

@@ -20,7 +20,7 @@ def register_test(command, name=None, output_file=None, desc=None):
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR
# (takes about 81s)
#
register_test('python -m textattack --model bert-mr --recipe textfooler --num_examples 10',
register_test('python -m textattack --model bert-mr --recipe textfooler --num-examples 10',
name='run_attack_textfooler_bert_mr_10',
output_file='local_tests/sample_outputs/run_attack_textfooler_bert_mr_10.txt',
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the MR dataset')
@@ -29,7 +29,7 @@ register_test('python -m textattack --model bert-mr --recipe textfooler --num_ex
# test: run_attack_parallel textfooler attack on 10 samples from BERT SNLI
# (takes about 51s)
#
register_test('python -m textattack --model bert-snli --recipe textfooler --num_examples 10',
register_test('python -m textattack --model bert-snli --recipe textfooler --num-examples 10',
name='run_attack_textfooler_bert_snli_10',
output_file='local_tests/sample_outputs/run_attack_textfooler_bert_snli_10.txt',
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset')
@@ -38,7 +38,7 @@ register_test('python -m textattack --model bert-snli --recipe textfooler --num_
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
# (takes about 41s)
#
register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num_examples 10',
register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num-examples 10',
name='run_attack_deepwordbug_lstm_mr_10',
output_file='local_tests/sample_outputs/run_attack_deepwordbug_lstm_mr_10.txt',
desc='Runs attack using DeepWordBug recipe on LSTM using 10 examples from the MR dataset')
@@ -49,7 +49,7 @@ register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num_e
# beam width 2, using language tool constraint, on 10 samples
# (takes about 171s)
#
register_test(('python -m textattack --attack_n --goal_function targeted-classification:target_class=2 '
register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 '
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet '
'--constraints lang-tool --attack beam-search:beam_width=2'),
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',

View File

@@ -16,3 +16,4 @@ tensorflow_hub
terminaltables
tqdm
visdom
wandb

View File

@@ -2,7 +2,7 @@ from .csv_logger import CSVLogger
from .file_logger import FileLogger
from .logger import Logger
from .visdom_logger import VisdomLogger
from .weights_and_biases_logger import WeightsAndBiasesLogger
# The AttackLogger must be imported last,
# since it imports the other loggers.
from .attack_logger import AttackLogger
# AttackLogManager must be imported last, since it imports the other loggers.
from .attack_log_manager import AttackLogManager

View File

@@ -3,9 +3,9 @@ import torch
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from . import CSVLogger, FileLogger, VisdomLogger
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
class AttackLogger:
class AttackLogManager:
def __init__(self):
""" Logs the results of an attack to all attached loggers
"""
@@ -20,6 +20,9 @@ class AttackLogger:
def enable_visdom(self):
self.loggers.append(VisdomLogger())
def enable_wandb(self):
self.loggers.append(WeightsAndBiasesLogger())
def add_output_file(self, filename):
self.loggers.append(FileLogger(filename=filename))
@@ -27,18 +30,22 @@ class AttackLogger:
self.loggers.append(CSVLogger(filename=filename, color_method=color_method))
def log_result(self, result):
""" Logs an `AttackResult` on each of `self.loggers`. """
self.results.append(result)
for logger in self.loggers:
logger.log_attack_result(result)
def log_results(self, results):
""" Logs an iterable of `AttackResult` objects on each of
`self.loggers`.
"""
for result in results:
self.log_result(result)
self.log_summary()
def _log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
for logger in self.loggers:
logger.log_rows(rows, title, window_id)
logger.log_summary_rows(rows, title, window_id)
def log_sep(self):
for logger in self.loggers:
@@ -49,11 +56,12 @@ class AttackLogger:
logger.flush()
def log_attack_details(self, attack_name, model_name):
# @TODO log a more complete set of attack details
attack_detail_rows = [
['Attack algorithm:', attack_name],
['Model:', model_name],
]
self._log_rows(attack_detail_rows, 'Attack Details', 'attack_details')
self.log_summary_rows(attack_detail_rows, 'Attack Details', 'attack_details')
def log_summary(self):
total_attacks = len(self.results)
@@ -123,9 +131,10 @@ class AttackLogger:
avg_num_queries = num_queries.mean()
avg_num_queries = str(round(avg_num_queries, 2))
summary_table_rows.append(['Avg num queries:', avg_num_queries])
self._log_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
self.log_summary_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
# Show histogram of words changed.
numbins = max(self.max_words_changed, 10)
for logger in self.loggers:
logger.log_hist(num_words_changed_until_success[:numbins],
numbins=numbins, title='Num Words Perturbed', window_id='num_words_perturbed')

View File

@@ -25,7 +25,7 @@ class FileLogger(Logger):
self.fout.write(result.__str__(color_method=color_method))
self.fout.write('\n')
def log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
if self.stdout:
table_rows = [[title, '']] + rows
table = terminaltables.SingleTable(table_rows)

View File

@@ -5,7 +5,7 @@ class Logger:
def log_attack_result(self, result, examples_completed):
pass
def log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
pass
def log_hist(self, arr, numbins, title, window_id):

View File

@@ -1,19 +1,9 @@
import socket
from visdom import Visdom
from textattack.shared.utils import html_table_from_rows
from .logger import Logger
def style_from_dict(style_dict):
""" Turns
{ 'color': 'red', 'height': '100px'}
into
style: "color: red; height: 100px"
"""
style_str = ''
for key in style_dict:
style_str += key + ': ' + style_dict[key] + ';'
return 'style="{}"'.format(style_str)
def port_is_open(port_num, hostname='127.0.0.1'):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex((hostname, port_num))
@@ -34,7 +24,7 @@ class VisdomLogger(Logger):
result_str = result.goal_function_result_str(color_method='html')
self.sample_rows.append([result_str,text_a,text_b])
def log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
self.table(rows, title=title, window_id=window_id)
def flush(self):
@@ -60,35 +50,7 @@ class VisdomLogger(Logger):
if not window_id: window_id = title # Can provide either of these,
if not title: title = window_id # or both.
# Stylize the container div.
if style:
table_html = '<div {}>'.format(style_from_dict(style))
else:
table_html = '<div>'
# Print the title string.
if title:
table_html += '<h1>{}</h1>'.format(title)
# Construct each row as HTML.
table_html = '<table class="table">'
if header:
table_html += '<tr>'
for element in header:
table_html += '<th>'
table_html += str(element)
table_html += '</th>'
table_html += '</tr>'
for row in rows:
table_html += '<tr>'
for element in row:
table_html += '<td>'
table_html += str(element)
table_html += '</td>'
table_html += '</tr>'
# Close the table and print to screen.
table_html += '</table></div>'
table = html_table_from_rows(rows, title=title, header=header, style_dict=style)
self.text(table_html, title=title, window_id=window_id)
def bar(self, X_data, numbins=10, title=None, window_id=None):

View File

@@ -0,0 +1,41 @@
import wandb
from textattack.shared.utils import html_table_from_rows
from .logger import Logger
class WeightsAndBiasesLogger(Logger):
def __init__(self, filename='', stdout=False):
wandb.init(project='textattack')
self._result_table_rows = []
def _log_result_table(self):
""" Weights & Biases doesn't have a feature to automatically
aggregate results across timesteps and display the full table.
Therefore, we have to do it manually.
"""
result_table = html_table_from_rows(self._result_table_rows)
wandb.log({ 'results': wandb.Html(result_table, inject=False) })
def log_attack_result(self, result):
original_text_colored, perturbed_text_colored = result.diff_color(color_method='html')
result_num = len(self._result_table_rows)
self._result_table_rows.append([f'<b>Result {result_num}</b>', original_text_colored, perturbed_text_colored])
result_diff_table = html_table_from_rows([[original_text_colored, perturbed_text_colored]])
result_diff_table = wandb.Html(result_diff_table)
wandb.log({
'resultType': type(result).__name__,
'result': result_diff_table,
'original_output': result.original_result.output,
'perturbed_output': result.perturbed_result.output,
})
self._log_result_table()
def log_summary_rows(self, rows, title, window_id):
print('w&b skipping summary')
pass
# @TODO: should we log some summary to W&B? It seems to automatically
# calculate its own summary statistics.
def log_sep(self):
self.fout.write('-' * 90 + '\n')

View File

@@ -134,22 +134,25 @@ def get_args():
default=[], choices=CONSTRAINT_CLASS_NAMES.keys(),
help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}"'))
parser.add_argument('--out_dir', type=str, required=False, default=None,
parser.add_argument('--out-dir', type=str, required=False, default=None,
help='A directory to output results to.')
parser.add_argument('--enable_visdom', action='store_true',
parser.add_argument('--enable-visdom', action='store_true',
help='Enable logging to visdom.')
parser.add_argument('--disable_stdout', action='store_true',
parser.add_argument('--enable-wandb', action='store_true',
help='Enable logging to Weights & Biases.')
parser.add_argument('--disable-stdout', action='store_true',
help='Disable logging to stdout')
parser.add_argument('--enable_csv', nargs='?', default=None, const='fancy', type=str,
parser.add_argument('--enable-csv', nargs='?', default=None, const='fancy', type=str,
help='Enable logging to csv. Use --enable_csv plain to remove [[]] around words.')
parser.add_argument('--num_examples', '-n', type=int, required=False,
parser.add_argument('--num-examples', '-n', type=int, required=False,
default='5', help='The number of examples to process.')
parser.add_argument('--num_examples_offset', '-o', type=int, required=False,
parser.add_argument('--num-examples-offset', '-o', type=int, required=False,
default=0, help='The offset to start at in the dataset.')
parser.add_argument('--shuffle', action='store_true', required=False,
@@ -158,7 +161,7 @@ def get_args():
parser.add_argument('--interactive', action='store_true', default=False,
help='Whether to run attacks interactively.')
parser.add_argument('--attack_n', action='store_true', default=False,
parser.add_argument('--attack-n', action='store_true', default=False,
help='Whether to run attack until `n` examples have been attacked (not skipped).')
parser.add_argument('--parallel', action='store_true', default=False,
@@ -277,32 +280,35 @@ def parse_goal_function_and_attack_from_args(args):
return goal_function, attack
def parse_logger_from_args(args):# Create logger
attack_logger = textattack.loggers.AttackLogger()
attack_log_manager = textattack.loggers.AttackLogManager()
# Set default output directory to `textattack/outputs`.
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(current_dir, os.pardir, 'outputs')
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
args.out_dir = outputs_dir
# Output file.
out_time = int(time.time()*1000) # Output file
outfile_name = 'attack-{}.txt'.format(out_time)
attack_logger.add_output_file(os.path.join(args.out_dir, outfile_name))
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
# CSV
if args.enable_csv:
outfile_name = 'attack-{}.csv'.format(out_time)
color_method = None if args.enable_csv == 'plain' else 'file'
csv_path = os.path.join(args.out_dir, outfile_name)
attack_logger.add_output_csv(csv_path, color_method)
attack_log_manager.add_output_csv(csv_path, color_method)
print('Logging to CSV at path {}.'.format(csv_path))
# Visdom
if args.enable_visdom:
attack_logger.enable_visdom()
attack_log_manager.enable_visdom()
# Weights & Biases
if args.enable_wandb:
attack_log_manager.enable_wandb()
# Stdout
if not args.disable_stdout:
attack_logger.enable_stdout()
return attack_logger
attack_log_manager.enable_stdout()
return attack_log_manager

View File

@@ -47,7 +47,7 @@ def run(args):
)
start_time = time.time()
attack_logger = parse_logger_from_args(args)
attack_log_manager = parse_logger_from_args(args)
# We reserve the first GPU for coordinating workers.
num_gpus = torch.cuda.device_count()
@@ -80,7 +80,7 @@ def run(args):
result = out_queue.get(block=True)
if isinstance(result, Exception):
raise result
attack_logger.log_result(result)
attack_log_manager.log_result(result)
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
pbar.update()
num_results += 1
@@ -96,9 +96,9 @@ def run(args):
print()
# Enable summary stdout.
if args.disable_stdout:
attack_logger.enable_stdout()
attack_logger.log_summary()
attack_logger.flush()
attack_log_manager.enable_stdout()
attack_log_manager.log_summary()
attack_log_manager.flush()
print()
finish_time = time.time()
print(f'Attack time: {time.time() - load_time}s')

View File

@@ -27,7 +27,7 @@ def run(args):
print(attack, '\n')
# Logger
attack_logger = parse_logger_from_args(args)
attack_log_manager = parse_logger_from_args(args)
load_time = time.time()
print(f'Load time: {load_time - start_time}s')
@@ -69,7 +69,7 @@ def run(args):
num_examples=args.num_examples,
shuffle=args.shuffle,
attack_n=args.attack_n):
attack_logger.log_result(result)
attack_log_manager.log_result(result)
if not args.disable_stdout:
print('\n')
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
@@ -84,9 +84,9 @@ def run(args):
print()
# Enable summary stdout
if args.disable_stdout:
attack_logger.enable_stdout()
attack_logger.log_summary()
attack_logger.flush()
attack_log_manager.enable_stdout()
attack_log_manager.log_summary()
attack_log_manager.flush()
print()
finish_time = time.time()
print(f'Attack time: {time.time() - load_time}s')

View File

@@ -171,3 +171,46 @@ class ANSI_ESCAPE_CODES:
UNDERLINE = '\033[4m'
""" This color stops the current color sequence. """
STOP = '\033[0m'
def html_style_from_dict(style_dict):
""" Turns
{ 'color': 'red', 'height': '100px'}
into
style: "color: red; height: 100px"
"""
style_str = ''
for key in style_dict:
style_str += key + ': ' + style_dict[key] + ';'
return 'style="{}"'.format(style_str)
def html_table_from_rows(rows, title=None, header=None, style_dict=None):
# Stylize the container div.
if style_dict:
table_html = '<div {}>'.format(style_from_dict(style_dict))
else:
table_html = '<div>'
# Print the title string.
if title:
table_html += '<h1>{}</h1>'.format(title)
# Construct each row as HTML.
table_html = '<table class="table">'
if header:
table_html += '<tr>'
for element in header:
table_html += '<th>'
table_html += str(element)
table_html += '</th>'
table_html += '</tr>'
for row in rows:
table_html += '<tr>'
for element in row:
table_html += '<td>'
table_html += str(element)
table_html += '</td>'
table_html += '</tr>'
# Close the table and print to screen.
table_html += '</table></div>'
return table_html