mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
136 lines
4.9 KiB
Python
136 lines
4.9 KiB
Python
import collections
|
|
import random
|
|
|
|
import nlp
|
|
|
|
import textattack
|
|
from textattack.datasets import TextAttackDataset
|
|
from textattack.shared import AttackedText
|
|
|
|
|
|
def _cb(s):
|
|
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
|
|
|
|
|
|
def get_nlp_dataset_columns(dataset):
|
|
schema = set(dataset.schema.names)
|
|
if {"premise", "hypothesis", "label"} <= schema:
|
|
input_columns = ("premise", "hypothesis")
|
|
output_column = "label"
|
|
elif {"question", "sentence", "label"} <= schema:
|
|
input_columns = ("question", "sentence")
|
|
output_column = "label"
|
|
elif {"sentence1", "sentence2", "label"} <= schema:
|
|
input_columns = ("sentence1", "sentence2")
|
|
output_column = "label"
|
|
elif {"question1", "question2", "label"} <= schema:
|
|
input_columns = ("question1", "question2")
|
|
output_column = "label"
|
|
elif {"question", "sentence", "label"} <= schema:
|
|
input_columns = ("question", "sentence")
|
|
output_column = "label"
|
|
elif {"text", "label"} <= schema:
|
|
input_columns = ("text",)
|
|
output_column = "label"
|
|
elif {"sentence", "label"} <= schema:
|
|
input_columns = ("sentence",)
|
|
output_column = "label"
|
|
elif {"document", "summary"} <= schema:
|
|
input_columns = ("document",)
|
|
output_column = "summary"
|
|
elif {"content", "summary"} <= schema:
|
|
input_columns = ("content",)
|
|
output_column = "summary"
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported dataset schema {schema}. Try loading dataset manually (from a file) instead."
|
|
)
|
|
|
|
return input_columns, output_column
|
|
|
|
|
|
class HuggingFaceNLPDataset(TextAttackDataset):
|
|
""" Loads a dataset from HuggingFace ``nlp`` and prepares it as a
|
|
TextAttack dataset.
|
|
|
|
- name: the dataset name
|
|
- subset: the subset of the main dataset. Dataset will be loaded as ``nlp.load_dataset(name, subset)``.
|
|
- label_map: Mapping if output labels should be re-mapped. Useful
|
|
if model was trained with a different label arrangement than
|
|
provided in the ``nlp`` version of the dataset.
|
|
- output_scale_factor (float): Factor to divide ground-truth outputs by.
|
|
Generally, TextAttack goal functions require model outputs
|
|
between 0 and 1. Some datasets test the model's \*correlation\*
|
|
with ground-truth output, instead of its accuracy, so these
|
|
outputs may be scaled arbitrarily.
|
|
- shuffle (bool): Whether to shuffle the dataset on load.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name,
|
|
subset=None,
|
|
split="train",
|
|
label_map=None,
|
|
output_scale_factor=None,
|
|
dataset_columns=None,
|
|
shuffle=False,
|
|
):
|
|
self._name = name
|
|
self._dataset = nlp.load_dataset(name, subset)[split]
|
|
subset_print_str = f", subset {_cb(subset)}" if subset else ""
|
|
textattack.shared.logger.info(
|
|
f"Loading {_cb('nlp')} dataset {_cb(name)}{subset_print_str}, split {_cb(split)}."
|
|
)
|
|
# Input/output column order, like (('premise', 'hypothesis'), 'label')
|
|
(
|
|
self.input_columns,
|
|
self.output_column,
|
|
) = dataset_columns or get_nlp_dataset_columns(self._dataset)
|
|
self._i = 0
|
|
self.examples = list(self._dataset)
|
|
self.label_map = label_map
|
|
self.output_scale_factor = output_scale_factor
|
|
try:
|
|
self.label_names = self._dataset.features["label"].names
|
|
# If labels are remapped, the label names have to be remapped as
|
|
# well.
|
|
if label_map:
|
|
self.label_names = [
|
|
self.label_names[self.label_map[i]]
|
|
for i in range(len(self.label_map))
|
|
]
|
|
except KeyError:
|
|
# This happens when the dataset doesn't have 'features' or a 'label' column.
|
|
self.label_names = None
|
|
except AttributeError:
|
|
# This happens when self._dataset.features["label"] exists
|
|
# but is a single value.
|
|
self.label_names = ("label",)
|
|
if shuffle:
|
|
random.shuffle(self.examples)
|
|
|
|
def _format_raw_example(self, raw_example):
|
|
input_dict = collections.OrderedDict(
|
|
[(c, raw_example[c]) for c in self.input_columns]
|
|
)
|
|
|
|
output = raw_example[self.output_column]
|
|
if self.label_map:
|
|
output = self.label_map[output]
|
|
if self.output_scale_factor:
|
|
output = output / self.output_scale_factor
|
|
|
|
return (input_dict, output)
|
|
|
|
def __next__(self):
|
|
if self._i >= len(self.examples):
|
|
raise StopIteration
|
|
raw_example = self.examples[self._i]
|
|
self._i += 1
|
|
return self._format_raw_example(raw_example)
|
|
|
|
def __getitem__(self, i):
|
|
return self._format_raw_example(self.examples[i])
|