mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add outputting to csv
This commit is contained in:
20
textattack/loggers/csv_logger.py
Normal file
20
textattack/loggers/csv_logger.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import sys
|
||||
import os
|
||||
import pandas as pd
|
||||
import csv
|
||||
from textattack.loggers import Logger
|
||||
|
||||
class CSVLogger(Logger):
|
||||
def __init__(self, filename='results.csv', plain=False):
|
||||
self.filename = filename
|
||||
self.plain = plain
|
||||
self.df = pd.DataFrame()
|
||||
|
||||
def log_attack_result(self, result, examples_completed):
|
||||
color_method = None if self.plain else 'file'
|
||||
s1, s2 = result.diff_color(color_method)
|
||||
row = {'passage_1': s1, 'passage_2': s2}
|
||||
self.df = self.df.append(row, ignore_index=True)
|
||||
|
||||
def flush(self):
|
||||
self.df.to_csv(self.filename, quoting=csv.QUOTE_NONNUMERIC)
|
||||
Reference in New Issue
Block a user