1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add outputting to csv

This commit is contained in:
uvafan
2019-12-03 19:10:35 -05:00
parent 23a695b7fa
commit f4e833f21a
5 changed files with 47 additions and 3 deletions

View File

@@ -49,6 +49,16 @@ class Attack:
"""
self.logger.add_output_file(filename)
def add_output_csv(self, filename, plain):
"""
When attack runs, it will output to this csv.
Args:
file (str): The path to the output file
"""
self.logger.add_output_csv(filename, plain)
def add_constraint(self, constraint):
"""
Adds a constraint to the attack.

View File

@@ -33,6 +33,9 @@ class AttackLogger:
def add_output_file(self, filename):
self.loggers.append(FileLogger(filename=filename))
def add_output_csv(self, filename, plain):
self.loggers.append(CSVLogger(filename=filename, plain=plain))
def log_skipped(self, tokenized_text):
self.skipped_attacks += 1

View File

@@ -1,3 +1,4 @@
from .logger import Logger
from .visdom_logger import VisdomLogger
from .file_logger import FileLogger
from .csv_logger import CSVLogger

View File

@@ -0,0 +1,20 @@
import sys
import os
import pandas as pd
import csv
from textattack.loggers import Logger
class CSVLogger(Logger):
def __init__(self, filename='results.csv', plain=False):
self.filename = filename
self.plain = plain
self.df = pd.DataFrame()
def log_attack_result(self, result, examples_completed):
color_method = None if self.plain else 'file'
s1, s2 = result.diff_color(color_method)
row = {'passage_1': s1, 'passage_2': s2}
self.df = self.df.append(row, ignore_index=True)
def flush(self):
self.df.to_csv(self.filename, quoting=csv.QUOTE_NONNUMERIC)

View File

@@ -117,7 +117,10 @@ def get_args():
parser.add_argument('--disable_stdout', action='store_true',
help='Disable logging to stdout')
parser.add_argument('--enable_csv', nargs='?', default=None, const='fancy', type=str,
help='Enable logging to csv. Use --enable_csv plain to remove [[]] around words.')
parser.add_argument('--num_examples', '-n', type=int, required=False,
default='5', help='The number of examples to process.')
@@ -234,11 +237,18 @@ if __name__ == '__main__':
attack = parse_attack_from_args()
attack.add_constraints(parse_constraints_from_args())
# Output file
out_time = int(time.time()) # Output file
if args.out_dir is not None:
outfile_name = 'attack-{}.txt'.format(int(time.time()))
outfile_name = 'attack-{}.txt'.format(out_time)
attack.add_output_file(os.path.join(args.out_dir, outfile_name))
# csv
if args.enable_csv:
out_dir = args.out_dir if args.out_dir else 'outputs'
outfile_name = 'attack-{}.csv'.format(out_time)
plain = args.enable_csv == 'plain'
attack.add_output_csv(os.path.join(out_dir, outfile_name), plain)
# Visdom
if args.enable_visdom:
attack.enable_visdom()