mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
142 lines
6.0 KiB
Python
142 lines
6.0 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
|
|
|
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
|
|
|
|
class AttackLogManager:
|
|
def __init__(self):
|
|
""" Logs the results of an attack to all attached loggers
|
|
"""
|
|
self.loggers = []
|
|
self.results = []
|
|
self.max_words_changed = 0
|
|
self.max_seq_len = 2**16
|
|
|
|
def enable_stdout(self):
|
|
self.loggers.append(FileLogger(stdout=True))
|
|
|
|
def enable_visdom(self):
|
|
self.loggers.append(VisdomLogger())
|
|
|
|
def enable_wandb(self):
|
|
self.loggers.append(WeightsAndBiasesLogger())
|
|
|
|
def add_output_file(self, filename):
|
|
self.loggers.append(FileLogger(filename=filename))
|
|
|
|
def add_output_csv(self, filename, color_method):
|
|
self.loggers.append(CSVLogger(filename=filename, color_method=color_method))
|
|
|
|
def log_result(self, result):
|
|
""" Logs an `AttackResult` on each of `self.loggers`. """
|
|
self.results.append(result)
|
|
for logger in self.loggers:
|
|
logger.log_attack_result(result)
|
|
|
|
def log_results(self, results):
|
|
""" Logs an iterable of `AttackResult` objects on each of
|
|
`self.loggers`.
|
|
"""
|
|
for result in results:
|
|
self.log_result(result)
|
|
self.log_summary()
|
|
|
|
def log_summary_rows(self, rows, title, window_id):
|
|
for logger in self.loggers:
|
|
logger.log_summary_rows(rows, title, window_id)
|
|
|
|
def log_sep(self):
|
|
for logger in self.loggers:
|
|
logger.log_sep()
|
|
|
|
def flush(self):
|
|
for logger in self.loggers:
|
|
logger.flush()
|
|
|
|
def log_attack_details(self, attack_name, model_name):
|
|
# @TODO log a more complete set of attack details
|
|
attack_detail_rows = [
|
|
['Attack algorithm:', attack_name],
|
|
['Model:', model_name],
|
|
]
|
|
self.log_summary_rows(attack_detail_rows, 'Attack Details', 'attack_details')
|
|
|
|
def log_summary(self):
|
|
total_attacks = len(self.results)
|
|
if total_attacks == 0:
|
|
return
|
|
# Count things about attacks.
|
|
all_num_words = np.zeros(len(self.results))
|
|
perturbed_word_percentages = np.zeros(len(self.results))
|
|
num_words_changed_until_success = np.zeros(self.max_seq_len)
|
|
failed_attacks = 0
|
|
skipped_attacks = 0
|
|
successful_attacks = 0
|
|
max_words_changed = None
|
|
for i, result in enumerate(self.results):
|
|
all_num_words[i] = len(result.original_result.tokenized_text.words)
|
|
if isinstance(result, FailedAttackResult):
|
|
failed_attacks += 1
|
|
continue
|
|
elif isinstance(result, SkippedAttackResult):
|
|
skipped_attacks += 1
|
|
continue
|
|
else:
|
|
successful_attacks += 1
|
|
num_words_changed = len(result.original_result.tokenized_text.all_words_diff(result.perturbed_result.tokenized_text))
|
|
num_words_changed_until_success[num_words_changed-1] += 1
|
|
max_words_changed = max(max_words_changed or num_words_changed, num_words_changed)
|
|
if len(result.original_result.tokenized_text.words) > 0:
|
|
perturbed_word_percentage = num_words_changed * 100.0 / len(result.original_result.tokenized_text.words)
|
|
else:
|
|
perturbed_word_percentage = 0
|
|
perturbed_word_percentages[i] = perturbed_word_percentage
|
|
|
|
# Original classifier success rate on these samples.
|
|
original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks)
|
|
original_accuracy = str(round(original_accuracy, 2)) + '%'
|
|
|
|
# New classifier success rate on these samples.
|
|
accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks)
|
|
accuracy_under_attack = str(round(accuracy_under_attack, 2)) + '%'
|
|
|
|
# Attack success rate.
|
|
if successful_attacks + failed_attacks == 0:
|
|
attack_success_rate = 0
|
|
else:
|
|
attack_success_rate = successful_attacks * 100.0 / (successful_attacks + failed_attacks)
|
|
attack_success_rate = str(round(attack_success_rate, 2)) + '%'
|
|
|
|
perturbed_word_percentages = perturbed_word_percentages[perturbed_word_percentages > 0]
|
|
average_perc_words_perturbed = perturbed_word_percentages.mean()
|
|
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + '%'
|
|
|
|
average_num_words = all_num_words.mean()
|
|
average_num_words = str(round(average_num_words, 2))
|
|
|
|
summary_table_rows = [
|
|
['Number of successful attacks:', str(successful_attacks)],
|
|
['Number of failed attacks:', str(failed_attacks)],
|
|
['Number of skipped attacks:', str(skipped_attacks)],
|
|
### @TODO: everything below this should be computed in a customizable
|
|
### way via overriding `GoalFunctionResult.statistics`.
|
|
['Original accuracy:', original_accuracy],
|
|
['Accuracy under attack:', accuracy_under_attack],
|
|
['Attack success rate:', attack_success_rate],
|
|
['Average perturbed word %:', average_perc_words_perturbed],
|
|
['Average num. words per input:', average_num_words],
|
|
]
|
|
|
|
num_queries = np.array([r.num_queries for r in self.results if not isinstance(r, SkippedAttackResult)])
|
|
avg_num_queries = num_queries.mean()
|
|
avg_num_queries = str(round(avg_num_queries, 2))
|
|
summary_table_rows.append(['Avg num queries:', avg_num_queries])
|
|
self.log_summary_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
|
|
# Show histogram of words changed.
|
|
numbins = max(self.max_words_changed, 10)
|
|
for logger in self.loggers:
|
|
logger.log_hist(num_words_changed_until_success[:numbins],
|
|
numbins=numbins, title='Num Words Perturbed', window_id='num_words_perturbed')
|
|
|