mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
improve logging
This commit is contained in:
@@ -8,7 +8,8 @@ sentence_transformers
|
||||
spacy
|
||||
torch
|
||||
transformers>=2.0.0
|
||||
tensorflow-gpu==2
|
||||
tensorflow-gpu>=2
|
||||
tensorflow_hub
|
||||
terminaltables
|
||||
tqdm
|
||||
visdom
|
||||
|
||||
@@ -71,7 +71,8 @@ def main():
|
||||
if not args.disable_stdout:
|
||||
print('\n')
|
||||
else:
|
||||
pbar.update(1)
|
||||
if not isinstance(result, textattack.attack_results.SkippedAttackResult):
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
print()
|
||||
# Enable summary stdout
|
||||
|
||||
@@ -111,6 +111,10 @@ class AttackLogger:
|
||||
attack_success_rate = str(round(attack_success_rate, 2)) + '%'
|
||||
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + '%'
|
||||
|
||||
all_num_words = np.array([len(result.original_text.words) for result in self.results])
|
||||
average_num_words = all_num_words.mean()
|
||||
average_num_words = str(round(average_num_words, 2))
|
||||
|
||||
summary_table_rows = [
|
||||
['Number of successful attacks:', str(self.successful_attacks)],
|
||||
['Number of failed attacks:', str(self.failed_attacks)],
|
||||
@@ -119,11 +123,12 @@ class AttackLogger:
|
||||
['Accuracy under attack:', accuracy_under_attack],
|
||||
['Attack success rate:', attack_success_rate],
|
||||
['Average perturbed word %:', average_perc_words_perturbed],
|
||||
['Average num. words', average_num_words],
|
||||
]
|
||||
|
||||
num_queries = [r.num_queries for r in self.results]
|
||||
avg_num_queries = statistics.mean(num_queries) if len(num_queries) else 0
|
||||
avg_num_queries = str(round(avg_num_queries, 2))
|
||||
summary_table_rows.append(['Avg num queries:', avg_num_queries])
|
||||
self._log_rows(summary_table_rows, 'Attack Results Summary', 'attack_results_summary')
|
||||
self._log_rows(summary_table_rows, 'Attack Results', 'attack_results_summary')
|
||||
self._log_num_words_changed()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
import terminaltables
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
@@ -23,8 +24,9 @@ class FileLogger(Logger):
|
||||
self.fout.write('\n')
|
||||
|
||||
def log_rows(self, rows, title, window_id):
|
||||
for row in rows:
|
||||
self.fout.write(f'{row[0]} {row[1]}\n')
|
||||
table_rows = [[title, '']] + rows
|
||||
table = terminaltables.SingleTable(table_rows)
|
||||
self.fout.write(table.table)
|
||||
|
||||
def log_sep(self):
|
||||
self.fout.write('-' * 90 + '\n')
|
||||
|
||||
Reference in New Issue
Block a user