mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
raise warning (not error) when too few samples to attack
This commit is contained in:
@@ -428,7 +428,7 @@ def parse_logger_from_args(args):
|
|||||||
color_method = None if args.enable_csv == "plain" else "file"
|
color_method = None if args.enable_csv == "plain" else "file"
|
||||||
csv_path = os.path.join(args.out_dir, outfile_name)
|
csv_path = os.path.join(args.out_dir, outfile_name)
|
||||||
attack_log_manager.add_output_csv(csv_path, color_method)
|
attack_log_manager.add_output_csv(csv_path, color_method)
|
||||||
print("Logging to CSV at path {}.".format(csv_path))
|
textattack.shared.logger.info(f"Logging to CSV at path {csv_path}.")
|
||||||
|
|
||||||
# Visdom
|
# Visdom
|
||||||
if args.enable_visdom:
|
if args.enable_visdom:
|
||||||
|
|||||||
@@ -205,15 +205,16 @@ class Attack:
|
|||||||
Gets examples from a dataset and tokenizes them.
|
Gets examples from a dataset and tokenizes them.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: An iterable of (text, ground_truth_output) pairs
|
dataset: An iterable of (text_input, ground_truth_output) pairs
|
||||||
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
|
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
|
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
|
||||||
"""
|
"""
|
||||||
indices = indices if indices else deque(range(len(dataset)))
|
indices = indices or range(len(dataset))
|
||||||
if not isinstance(indices, deque):
|
if not isinstance(indices, deque):
|
||||||
indices = deque(indices)
|
indices = deque(sorted(indices))
|
||||||
|
|
||||||
if not indices:
|
if not indices:
|
||||||
return
|
return
|
||||||
yield
|
yield
|
||||||
@@ -221,14 +222,14 @@ class Attack:
|
|||||||
while indices:
|
while indices:
|
||||||
i = indices.popleft()
|
i = indices.popleft()
|
||||||
try:
|
try:
|
||||||
text, ground_truth_output = dataset[i]
|
text_input, ground_truth_output = dataset[i]
|
||||||
try:
|
try:
|
||||||
# get label names from dataset, if possible
|
# get label names from dataset, if possible
|
||||||
label_names = dataset.label_names
|
label_names = dataset.label_names
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
label_names = None
|
label_names = None
|
||||||
attacked_text = AttackedText(
|
attacked_text = AttackedText(
|
||||||
text, attack_attrs={"label_names": label_names}
|
text_input, attack_attrs={"label_names": label_names}
|
||||||
)
|
)
|
||||||
goal_function_result, _ = self.goal_function.init_attack_example(
|
goal_function_result, _ = self.goal_function.init_attack_example(
|
||||||
attacked_text, ground_truth_output
|
attacked_text, ground_truth_output
|
||||||
@@ -236,9 +237,10 @@ class Attack:
|
|||||||
yield goal_function_result
|
yield goal_function_result
|
||||||
|
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise IndexError(
|
utils.logger.warn(
|
||||||
f"Out of bounds access of dataset. Size of data is {len(dataset)} but tried to access index {i}"
|
f"Dataset has {len(dataset)} samples but tried to access index {i}. Ending attack early."
|
||||||
)
|
)
|
||||||
|
break
|
||||||
|
|
||||||
def attack_dataset(self, dataset, indices=None):
|
def attack_dataset(self, dataset, indices=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user