mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
import collections
|
|
|
|
from textattack.datasets import TextAttackDataset
|
|
from textattack.shared import AttackedText
|
|
|
|
|
|
class EntailmentDataset(TextAttackDataset):
|
|
"""
|
|
A generic class for loading entailment data.
|
|
|
|
Labels
|
|
0: Entailment
|
|
1: Neutral
|
|
2: Contradiction
|
|
"""
|
|
|
|
def _label_str_to_int(self, label_str):
|
|
if label_str == "entailment":
|
|
return 0
|
|
elif label_str == "neutral":
|
|
return 1
|
|
elif label_str == "contradiction":
|
|
return 2
|
|
else:
|
|
raise ValueError(f"Unknown entailment label {label_str}")
|
|
|
|
def _process_example_from_file(self, raw_line):
|
|
line = raw_line.strip()
|
|
label, premise, hypothesis = line.split("\t")
|
|
try:
|
|
label = int(label)
|
|
except ValueError:
|
|
# If the label is not an integer, it's a label description.
|
|
label = self._label_str_to_int(label)
|
|
text_input = collections.OrderedDict(
|
|
[("premise", premise), ("hypothesis", hypothesis),]
|
|
)
|
|
return (text_input, label)
|