1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/commands/peek_dataset.py
2020-06-23 21:45:06 -04:00

74 lines
2.6 KiB
Python

from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import textattack
import numpy as np
import re
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args_helpers import add_dataset_args, parse_dataset_from_args
def _cb(s):
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
logger = textattack.shared.logger
import re
class PeekDatasetCommand(TextAttackCommand):
"""
The peek dataset module:
Takes a peek into a dataset in textattack.
"""
def run(self, args):
UPPERCASE_LETTERS_REGEX = re.compile('[A-Z]')
args.model = None # set model to None for parse_dataset_from_args to work
dataset = parse_dataset_from_args(args)
num_words = []
attacked_texts = []
data_all_lowercased = True
outputs = []
for inputs, output in dataset:
at = textattack.shared.AttackedText(inputs)
if data_all_lowercased:
# Test if any of the letters in the string are lowercase.
if re.search(UPPERCASE_LETTERS_REGEX, at.text):
data_all_lowercased = False
attacked_texts.append(at)
num_words.append(len(at.words))
outputs.append(output)
logger.info(f'Number of samples: {_cb(len(attacked_texts))}')
logger.info(f'Number of words per input:')
num_words = np.array(num_words)
logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}')
mean_words = f'{num_words.mean():.2f}'
logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}')
std_words = f'{num_words.std():.2f}'
logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}')
logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}')
logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}')
logger.info(f'Dataset lowercased: {_cb(data_all_lowercased)}')
logger.info('First sample:')
print(attacked_texts[0].printable_text(), '\n')
logger.info('Last sample:')
print(attacked_texts[-1].printable_text(), '\n')
outputs = set(outputs)
logger.info(f'Found {len(outputs)} distinct outputs:')
print(sorted(outputs))
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"peek-dataset", help="show main statistics about a dataset",
formatter_class=ArgumentDefaultsHelpFormatter,
)
add_dataset_args(parser)
parser.set_defaults(func=PeekDatasetCommand())