mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
from textattack.datasets import TextAttackDataset
|
|
from textattack.shared import TokenizedText
|
|
|
|
import nlp
|
|
import random
|
|
|
|
def get_nlp_dataset_columns(dataset):
|
|
schema = set(dataset.schema.names)
|
|
|
|
if {'premise', 'hypothesis', 'label'} <= schema:
|
|
input_columns = ('premise', 'hypothesis')
|
|
output_column = 'label'
|
|
elif {'sentence', 'label'} <= schema:
|
|
input_columns = ('sentence',)
|
|
output_column = 'label'
|
|
elif {'text', 'label'} <= schema:
|
|
input_columns = ('text',)
|
|
output_column = 'label'
|
|
elif {'sentence1', 'sentence2', 'label'} <= schema:
|
|
input_columns = {'sentence1', 'sentence2'}
|
|
output_column = 'label'
|
|
elif {'question1', 'question2', 'label'} <= schema:
|
|
input_columns = {'question1', 'question2'}
|
|
elif {'question', 'sentence', 'label'} <= schema:
|
|
input_columns = {'question', 'sentence'}
|
|
output_column = 'label'
|
|
else:
|
|
raise ValueError(f'Unsupported dataset schema {schema}. Try loading dataset manually (from a file) instead.')
|
|
|
|
return input_columns, output_column
|
|
|
|
class HuggingFaceNLPDataset(TextAttackDataset):
|
|
""" Loads a dataset from HuggingFace ``nlp`` and prepares it as a
|
|
TextAttack dataset.
|
|
"""
|
|
def __init__(self, dataset_args, split='train', shuffle=False):
|
|
dataset = nlp.load_dataset(*dataset_args)
|
|
self.input_columns, self.output_columns = get_nlp_dataset_columns(dataset[split])
|
|
self._i = 0
|
|
self.examples = list(dataset[split])
|
|
if shuffle:
|
|
random.shuffle(self.examples)
|
|
|
|
def __next__(self):
|
|
if self._i >= len(self.examples):
|
|
raise StopIteration
|
|
raw_example = self.examples[self._i]
|
|
self._i += 1
|
|
joined_input = TokenizedText.SPLIT_TOKEN.join(raw_example[c] for c in self.input_columns)
|
|
output = raw_example[self.output_columns]
|
|
return (joined_input, output) |