1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/loggers/attack_log_manager.py
2020-10-26 08:18:55 -04:00

170 lines
6.1 KiB
Python

"""
Managing Attack Logs.
========================
"""
import numpy as np
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
class AttackLogManager:
"""Logs the results of an attack to all attached loggers."""
def __init__(self):
self.loggers = []
self.results = []
def enable_stdout(self):
self.loggers.append(FileLogger(stdout=True))
def enable_visdom(self):
self.loggers.append(VisdomLogger())
def enable_wandb(self):
self.loggers.append(WeightsAndBiasesLogger())
def add_output_file(self, filename):
self.loggers.append(FileLogger(filename=filename))
def add_output_csv(self, filename, color_method):
self.loggers.append(CSVLogger(filename=filename, color_method=color_method))
def log_result(self, result):
"""Logs an ``AttackResult`` on each of `self.loggers`."""
self.results.append(result)
for logger in self.loggers:
logger.log_attack_result(result)
def log_results(self, results):
"""Logs an iterable of ``AttackResult`` objects on each of
`self.loggers`."""
for result in results:
self.log_result(result)
self.log_summary()
def log_summary_rows(self, rows, title, window_id):
for logger in self.loggers:
logger.log_summary_rows(rows, title, window_id)
def log_sep(self):
for logger in self.loggers:
logger.log_sep()
def flush(self):
for logger in self.loggers:
logger.flush()
def log_attack_details(self, attack_name, model_name):
# @TODO log a more complete set of attack details
attack_detail_rows = [
["Attack algorithm:", attack_name],
["Model:", model_name],
]
self.log_summary_rows(attack_detail_rows, "Attack Details", "attack_details")
def log_summary(self):
total_attacks = len(self.results)
if total_attacks == 0:
return
# Count things about attacks.
all_num_words = np.zeros(len(self.results))
perturbed_word_percentages = np.zeros(len(self.results))
num_words_changed_until_success = np.zeros(
2 ** 16
) # @ TODO: be smarter about this
failed_attacks = 0
skipped_attacks = 0
successful_attacks = 0
max_words_changed = 0
for i, result in enumerate(self.results):
all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult):
failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
skipped_attacks += 1
continue
else:
successful_attacks += 1
num_words_changed = len(
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
num_words_changed_until_success[num_words_changed - 1] += 1
max_words_changed = max(
max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0
perturbed_word_percentages[i] = perturbed_word_percentage
# Original classifier success rate on these samples.
original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks)
original_accuracy = str(round(original_accuracy, 2)) + "%"
# New classifier success rate on these samples.
accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks)
accuracy_under_attack = str(round(accuracy_under_attack, 2)) + "%"
# Attack success rate.
if successful_attacks + failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
successful_attacks * 100.0 / (successful_attacks + failed_attacks)
)
attack_success_rate = str(round(attack_success_rate, 2)) + "%"
perturbed_word_percentages = perturbed_word_percentages[
perturbed_word_percentages > 0
]
average_perc_words_perturbed = perturbed_word_percentages.mean()
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + "%"
average_num_words = all_num_words.mean()
average_num_words = str(round(average_num_words, 2))
summary_table_rows = [
["Number of successful attacks:", str(successful_attacks)],
["Number of failed attacks:", str(failed_attacks)],
["Number of skipped attacks:", str(skipped_attacks)],
["Original accuracy:", original_accuracy],
["Accuracy under attack:", accuracy_under_attack],
["Attack success rate:", attack_success_rate],
["Average perturbed word %:", average_perc_words_perturbed],
["Average num. words per input:", average_num_words],
]
num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)
avg_num_queries = num_queries.mean()
avg_num_queries = str(round(avg_num_queries, 2))
summary_table_rows.append(["Avg num queries:", avg_num_queries])
self.log_summary_rows(
summary_table_rows, "Attack Results", "attack_results_summary"
)
# Show histogram of words changed.
numbins = max(max_words_changed, 10)
for logger in self.loggers:
logger.log_hist(
num_words_changed_until_success[:numbins],
numbins=numbins,
title="Num Words Perturbed",
window_id="num_words_perturbed",
)