1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

AttackLogger --> AttackLogManager; add W&B logging

This commit is contained in:
Jack Morris
2020-05-04 14:21:05 -04:00
parent c82aad99b4
commit ac797c8840
12 changed files with 90 additions and 30 deletions

3
.gitignore vendored
View File

@@ -31,3 +31,6 @@ tensorflow-hub
# build outputs for PyPI
build/
dist/
# Weights & Biases outputs
wandb/

View File

@@ -16,3 +16,4 @@ tensorflow_hub
terminaltables
tqdm
visdom
wandb

View File

@@ -18,6 +18,16 @@ class GoalFunctionResult:
if isinstance(self.score, torch.Tensor):
self.score = self.score.item()
def statistics(self):
""" A dictionary of statistics about this result.
Used by the `AttackLogManager` to print aggregate attack statistics.
"""
return { # @ TODO implement
'num_words_changed', None,
'Words changed %', None
}
def get_text_color_input(self):
""" A string representing the color this result's changed
portion should be if it represents the original input.

View File

@@ -2,7 +2,7 @@ from .csv_logger import CSVLogger
from .file_logger import FileLogger
from .logger import Logger
from .visdom_logger import VisdomLogger
from .weights_and_biases_logger import WeightsAndBiasesLogger
# The AttackLogger must be imported last,
# since it imports the other loggers.
from .attack_logger import AttackLogger
# AttackLogManager must be imported last, since it imports the other loggers.
from .attack_log_manager import AttackLogManager

View File

@@ -3,9 +3,9 @@ import torch
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from . import CSVLogger, FileLogger, VisdomLogger
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
class AttackLogger:
class AttackLogManager:
def __init__(self):
""" Logs the results of an attack to all attached loggers
"""
@@ -20,6 +20,9 @@ class AttackLogger:
def enable_visdom(self):
self.loggers.append(VisdomLogger())
def enable_wandb(self):
self.loggers.append(WeightsAndBiasesLogger())
def add_output_file(self, filename):
self.loggers.append(FileLogger(filename=filename))
@@ -27,18 +30,22 @@ class AttackLogger:
self.loggers.append(CSVLogger(filename=filename, color_method=color_method))
def log_result(self, result):
""" Logs an `AttackResult` on each of `self.loggers`. """
self.results.append(result)
for logger in self.loggers:
logger.log_attack_result(result)
def log_results(self, results):
""" Logs an iterable of `AttackResult` objects on each of
`self.loggers`.
"""
for result in results:
self.log_result(result)
self.log_summary()
def _log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
for logger in self.loggers:
logger.log_rows(rows, title, window_id)
logger.log_summary_rows(rows, title, window_id)
def log_sep(self):
for logger in self.loggers:
@@ -49,11 +56,12 @@ class AttackLogger:
logger.flush()
def log_attack_details(self, attack_name, model_name):
# @TODO log a more complete set of attack details
attack_detail_rows = [
['Attack algorithm:', attack_name],
['Model:', model_name],
]
self._log_rows(attack_detail_rows, 'Attack Details', 'attack_details')
self.log_summary_rows(attack_detail_rows, 'Attack Details', 'attack_details')
def log_summary(self):
total_attacks = len(self.results)
@@ -112,6 +120,8 @@ class AttackLogger:
['Number of successful attacks:', str(successful_attacks)],
['Number of failed attacks:', str(failed_attacks)],
['Number of skipped attacks:', str(skipped_attacks)],
### @TODO: everything below this should be computed in a customizable
### way via overriding `GoalFunctionResult.statistics`.
['Original accuracy:', original_accuracy],
['Accuracy under attack:', accuracy_under_attack],
['Attack success rate:', attack_success_rate],
@@ -123,9 +133,10 @@ class AttackLogger:
avg_num_queries = num_queries.mean()
avg_num_queries = str(round(avg_num_queries, 2))
summary_table_rows.append(['Avg num queries:', avg_num_queries])
self._log_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
self.log_summary_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
# Show histogram of words changed.
numbins = max(self.max_words_changed, 10)
for logger in self.loggers:
logger.log_hist(num_words_changed_until_success[:numbins],
numbins=numbins, title='Num Words Perturbed', window_id='num_words_perturbed')

View File

@@ -25,7 +25,7 @@ class FileLogger(Logger):
self.fout.write(result.__str__(color_method=color_method))
self.fout.write('\n')
def log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
if self.stdout:
table_rows = [[title, '']] + rows
table = terminaltables.SingleTable(table_rows)

View File

@@ -5,7 +5,7 @@ class Logger:
def log_attack_result(self, result, examples_completed):
pass
def log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
pass
def log_hist(self, arr, numbins, title, window_id):

View File

@@ -34,7 +34,7 @@ class VisdomLogger(Logger):
result_str = result.goal_function_result_str(color_method='html')
self.sample_rows.append([result_str,text_a,text_b])
def log_rows(self, rows, title, window_id):
def log_summary_rows(self, rows, title, window_id):
self.table(rows, title=title, window_id=window_id)
def flush(self):

View File

@@ -0,0 +1,29 @@
import wandb
from .logger import Logger
class WeightsAndBiasesLogger(Logger):
def __init__(self, filename='', stdout=False):
wandb.init()
def log_attack_result(self, result):
original_text_colored, perturbed_text_colored = result.diff_color(color_method='html')
original_text_colored = wandb.Html(original_text_colored)
perturbed_text_colored = wandb.Html(perturbed_text_colored)
wandb.log({
'type': type(result).__name__,
'original_text': original_text_colored,
'perturbed_text': perturbed_text_colored,
'original_output': result.original_result.output,
'perturbed_output': result.perturbed_result.output,
})
def log_summary_rows(self, rows, title, window_id):
print('w&b skipping summary')
pass
# @TODO: should we log some summary to W&B? It seems to automatically
# calculate its own summary statistics.
def log_sep(self):
self.fout.write('-' * 90 + '\n')

View File

@@ -139,6 +139,9 @@ def get_args():
parser.add_argument('--enable_visdom', action='store_true',
help='Enable logging to visdom.')
parser.add_argument('--enable_wandb', action='store_true',
help='Enable logging to Weights & Biases.')
parser.add_argument('--disable_stdout', action='store_true',
help='Disable logging to stdout')
@@ -276,32 +279,35 @@ def parse_goal_function_and_attack_from_args(args):
return goal_function, attack
def parse_logger_from_args(args):# Create logger
attack_logger = textattack.loggers.AttackLogger()
attack_log_manager = textattack.loggers.AttackLogManager()
# Set default output directory to `textattack/outputs`.
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(current_dir, os.pardir, 'outputs')
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
args.out_dir = outputs_dir
# Output file.
out_time = int(time.time()*1000) # Output file
outfile_name = 'attack-{}.txt'.format(out_time)
attack_logger.add_output_file(os.path.join(args.out_dir, outfile_name))
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
# CSV
if args.enable_csv:
outfile_name = 'attack-{}.csv'.format(out_time)
color_method = None if args.enable_csv == 'plain' else 'file'
csv_path = os.path.join(args.out_dir, outfile_name)
attack_logger.add_output_csv(csv_path, color_method)
attack_log_manager.add_output_csv(csv_path, color_method)
print('Logging to CSV at path {}.'.format(csv_path))
# Visdom
if args.enable_visdom:
attack_logger.enable_visdom()
attack_log_manager.enable_visdom()
# Weights & Biases
if args.enable_wandb:
attack_log_manager.enable_wandb()
# Stdout
if not args.disable_stdout:
attack_logger.enable_stdout()
return attack_logger
attack_log_manager.enable_stdout()
return attack_log_manager

View File

@@ -47,7 +47,7 @@ def run(args):
)
start_time = time.time()
attack_logger = parse_logger_from_args(args)
attack_log_manager = parse_logger_from_args(args)
# We reserve the first GPU for coordinating workers.
num_gpus = torch.cuda.device_count()
@@ -80,7 +80,7 @@ def run(args):
result = out_queue.get(block=True)
if isinstance(result, Exception):
raise result
attack_logger.log_result(result)
attack_log_manager.log_result(result)
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
pbar.update()
num_results += 1
@@ -96,9 +96,9 @@ def run(args):
print()
# Enable summary stdout.
if args.disable_stdout:
attack_logger.enable_stdout()
attack_logger.log_summary()
attack_logger.flush()
attack_log_manager.enable_stdout()
attack_log_manager.log_summary()
attack_log_manager.flush()
print()
finish_time = time.time()
print(f'Attack time: {time.time() - load_time}s')

View File

@@ -27,7 +27,7 @@ def run(args):
print(attack, '\n')
# Logger
attack_logger = parse_logger_from_args(args)
attack_log_manager = parse_logger_from_args(args)
load_time = time.time()
print(f'Load time: {load_time - start_time}s')
@@ -69,7 +69,7 @@ def run(args):
num_examples=args.num_examples,
shuffle=args.shuffle,
attack_n=args.attack_n):
attack_logger.log_result(result)
attack_log_manager.log_result(result)
if not args.disable_stdout:
print('\n')
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
@@ -84,9 +84,9 @@ def run(args):
print()
# Enable summary stdout
if args.disable_stdout:
attack_logger.enable_stdout()
attack_logger.log_summary()
attack_logger.flush()
attack_log_manager.enable_stdout()
attack_log_manager.log_summary()
attack_log_manager.flush()
print()
finish_time = time.time()
print(f'Attack time: {time.time() - load_time}s')