mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
AttackLogger --> AttackLogManager; add W&B logging
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -31,3 +31,6 @@ tensorflow-hub
|
||||
# build outputs for PyPI
|
||||
build/
|
||||
dist/
|
||||
|
||||
# Weights & Biases outputs
|
||||
wandb/
|
||||
|
||||
@@ -16,3 +16,4 @@ tensorflow_hub
|
||||
terminaltables
|
||||
tqdm
|
||||
visdom
|
||||
wandb
|
||||
@@ -18,6 +18,16 @@ class GoalFunctionResult:
|
||||
if isinstance(self.score, torch.Tensor):
|
||||
self.score = self.score.item()
|
||||
|
||||
def statistics(self):
|
||||
""" A dictionary of statistics about this result.
|
||||
|
||||
Used by the `AttackLogManager` to print aggregate attack statistics.
|
||||
"""
|
||||
return { # @ TODO implement
|
||||
'num_words_changed', None,
|
||||
'Words changed %', None
|
||||
}
|
||||
|
||||
def get_text_color_input(self):
|
||||
""" A string representing the color this result's changed
|
||||
portion should be if it represents the original input.
|
||||
|
||||
@@ -2,7 +2,7 @@ from .csv_logger import CSVLogger
|
||||
from .file_logger import FileLogger
|
||||
from .logger import Logger
|
||||
from .visdom_logger import VisdomLogger
|
||||
from .weights_and_biases_logger import WeightsAndBiasesLogger
|
||||
|
||||
# The AttackLogger must be imported last,
|
||||
# since it imports the other loggers.
|
||||
from .attack_logger import AttackLogger
|
||||
# AttackLogManager must be imported last, since it imports the other loggers.
|
||||
from .attack_log_manager import AttackLogManager
|
||||
@@ -3,9 +3,9 @@ import torch
|
||||
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
|
||||
from . import CSVLogger, FileLogger, VisdomLogger
|
||||
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
|
||||
|
||||
class AttackLogger:
|
||||
class AttackLogManager:
|
||||
def __init__(self):
|
||||
""" Logs the results of an attack to all attached loggers
|
||||
"""
|
||||
@@ -20,6 +20,9 @@ class AttackLogger:
|
||||
def enable_visdom(self):
|
||||
self.loggers.append(VisdomLogger())
|
||||
|
||||
def enable_wandb(self):
|
||||
self.loggers.append(WeightsAndBiasesLogger())
|
||||
|
||||
def add_output_file(self, filename):
|
||||
self.loggers.append(FileLogger(filename=filename))
|
||||
|
||||
@@ -27,18 +30,22 @@ class AttackLogger:
|
||||
self.loggers.append(CSVLogger(filename=filename, color_method=color_method))
|
||||
|
||||
def log_result(self, result):
|
||||
""" Logs an `AttackResult` on each of `self.loggers`. """
|
||||
self.results.append(result)
|
||||
for logger in self.loggers:
|
||||
logger.log_attack_result(result)
|
||||
|
||||
def log_results(self, results):
|
||||
""" Logs an iterable of `AttackResult` objects on each of
|
||||
`self.loggers`.
|
||||
"""
|
||||
for result in results:
|
||||
self.log_result(result)
|
||||
self.log_summary()
|
||||
|
||||
def _log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
for logger in self.loggers:
|
||||
logger.log_rows(rows, title, window_id)
|
||||
logger.log_summary_rows(rows, title, window_id)
|
||||
|
||||
def log_sep(self):
|
||||
for logger in self.loggers:
|
||||
@@ -49,11 +56,12 @@ class AttackLogger:
|
||||
logger.flush()
|
||||
|
||||
def log_attack_details(self, attack_name, model_name):
|
||||
# @TODO log a more complete set of attack details
|
||||
attack_detail_rows = [
|
||||
['Attack algorithm:', attack_name],
|
||||
['Model:', model_name],
|
||||
]
|
||||
self._log_rows(attack_detail_rows, 'Attack Details', 'attack_details')
|
||||
self.log_summary_rows(attack_detail_rows, 'Attack Details', 'attack_details')
|
||||
|
||||
def log_summary(self):
|
||||
total_attacks = len(self.results)
|
||||
@@ -112,6 +120,8 @@ class AttackLogger:
|
||||
['Number of successful attacks:', str(successful_attacks)],
|
||||
['Number of failed attacks:', str(failed_attacks)],
|
||||
['Number of skipped attacks:', str(skipped_attacks)],
|
||||
### @TODO: everything below this should be computed in a customizable
|
||||
### way via overriding `GoalFunctionResult.statistics`.
|
||||
['Original accuracy:', original_accuracy],
|
||||
['Accuracy under attack:', accuracy_under_attack],
|
||||
['Attack success rate:', attack_success_rate],
|
||||
@@ -123,9 +133,10 @@ class AttackLogger:
|
||||
avg_num_queries = num_queries.mean()
|
||||
avg_num_queries = str(round(avg_num_queries, 2))
|
||||
summary_table_rows.append(['Avg num queries:', avg_num_queries])
|
||||
self._log_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
|
||||
self.log_summary_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
|
||||
# Show histogram of words changed.
|
||||
numbins = max(self.max_words_changed, 10)
|
||||
for logger in self.loggers:
|
||||
logger.log_hist(num_words_changed_until_success[:numbins],
|
||||
numbins=numbins, title='Num Words Perturbed', window_id='num_words_perturbed')
|
||||
|
||||
@@ -25,7 +25,7 @@ class FileLogger(Logger):
|
||||
self.fout.write(result.__str__(color_method=color_method))
|
||||
self.fout.write('\n')
|
||||
|
||||
def log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
if self.stdout:
|
||||
table_rows = [[title, '']] + rows
|
||||
table = terminaltables.SingleTable(table_rows)
|
||||
|
||||
@@ -5,7 +5,7 @@ class Logger:
|
||||
def log_attack_result(self, result, examples_completed):
|
||||
pass
|
||||
|
||||
def log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
pass
|
||||
|
||||
def log_hist(self, arr, numbins, title, window_id):
|
||||
|
||||
@@ -34,7 +34,7 @@ class VisdomLogger(Logger):
|
||||
result_str = result.goal_function_result_str(color_method='html')
|
||||
self.sample_rows.append([result_str,text_a,text_b])
|
||||
|
||||
def log_rows(self, rows, title, window_id):
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
self.table(rows, title=title, window_id=window_id)
|
||||
|
||||
def flush(self):
|
||||
|
||||
29
textattack/loggers/weights_and_biases_logger.py
Normal file
29
textattack/loggers/weights_and_biases_logger.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import wandb
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
class WeightsAndBiasesLogger(Logger):
|
||||
def __init__(self, filename='', stdout=False):
|
||||
wandb.init()
|
||||
|
||||
def log_attack_result(self, result):
|
||||
original_text_colored, perturbed_text_colored = result.diff_color(color_method='html')
|
||||
original_text_colored = wandb.Html(original_text_colored)
|
||||
perturbed_text_colored = wandb.Html(perturbed_text_colored)
|
||||
wandb.log({
|
||||
'type': type(result).__name__,
|
||||
'original_text': original_text_colored,
|
||||
'perturbed_text': perturbed_text_colored,
|
||||
'original_output': result.original_result.output,
|
||||
'perturbed_output': result.perturbed_result.output,
|
||||
})
|
||||
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
print('w&b skipping summary')
|
||||
pass
|
||||
# @TODO: should we log some summary to W&B? It seems to automatically
|
||||
# calculate its own summary statistics.
|
||||
|
||||
def log_sep(self):
|
||||
self.fout.write('-' * 90 + '\n')
|
||||
|
||||
@@ -139,6 +139,9 @@ def get_args():
|
||||
parser.add_argument('--enable_visdom', action='store_true',
|
||||
help='Enable logging to visdom.')
|
||||
|
||||
parser.add_argument('--enable_wandb', action='store_true',
|
||||
help='Enable logging to Weights & Biases.')
|
||||
|
||||
parser.add_argument('--disable_stdout', action='store_true',
|
||||
help='Disable logging to stdout')
|
||||
|
||||
@@ -276,32 +279,35 @@ def parse_goal_function_and_attack_from_args(args):
|
||||
return goal_function, attack
|
||||
|
||||
def parse_logger_from_args(args):# Create logger
|
||||
attack_logger = textattack.loggers.AttackLogger()
|
||||
attack_log_manager = textattack.loggers.AttackLogManager()
|
||||
# Set default output directory to `textattack/outputs`.
|
||||
if not args.out_dir:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(current_dir, os.pardir, 'outputs')
|
||||
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
|
||||
args.out_dir = outputs_dir
|
||||
|
||||
# Output file.
|
||||
out_time = int(time.time()*1000) # Output file
|
||||
outfile_name = 'attack-{}.txt'.format(out_time)
|
||||
attack_logger.add_output_file(os.path.join(args.out_dir, outfile_name))
|
||||
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
|
||||
|
||||
# CSV
|
||||
if args.enable_csv:
|
||||
outfile_name = 'attack-{}.csv'.format(out_time)
|
||||
color_method = None if args.enable_csv == 'plain' else 'file'
|
||||
csv_path = os.path.join(args.out_dir, outfile_name)
|
||||
attack_logger.add_output_csv(csv_path, color_method)
|
||||
attack_log_manager.add_output_csv(csv_path, color_method)
|
||||
print('Logging to CSV at path {}.'.format(csv_path))
|
||||
|
||||
|
||||
# Visdom
|
||||
if args.enable_visdom:
|
||||
attack_logger.enable_visdom()
|
||||
attack_log_manager.enable_visdom()
|
||||
|
||||
# Weights & Biases
|
||||
if args.enable_wandb:
|
||||
attack_log_manager.enable_wandb()
|
||||
|
||||
# Stdout
|
||||
if not args.disable_stdout:
|
||||
attack_logger.enable_stdout()
|
||||
return attack_logger
|
||||
attack_log_manager.enable_stdout()
|
||||
return attack_log_manager
|
||||
|
||||
@@ -47,7 +47,7 @@ def run(args):
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
attack_logger = parse_logger_from_args(args)
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
# We reserve the first GPU for coordinating workers.
|
||||
num_gpus = torch.cuda.device_count()
|
||||
@@ -80,7 +80,7 @@ def run(args):
|
||||
result = out_queue.get(block=True)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
attack_logger.log_result(result)
|
||||
attack_log_manager.log_result(result)
|
||||
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
|
||||
pbar.update()
|
||||
num_results += 1
|
||||
@@ -96,9 +96,9 @@ def run(args):
|
||||
print()
|
||||
# Enable summary stdout.
|
||||
if args.disable_stdout:
|
||||
attack_logger.enable_stdout()
|
||||
attack_logger.log_summary()
|
||||
attack_logger.flush()
|
||||
attack_log_manager.enable_stdout()
|
||||
attack_log_manager.log_summary()
|
||||
attack_log_manager.flush()
|
||||
print()
|
||||
finish_time = time.time()
|
||||
print(f'Attack time: {time.time() - load_time}s')
|
||||
|
||||
@@ -27,7 +27,7 @@ def run(args):
|
||||
print(attack, '\n')
|
||||
|
||||
# Logger
|
||||
attack_logger = parse_logger_from_args(args)
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
load_time = time.time()
|
||||
print(f'Load time: {load_time - start_time}s')
|
||||
@@ -69,7 +69,7 @@ def run(args):
|
||||
num_examples=args.num_examples,
|
||||
shuffle=args.shuffle,
|
||||
attack_n=args.attack_n):
|
||||
attack_logger.log_result(result)
|
||||
attack_log_manager.log_result(result)
|
||||
if not args.disable_stdout:
|
||||
print('\n')
|
||||
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
|
||||
@@ -84,9 +84,9 @@ def run(args):
|
||||
print()
|
||||
# Enable summary stdout
|
||||
if args.disable_stdout:
|
||||
attack_logger.enable_stdout()
|
||||
attack_logger.log_summary()
|
||||
attack_logger.flush()
|
||||
attack_log_manager.enable_stdout()
|
||||
attack_log_manager.log_summary()
|
||||
attack_log_manager.flush()
|
||||
print()
|
||||
finish_time = time.time()
|
||||
print(f'Attack time: {time.time() - load_time}s')
|
||||
|
||||
Reference in New Issue
Block a user