mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
"""
|
|
Attack Logs to Visdom
|
|
========================
|
|
"""
|
|
|
|
|
|
import socket
|
|
|
|
from textattack.shared.utils import LazyLoader, html_table_from_rows
|
|
|
|
from .logger import Logger
|
|
|
|
visdom = LazyLoader("visdom", globals(), "visdom")
|
|
|
|
|
|
def port_is_open(port_num, hostname="127.0.0.1"):
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
result = sock.connect_ex((hostname, port_num))
|
|
sock.close()
|
|
if result == 0:
|
|
return True
|
|
return False
|
|
|
|
|
|
class VisdomLogger(Logger):
|
|
"""Logs attack results to Visdom."""
|
|
|
|
def __init__(self, env="main", port=8097, hostname="localhost"):
|
|
if not port_is_open(port, hostname=hostname):
|
|
raise socket.error(f"Visdom not running on {hostname}:{port}")
|
|
self.vis = visdom.Visdom(port=port, server=hostname, env=env)
|
|
self.env = env
|
|
self.port = port
|
|
self.hostname = hostname
|
|
self.windows = {}
|
|
self.sample_rows = []
|
|
|
|
def __getstate__(self):
|
|
state = {i: self.__dict__[i] for i in self.__dict__ if i != "vis"}
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__ = state
|
|
self.vis = visdom.Visdom(port=self.port, server=self.hostname, env=self.env)
|
|
|
|
def log_attack_result(self, result):
|
|
text_a, text_b = result.diff_color(color_method="html")
|
|
result_str = result.goal_function_result_str(color_method="html")
|
|
self.sample_rows.append([result_str, text_a, text_b])
|
|
|
|
def log_summary_rows(self, rows, title, window_id):
|
|
self.table(rows, title=title, window_id=window_id)
|
|
|
|
def flush(self):
|
|
self.table(
|
|
self.sample_rows,
|
|
title="Sample-Level Results",
|
|
window_id="sample_level_results",
|
|
)
|
|
|
|
def log_hist(self, arr, numbins, title, window_id):
|
|
self.bar(arr, numbins=numbins, title=title, window_id=window_id)
|
|
|
|
def text(self, text_data, title=None, window_id="default"):
|
|
if window_id and window_id in self.windows:
|
|
window = self.windows[window_id]
|
|
self.vis.text(text_data, win=window)
|
|
else:
|
|
new_window = self.vis.text(text_data, opts=dict(title=title))
|
|
self.windows[window_id] = new_window
|
|
|
|
def table(self, rows, window_id=None, title=None, header=None, style=None):
|
|
"""Generates an HTML table."""
|
|
|
|
if not window_id:
|
|
window_id = title # Can provide either of these,
|
|
if not title:
|
|
title = window_id # or both.
|
|
table = html_table_from_rows(rows, title=title, header=header, style_dict=style)
|
|
self.text(table, title=title, window_id=window_id)
|
|
|
|
def bar(self, X_data, numbins=10, title=None, window_id=None):
|
|
window = None
|
|
if window_id and window_id in self.windows:
|
|
window = self.windows[window_id]
|
|
self.vis.bar(X=X_data, win=window, opts=dict(title=title, numbins=numbins))
|
|
else:
|
|
new_window = self.vis.bar(X=X_data, opts=dict(title=title, numbins=numbins))
|
|
if window_id:
|
|
self.windows[window_id] = new_window
|
|
|
|
def hist(self, X_data, numbins=10, title=None, window_id=None):
|
|
window = None
|
|
if window_id and window_id in self.windows:
|
|
window = self.windows[window_id]
|
|
self.vis.histogram(
|
|
X=X_data, win=window, opts=dict(title=title, numbins=numbins)
|
|
)
|
|
else:
|
|
new_window = self.vis.histogram(
|
|
X=X_data, opts=dict(title=title, numbins=numbins)
|
|
)
|
|
if window_id:
|
|
self.windows[window_id] = new_window
|