mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix bug with HFDataset
This commit is contained in:
@@ -17,4 +17,4 @@ tqdm>=4.27,<4.50.0
|
||||
word2number
|
||||
num2words
|
||||
more-itertools
|
||||
PySocks!=1.5.7,>=1.5.6
|
||||
PySocks!=1.5.7,>=1.5.6
|
||||
@@ -110,16 +110,21 @@ class HuggingFaceDataset(Dataset):
|
||||
) = dataset_columns or get_datasets_dataset_columns(self._dataset)
|
||||
self.label_map = label_map
|
||||
self.output_scale_factor = output_scale_factor
|
||||
try:
|
||||
self.label_names = self._dataset.features["label"].names
|
||||
# If labels are remapped, the label names have to be remapped as well.
|
||||
if label_map:
|
||||
self.label_names = [
|
||||
self.label_names[self.label_map[i]] for i in self.label_map
|
||||
]
|
||||
except KeyError:
|
||||
# This happens when the dataset doesn't have 'features' or a 'label' column.
|
||||
self.label_names = None
|
||||
if label_names:
|
||||
self.label_names = label_names
|
||||
else:
|
||||
try:
|
||||
self.label_names = self._dataset.features[self.output_column].names
|
||||
except (KeyError, AttributeError):
|
||||
# This happens when the dataset doesn't have 'features' or a 'label' column.
|
||||
self.label_names = None
|
||||
|
||||
# If labels are remapped, the label names have to be remapped as well.
|
||||
if self.label_names and label_map:
|
||||
self.label_names = [
|
||||
self.label_names[self.label_map[i]] for i in self.label_map
|
||||
]
|
||||
|
||||
self.shuffled = shuffle
|
||||
if shuffle:
|
||||
self._dataset.shuffle()
|
||||
|
||||
@@ -16,7 +16,7 @@ class CSVLogger(Logger):
|
||||
"""Logs attack results to a CSV."""
|
||||
|
||||
def __init__(self, filename="results.csv", color_method="file"):
|
||||
logger.info(f"Logging to CSV at path {filename}.")
|
||||
logger.info(f"Logging to CSV at path {filename}")
|
||||
self.filename = filename
|
||||
self.color_method = color_method
|
||||
self.df = pd.DataFrame()
|
||||
|
||||
@@ -26,7 +26,7 @@ class FileLogger(Logger):
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
self.fout = open(filename, "w")
|
||||
logger.info(f"Logging to text file at path {filename}.")
|
||||
logger.info(f"Logging to text file at path {filename}")
|
||||
else:
|
||||
self.fout = filename
|
||||
self.num_results = 0
|
||||
|
||||
Reference in New Issue
Block a user