1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

improve logging

This commit is contained in:
Jack Morris
2020-02-07 12:38:29 -05:00
parent 7677330113
commit a35073a6a5
4 changed files with 15 additions and 6 deletions

View File

@@ -8,7 +8,8 @@ sentence_transformers
spacy
torch
transformers>=2.0.0
tensorflow-gpu==2
tensorflow-gpu>=2
tensorflow_hub
terminaltables
tqdm
visdom

View File

@@ -71,7 +71,8 @@ def main():
if not args.disable_stdout:
print('\n')
else:
pbar.update(1)
if not isinstance(result, textattack.attack_results.SkippedAttackResult):
pbar.update(1)
pbar.close()
print()
# Enable summary stdout

View File

@@ -111,6 +111,10 @@ class AttackLogger:
attack_success_rate = str(round(attack_success_rate, 2)) + '%'
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + '%'
all_num_words = np.array([len(result.original_text.words) for result in self.results])
average_num_words = all_num_words.mean()
average_num_words = str(round(average_num_words, 2))
summary_table_rows = [
['Number of successful attacks:', str(self.successful_attacks)],
['Number of failed attacks:', str(self.failed_attacks)],
@@ -119,11 +123,12 @@ class AttackLogger:
['Accuracy under attack:', accuracy_under_attack],
['Attack success rate:', attack_success_rate],
['Average perturbed word %:', average_perc_words_perturbed],
['Average num. words', average_num_words],
]
num_queries = [r.num_queries for r in self.results]
avg_num_queries = statistics.mean(num_queries) if len(num_queries) else 0
avg_num_queries = str(round(avg_num_queries, 2))
summary_table_rows.append(['Avg num queries:', avg_num_queries])
self._log_rows(summary_table_rows, 'Attack Results Summary', 'attack_results_summary')
self._log_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
self._log_num_words_changed()

View File

@@ -1,5 +1,6 @@
import sys
import os
import sys
import terminaltables
from .logger import Logger
@@ -23,8 +24,9 @@ class FileLogger(Logger):
self.fout.write('\n')
def log_rows(self, rows, title, window_id):
for row in rows:
self.fout.write(f'{row[0]} {row[1]}\n')
table_rows = [[title, '']] + rows
table = terminaltables.SingleTable(table_rows)
self.fout.write(table.table)
def log_sep(self):
self.fout.write('-' * 90 + '\n')