1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/loggers/weights_and_biases_logger.py
2020-10-30 17:31:00 -04:00

75 lines
2.3 KiB
Python

"""
Attack Logs to WandB
========================
"""
from textattack.shared.utils import LazyLoader, html_table_from_rows
from .logger import Logger
class WeightsAndBiasesLogger(Logger):
"""Logs attack results to Weights & Biases."""
def __init__(self, filename="", stdout=False):
global wandb
wandb = LazyLoader("wandb", globals(), "wandb")
wandb.init(project="textattack", resume=True)
self._result_table_rows = []
def __setstate__(self, state):
global wandb
wandb = LazyLoader("wandb", globals(), "wandb")
self.__dict__ = state
wandb.init(project="textattack", resume=True)
def log_summary_rows(self, rows, title, window_id):
table = wandb.Table(columns=["Attack Results", ""])
for row in rows:
table.add_data(*row)
metric_name, metric_score = row
wandb.run.summary[metric_name] = metric_score
wandb.log({"attack_params": table})
def _log_result_table(self):
"""Weights & Biases doesn't have a feature to automatically aggregate
results across timesteps and display the full table.
Therefore, we have to do it manually.
"""
result_table = html_table_from_rows(
self._result_table_rows, header=["", "Original Input", "Perturbed Input"]
)
wandb.log({"results": wandb.Html(result_table)})
def log_attack_result(self, result):
original_text_colored, perturbed_text_colored = result.diff_color(
color_method="html"
)
result_num = len(self._result_table_rows)
self._result_table_rows.append(
[
f"<b>Result {result_num}</b>",
original_text_colored,
perturbed_text_colored,
]
)
result_diff_table = html_table_from_rows(
[[original_text_colored, perturbed_text_colored]]
)
result_diff_table = wandb.Html(result_diff_table)
wandb.log(
{
"result": result_diff_table,
"original_output": result.original_result.output,
"perturbed_output": result.perturbed_result.output,
}
)
self._log_result_table()
def log_sep(self):
self.fout.write("-" * 90 + "\n")