mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add outputting to csv
This commit is contained in:
@@ -49,6 +49,16 @@ class Attack:
|
|||||||
"""
|
"""
|
||||||
self.logger.add_output_file(filename)
|
self.logger.add_output_file(filename)
|
||||||
|
|
||||||
|
def add_output_csv(self, filename, plain):
|
||||||
|
"""
|
||||||
|
When attack runs, it will output to this csv.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): The path to the output file
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.logger.add_output_csv(filename, plain)
|
||||||
|
|
||||||
def add_constraint(self, constraint):
|
def add_constraint(self, constraint):
|
||||||
"""
|
"""
|
||||||
Adds a constraint to the attack.
|
Adds a constraint to the attack.
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ class AttackLogger:
|
|||||||
def add_output_file(self, filename):
|
def add_output_file(self, filename):
|
||||||
self.loggers.append(FileLogger(filename=filename))
|
self.loggers.append(FileLogger(filename=filename))
|
||||||
|
|
||||||
|
def add_output_csv(self, filename, plain):
|
||||||
|
self.loggers.append(CSVLogger(filename=filename, plain=plain))
|
||||||
|
|
||||||
def log_skipped(self, tokenized_text):
|
def log_skipped(self, tokenized_text):
|
||||||
self.skipped_attacks += 1
|
self.skipped_attacks += 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
from .logger import Logger
|
from .logger import Logger
|
||||||
from .visdom_logger import VisdomLogger
|
from .visdom_logger import VisdomLogger
|
||||||
from .file_logger import FileLogger
|
from .file_logger import FileLogger
|
||||||
|
from .csv_logger import CSVLogger
|
||||||
|
|||||||
20
textattack/loggers/csv_logger.py
Normal file
20
textattack/loggers/csv_logger.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import csv
|
||||||
|
from textattack.loggers import Logger
|
||||||
|
|
||||||
|
class CSVLogger(Logger):
|
||||||
|
def __init__(self, filename='results.csv', plain=False):
|
||||||
|
self.filename = filename
|
||||||
|
self.plain = plain
|
||||||
|
self.df = pd.DataFrame()
|
||||||
|
|
||||||
|
def log_attack_result(self, result, examples_completed):
|
||||||
|
color_method = None if self.plain else 'file'
|
||||||
|
s1, s2 = result.diff_color(color_method)
|
||||||
|
row = {'passage_1': s1, 'passage_2': s2}
|
||||||
|
self.df = self.df.append(row, ignore_index=True)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.df.to_csv(self.filename, quoting=csv.QUOTE_NONNUMERIC)
|
||||||
@@ -117,7 +117,10 @@ def get_args():
|
|||||||
|
|
||||||
parser.add_argument('--disable_stdout', action='store_true',
|
parser.add_argument('--disable_stdout', action='store_true',
|
||||||
help='Disable logging to stdout')
|
help='Disable logging to stdout')
|
||||||
|
|
||||||
|
parser.add_argument('--enable_csv', nargs='?', default=None, const='fancy', type=str,
|
||||||
|
help='Enable logging to csv. Use --enable_csv plain to remove [[]] around words.')
|
||||||
|
|
||||||
parser.add_argument('--num_examples', '-n', type=int, required=False,
|
parser.add_argument('--num_examples', '-n', type=int, required=False,
|
||||||
default='5', help='The number of examples to process.')
|
default='5', help='The number of examples to process.')
|
||||||
|
|
||||||
@@ -234,11 +237,18 @@ if __name__ == '__main__':
|
|||||||
attack = parse_attack_from_args()
|
attack = parse_attack_from_args()
|
||||||
attack.add_constraints(parse_constraints_from_args())
|
attack.add_constraints(parse_constraints_from_args())
|
||||||
|
|
||||||
# Output file
|
out_time = int(time.time()) # Output file
|
||||||
if args.out_dir is not None:
|
if args.out_dir is not None:
|
||||||
outfile_name = 'attack-{}.txt'.format(int(time.time()))
|
outfile_name = 'attack-{}.txt'.format(out_time)
|
||||||
attack.add_output_file(os.path.join(args.out_dir, outfile_name))
|
attack.add_output_file(os.path.join(args.out_dir, outfile_name))
|
||||||
|
|
||||||
|
# csv
|
||||||
|
if args.enable_csv:
|
||||||
|
out_dir = args.out_dir if args.out_dir else 'outputs'
|
||||||
|
outfile_name = 'attack-{}.csv'.format(out_time)
|
||||||
|
plain = args.enable_csv == 'plain'
|
||||||
|
attack.add_output_csv(os.path.join(out_dir, outfile_name), plain)
|
||||||
|
|
||||||
# Visdom
|
# Visdom
|
||||||
if args.enable_visdom:
|
if args.enable_visdom:
|
||||||
attack.enable_visdom()
|
attack.enable_visdom()
|
||||||
|
|||||||
Reference in New Issue
Block a user