1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix bug with HFDataset

This commit is contained in:
Jin Yong Yoo
2021-01-02 04:01:32 -05:00
parent fccf60dfd4
commit a32263a20e
4 changed files with 18 additions and 13 deletions

View File

@@ -17,4 +17,4 @@ tqdm>=4.27,<4.50.0
word2number
num2words
more-itertools
PySocks!=1.5.7,>=1.5.6
PySocks!=1.5.7,>=1.5.6

View File

@@ -110,16 +110,21 @@ class HuggingFaceDataset(Dataset):
) = dataset_columns or get_datasets_dataset_columns(self._dataset)
self.label_map = label_map
self.output_scale_factor = output_scale_factor
try:
self.label_names = self._dataset.features["label"].names
# If labels are remapped, the label names have to be remapped as well.
if label_map:
self.label_names = [
self.label_names[self.label_map[i]] for i in self.label_map
]
except KeyError:
# This happens when the dataset doesn't have 'features' or a 'label' column.
self.label_names = None
if label_names:
self.label_names = label_names
else:
try:
self.label_names = self._dataset.features[self.output_column].names
except (KeyError, AttributeError):
# This happens when the dataset doesn't have 'features' or a 'label' column.
self.label_names = None
# If labels are remapped, the label names have to be remapped as well.
if self.label_names and label_map:
self.label_names = [
self.label_names[self.label_map[i]] for i in self.label_map
]
self.shuffled = shuffle
if shuffle:
self._dataset.shuffle()

View File

@@ -16,7 +16,7 @@ class CSVLogger(Logger):
"""Logs attack results to a CSV."""
def __init__(self, filename="results.csv", color_method="file"):
logger.info(f"Logging to CSV at path {filename}.")
logger.info(f"Logging to CSV at path {filename}")
self.filename = filename
self.color_method = color_method
self.df = pd.DataFrame()

View File

@@ -26,7 +26,7 @@ class FileLogger(Logger):
if not os.path.exists(directory):
os.makedirs(directory)
self.fout = open(filename, "w")
logger.info(f"Logging to text file at path {filename}.")
logger.info(f"Logging to text file at path {filename}")
else:
self.fout = filename
self.num_results = 0