mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #87 from QData/logging
Logging updates; add Weights & Biases
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -31,3 +31,6 @@ tensorflow-hub
|
||||
# build outputs for PyPI
|
||||
build/
|
||||
dist/
|
||||
|
||||
# Weights & Biases outputs
|
||||
wandb/
|
||||
|
||||
@@ -20,7 +20,7 @@ def register_test(command, name=None, output_file=None, desc=None):
|
||||
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR
|
||||
# (takes about 81s)
|
||||
#
|
||||
register_test('python -m textattack --model bert-mr --recipe textfooler --num_examples 10',
|
||||
register_test('python -m textattack --model bert-mr --recipe textfooler --num-examples 10',
|
||||
name='run_attack_textfooler_bert_mr_10',
|
||||
output_file='local_tests/sample_outputs/run_attack_textfooler_bert_mr_10.txt',
|
||||
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the MR dataset')
|
||||
@@ -29,7 +29,7 @@ register_test('python -m textattack --model bert-mr --recipe textfooler --num_ex
|
||||
# test: run_attack_parallel textfooler attack on 10 samples from BERT SNLI
|
||||
# (takes about 51s)
|
||||
#
|
||||
register_test('python -m textattack --model bert-snli --recipe textfooler --num_examples 10',
|
||||
register_test('python -m textattack --model bert-snli --recipe textfooler --num-examples 10',
|
||||
name='run_attack_textfooler_bert_snli_10',
|
||||
output_file='local_tests/sample_outputs/run_attack_textfooler_bert_snli_10.txt',
|
||||
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset')
|
||||
@@ -38,7 +38,7 @@ register_test('python -m textattack --model bert-snli --recipe textfooler --num_
|
||||
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
|
||||
# (takes about 41s)
|
||||
#
|
||||
register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num_examples 10',
|
||||
register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num-examples 10',
|
||||
name='run_attack_deepwordbug_lstm_mr_10',
|
||||
output_file='local_tests/sample_outputs/run_attack_deepwordbug_lstm_mr_10.txt',
|
||||
desc='Runs attack using DeepWordBug recipe on LSTM using 10 examples from the MR dataset')
|
||||
@@ -49,7 +49,7 @@ register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num_e
|
||||
# beam width 2, using language tool constraint, on 10 samples
|
||||
# (takes about 171s)
|
||||
#
|
||||
register_test(('python -m textattack --attack_n --goal_function targeted-classification:target_class=2 '
|
||||
register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 '
|
||||
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet '
|
||||
'--constraints lang-tool --attack beam-search:beam_width=2'),
|
||||
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
|
||||
|
||||
@@ -16,3 +16,4 @@ tensorflow_hub
|
||||
terminaltables
|
||||
tqdm
|
||||
visdom
|
||||
wandb
|
||||
@@ -2,7 +2,7 @@ from .csv_logger import CSVLogger
|
||||
from .file_logger import FileLogger
|
||||
from .logger import Logger
|
||||
from .visdom_logger import VisdomLogger
|
||||
from .weights_and_biases_logger import WeightsAndBiasesLogger
|
||||
|
||||
# The AttackLogger must be imported last,
|
||||
# since it imports the other loggers.
|
||||
from .attack_logger import AttackLogger
|
||||
# AttackLogManager must be imported last, since it imports the other loggers.
|
||||
from .attack_log_manager import AttackLogManager
|
||||
@@ -3,9 +3,9 @@ import torch
|
||||
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
|
||||
from . import CSVLogger, FileLogger, VisdomLogger
|
||||
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
|
||||
|
||||
class AttackLogger:
|
||||
class AttackLogManager:
|
||||
def __init__(self):
|
||||
""" Logs the results of an attack to all attached loggers
|
||||
"""
|
||||
@@ -20,6 +20,9 @@ class AttackLogger:
|
||||
def enable_visdom(self):
|
||||
self.loggers.append(VisdomLogger())
|
||||
|
||||
def enable_wandb(self):
|
||||
self.loggers.append(WeightsAndBiasesLogger())
|
||||
|
||||
def add_output_file(self, filename):
|
||||
self.loggers.append(FileLogger(filename=filename))
|
||||
|
||||
@@ -27,18 +30,22 @@ class AttackLogger:
|
||||
self.loggers.append(CSVLogger(filename=filename, color_method=color_method))
|
||||
|
||||
def log_result(self, result):
|
||||
""" Logs an `AttackResult` on each of `self.loggers`. """
|
||||
self.results.append(result)
|
||||
for logger in self.loggers:
|
||||
logger.log_attack_result(result)
|
||||
|
||||
def log_results(self, results):
|
||||
""" Logs an iterable of `AttackResult` objects on each of
|
||||
`self.loggers`.
|
||||
"""
|
||||
for result in results:
|
||||
self.log_result(result)
|
||||
self.log_summary()
|
||||
|
||||
def _log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
for logger in self.loggers:
|
||||
logger.log_rows(rows, title, window_id)
|
||||
logger.log_summary_rows(rows, title, window_id)
|
||||
|
||||
def log_sep(self):
|
||||
for logger in self.loggers:
|
||||
@@ -49,11 +56,12 @@ class AttackLogger:
|
||||
logger.flush()
|
||||
|
||||
def log_attack_details(self, attack_name, model_name):
|
||||
# @TODO log a more complete set of attack details
|
||||
attack_detail_rows = [
|
||||
['Attack algorithm:', attack_name],
|
||||
['Model:', model_name],
|
||||
]
|
||||
self._log_rows(attack_detail_rows, 'Attack Details', 'attack_details')
|
||||
self.log_summary_rows(attack_detail_rows, 'Attack Details', 'attack_details')
|
||||
|
||||
def log_summary(self):
|
||||
total_attacks = len(self.results)
|
||||
@@ -123,9 +131,10 @@ class AttackLogger:
|
||||
avg_num_queries = num_queries.mean()
|
||||
avg_num_queries = str(round(avg_num_queries, 2))
|
||||
summary_table_rows.append(['Avg num queries:', avg_num_queries])
|
||||
self._log_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
|
||||
self.log_summary_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
|
||||
# Show histogram of words changed.
|
||||
numbins = max(self.max_words_changed, 10)
|
||||
for logger in self.loggers:
|
||||
logger.log_hist(num_words_changed_until_success[:numbins],
|
||||
numbins=numbins, title='Num Words Perturbed', window_id='num_words_perturbed')
|
||||
|
||||
@@ -25,7 +25,7 @@ class FileLogger(Logger):
|
||||
self.fout.write(result.__str__(color_method=color_method))
|
||||
self.fout.write('\n')
|
||||
|
||||
def log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
if self.stdout:
|
||||
table_rows = [[title, '']] + rows
|
||||
table = terminaltables.SingleTable(table_rows)
|
||||
|
||||
@@ -5,7 +5,7 @@ class Logger:
|
||||
def log_attack_result(self, result, examples_completed):
|
||||
pass
|
||||
|
||||
def log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
pass
|
||||
|
||||
def log_hist(self, arr, numbins, title, window_id):
|
||||
|
||||
@@ -1,19 +1,9 @@
|
||||
import socket
|
||||
from visdom import Visdom
|
||||
|
||||
from textattack.shared.utils import html_table_from_rows
|
||||
from .logger import Logger
|
||||
|
||||
def style_from_dict(style_dict):
|
||||
""" Turns
|
||||
{ 'color': 'red', 'height': '100px'}
|
||||
into
|
||||
style: "color: red; height: 100px"
|
||||
"""
|
||||
style_str = ''
|
||||
for key in style_dict:
|
||||
style_str += key + ': ' + style_dict[key] + ';'
|
||||
return 'style="{}"'.format(style_str)
|
||||
|
||||
def port_is_open(port_num, hostname='127.0.0.1'):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
result = sock.connect_ex((hostname, port_num))
|
||||
@@ -34,7 +24,7 @@ class VisdomLogger(Logger):
|
||||
result_str = result.goal_function_result_str(color_method='html')
|
||||
self.sample_rows.append([result_str,text_a,text_b])
|
||||
|
||||
def log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
self.table(rows, title=title, window_id=window_id)
|
||||
|
||||
def flush(self):
|
||||
@@ -60,35 +50,7 @@ class VisdomLogger(Logger):
|
||||
|
||||
if not window_id: window_id = title # Can provide either of these,
|
||||
if not title: title = window_id # or both.
|
||||
|
||||
# Stylize the container div.
|
||||
if style:
|
||||
table_html = '<div {}>'.format(style_from_dict(style))
|
||||
else:
|
||||
table_html = '<div>'
|
||||
# Print the title string.
|
||||
if title:
|
||||
table_html += '<h1>{}</h1>'.format(title)
|
||||
|
||||
# Construct each row as HTML.
|
||||
table_html = '<table class="table">'
|
||||
if header:
|
||||
table_html += '<tr>'
|
||||
for element in header:
|
||||
table_html += '<th>'
|
||||
table_html += str(element)
|
||||
table_html += '</th>'
|
||||
table_html += '</tr>'
|
||||
for row in rows:
|
||||
table_html += '<tr>'
|
||||
for element in row:
|
||||
table_html += '<td>'
|
||||
table_html += str(element)
|
||||
table_html += '</td>'
|
||||
table_html += '</tr>'
|
||||
|
||||
# Close the table and print to screen.
|
||||
table_html += '</table></div>'
|
||||
table = html_table_from_rows(rows, title=title, header=header, style_dict=style)
|
||||
self.text(table_html, title=title, window_id=window_id)
|
||||
|
||||
def bar(self, X_data, numbins=10, title=None, window_id=None):
|
||||
|
||||
41
textattack/loggers/weights_and_biases_logger.py
Normal file
41
textattack/loggers/weights_and_biases_logger.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import wandb
|
||||
|
||||
from textattack.shared.utils import html_table_from_rows
|
||||
from .logger import Logger
|
||||
|
||||
class WeightsAndBiasesLogger(Logger):
|
||||
def __init__(self, filename='', stdout=False):
|
||||
wandb.init(project='textattack')
|
||||
self._result_table_rows = []
|
||||
|
||||
def _log_result_table(self):
|
||||
""" Weights & Biases doesn't have a feature to automatically
|
||||
aggregate results across timesteps and display the full table.
|
||||
Therefore, we have to do it manually.
|
||||
"""
|
||||
result_table = html_table_from_rows(self._result_table_rows)
|
||||
wandb.log({ 'results': wandb.Html(result_table, inject=False) })
|
||||
|
||||
def log_attack_result(self, result):
|
||||
original_text_colored, perturbed_text_colored = result.diff_color(color_method='html')
|
||||
result_num = len(self._result_table_rows)
|
||||
self._result_table_rows.append([f'<b>Result {result_num}</b>', original_text_colored, perturbed_text_colored])
|
||||
result_diff_table = html_table_from_rows([[original_text_colored, perturbed_text_colored]])
|
||||
result_diff_table = wandb.Html(result_diff_table)
|
||||
wandb.log({
|
||||
'resultType': type(result).__name__,
|
||||
'result': result_diff_table,
|
||||
'original_output': result.original_result.output,
|
||||
'perturbed_output': result.perturbed_result.output,
|
||||
})
|
||||
self._log_result_table()
|
||||
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
print('w&b skipping summary')
|
||||
pass
|
||||
# @TODO: should we log some summary to W&B? It seems to automatically
|
||||
# calculate its own summary statistics.
|
||||
|
||||
def log_sep(self):
|
||||
self.fout.write('-' * 90 + '\n')
|
||||
|
||||
@@ -134,22 +134,25 @@ def get_args():
|
||||
default=[], choices=CONSTRAINT_CLASS_NAMES.keys(),
|
||||
help=('Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}"'))
|
||||
|
||||
parser.add_argument('--out_dir', type=str, required=False, default=None,
|
||||
parser.add_argument('--out-dir', type=str, required=False, default=None,
|
||||
help='A directory to output results to.')
|
||||
|
||||
parser.add_argument('--enable_visdom', action='store_true',
|
||||
parser.add_argument('--enable-visdom', action='store_true',
|
||||
help='Enable logging to visdom.')
|
||||
|
||||
parser.add_argument('--disable_stdout', action='store_true',
|
||||
parser.add_argument('--enable-wandb', action='store_true',
|
||||
help='Enable logging to Weights & Biases.')
|
||||
|
||||
parser.add_argument('--disable-stdout', action='store_true',
|
||||
help='Disable logging to stdout')
|
||||
|
||||
parser.add_argument('--enable_csv', nargs='?', default=None, const='fancy', type=str,
|
||||
parser.add_argument('--enable-csv', nargs='?', default=None, const='fancy', type=str,
|
||||
help='Enable logging to csv. Use --enable_csv plain to remove [[]] around words.')
|
||||
|
||||
parser.add_argument('--num_examples', '-n', type=int, required=False,
|
||||
parser.add_argument('--num-examples', '-n', type=int, required=False,
|
||||
default='5', help='The number of examples to process.')
|
||||
|
||||
parser.add_argument('--num_examples_offset', '-o', type=int, required=False,
|
||||
parser.add_argument('--num-examples-offset', '-o', type=int, required=False,
|
||||
default=0, help='The offset to start at in the dataset.')
|
||||
|
||||
parser.add_argument('--shuffle', action='store_true', required=False,
|
||||
@@ -158,7 +161,7 @@ def get_args():
|
||||
parser.add_argument('--interactive', action='store_true', default=False,
|
||||
help='Whether to run attacks interactively.')
|
||||
|
||||
parser.add_argument('--attack_n', action='store_true', default=False,
|
||||
parser.add_argument('--attack-n', action='store_true', default=False,
|
||||
help='Whether to run attack until `n` examples have been attacked (not skipped).')
|
||||
|
||||
parser.add_argument('--parallel', action='store_true', default=False,
|
||||
@@ -277,32 +280,35 @@ def parse_goal_function_and_attack_from_args(args):
|
||||
return goal_function, attack
|
||||
|
||||
def parse_logger_from_args(args):# Create logger
|
||||
attack_logger = textattack.loggers.AttackLogger()
|
||||
attack_log_manager = textattack.loggers.AttackLogManager()
|
||||
# Set default output directory to `textattack/outputs`.
|
||||
if not args.out_dir:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(current_dir, os.pardir, 'outputs')
|
||||
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
|
||||
args.out_dir = outputs_dir
|
||||
|
||||
# Output file.
|
||||
out_time = int(time.time()*1000) # Output file
|
||||
outfile_name = 'attack-{}.txt'.format(out_time)
|
||||
attack_logger.add_output_file(os.path.join(args.out_dir, outfile_name))
|
||||
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
|
||||
|
||||
# CSV
|
||||
if args.enable_csv:
|
||||
outfile_name = 'attack-{}.csv'.format(out_time)
|
||||
color_method = None if args.enable_csv == 'plain' else 'file'
|
||||
csv_path = os.path.join(args.out_dir, outfile_name)
|
||||
attack_logger.add_output_csv(csv_path, color_method)
|
||||
attack_log_manager.add_output_csv(csv_path, color_method)
|
||||
print('Logging to CSV at path {}.'.format(csv_path))
|
||||
|
||||
|
||||
# Visdom
|
||||
if args.enable_visdom:
|
||||
attack_logger.enable_visdom()
|
||||
attack_log_manager.enable_visdom()
|
||||
|
||||
# Weights & Biases
|
||||
if args.enable_wandb:
|
||||
attack_log_manager.enable_wandb()
|
||||
|
||||
# Stdout
|
||||
if not args.disable_stdout:
|
||||
attack_logger.enable_stdout()
|
||||
return attack_logger
|
||||
attack_log_manager.enable_stdout()
|
||||
return attack_log_manager
|
||||
|
||||
@@ -47,7 +47,7 @@ def run(args):
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
attack_logger = parse_logger_from_args(args)
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
# We reserve the first GPU for coordinating workers.
|
||||
num_gpus = torch.cuda.device_count()
|
||||
@@ -80,7 +80,7 @@ def run(args):
|
||||
result = out_queue.get(block=True)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
attack_logger.log_result(result)
|
||||
attack_log_manager.log_result(result)
|
||||
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
|
||||
pbar.update()
|
||||
num_results += 1
|
||||
@@ -96,9 +96,9 @@ def run(args):
|
||||
print()
|
||||
# Enable summary stdout.
|
||||
if args.disable_stdout:
|
||||
attack_logger.enable_stdout()
|
||||
attack_logger.log_summary()
|
||||
attack_logger.flush()
|
||||
attack_log_manager.enable_stdout()
|
||||
attack_log_manager.log_summary()
|
||||
attack_log_manager.flush()
|
||||
print()
|
||||
finish_time = time.time()
|
||||
print(f'Attack time: {time.time() - load_time}s')
|
||||
|
||||
@@ -27,7 +27,7 @@ def run(args):
|
||||
print(attack, '\n')
|
||||
|
||||
# Logger
|
||||
attack_logger = parse_logger_from_args(args)
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
load_time = time.time()
|
||||
print(f'Load time: {load_time - start_time}s')
|
||||
@@ -69,7 +69,7 @@ def run(args):
|
||||
num_examples=args.num_examples,
|
||||
shuffle=args.shuffle,
|
||||
attack_n=args.attack_n):
|
||||
attack_logger.log_result(result)
|
||||
attack_log_manager.log_result(result)
|
||||
if not args.disable_stdout:
|
||||
print('\n')
|
||||
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
|
||||
@@ -84,9 +84,9 @@ def run(args):
|
||||
print()
|
||||
# Enable summary stdout
|
||||
if args.disable_stdout:
|
||||
attack_logger.enable_stdout()
|
||||
attack_logger.log_summary()
|
||||
attack_logger.flush()
|
||||
attack_log_manager.enable_stdout()
|
||||
attack_log_manager.log_summary()
|
||||
attack_log_manager.flush()
|
||||
print()
|
||||
finish_time = time.time()
|
||||
print(f'Attack time: {time.time() - load_time}s')
|
||||
|
||||
@@ -171,3 +171,46 @@ class ANSI_ESCAPE_CODES:
|
||||
UNDERLINE = '\033[4m'
|
||||
""" This color stops the current color sequence. """
|
||||
STOP = '\033[0m'
|
||||
|
||||
def html_style_from_dict(style_dict):
|
||||
""" Turns
|
||||
{ 'color': 'red', 'height': '100px'}
|
||||
into
|
||||
style: "color: red; height: 100px"
|
||||
"""
|
||||
style_str = ''
|
||||
for key in style_dict:
|
||||
style_str += key + ': ' + style_dict[key] + ';'
|
||||
return 'style="{}"'.format(style_str)
|
||||
|
||||
def html_table_from_rows(rows, title=None, header=None, style_dict=None):
|
||||
# Stylize the container div.
|
||||
if style_dict:
|
||||
table_html = '<div {}>'.format(style_from_dict(style_dict))
|
||||
else:
|
||||
table_html = '<div>'
|
||||
# Print the title string.
|
||||
if title:
|
||||
table_html += '<h1>{}</h1>'.format(title)
|
||||
|
||||
# Construct each row as HTML.
|
||||
table_html = '<table class="table">'
|
||||
if header:
|
||||
table_html += '<tr>'
|
||||
for element in header:
|
||||
table_html += '<th>'
|
||||
table_html += str(element)
|
||||
table_html += '</th>'
|
||||
table_html += '</tr>'
|
||||
for row in rows:
|
||||
table_html += '<tr>'
|
||||
for element in row:
|
||||
table_html += '<td>'
|
||||
table_html += str(element)
|
||||
table_html += '</td>'
|
||||
table_html += '</tr>'
|
||||
|
||||
# Close the table and print to screen.
|
||||
table_html += '</table></div>'
|
||||
|
||||
return table_html
|
||||
Reference in New Issue
Block a user