1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

raise warning (not error) when too few samples to attack

This commit is contained in:
Jack Morris
2020-07-07 22:08:51 -04:00
parent 7e3b90e979
commit 58609feac1
2 changed files with 10 additions and 8 deletions

View File

@@ -428,7 +428,7 @@ def parse_logger_from_args(args):
color_method = None if args.enable_csv == "plain" else "file" color_method = None if args.enable_csv == "plain" else "file"
csv_path = os.path.join(args.out_dir, outfile_name) csv_path = os.path.join(args.out_dir, outfile_name)
attack_log_manager.add_output_csv(csv_path, color_method) attack_log_manager.add_output_csv(csv_path, color_method)
print("Logging to CSV at path {}.".format(csv_path)) textattack.shared.logger.info(f"Logging to CSV at path {csv_path}.")
# Visdom # Visdom
if args.enable_visdom: if args.enable_visdom:

View File

@@ -205,15 +205,16 @@ class Attack:
Gets examples from a dataset and tokenizes them. Gets examples from a dataset and tokenizes them.
Args: Args:
dataset: An iterable of (text, ground_truth_output) pairs dataset: An iterable of (text_input, ground_truth_output) pairs
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset. indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
Returns: Returns:
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
""" """
indices = indices if indices else deque(range(len(dataset))) indices = indices or range(len(dataset))
if not isinstance(indices, deque): if not isinstance(indices, deque):
indices = deque(indices) indices = deque(sorted(indices))
if not indices: if not indices:
return return
yield yield
@@ -221,14 +222,14 @@ class Attack:
while indices: while indices:
i = indices.popleft() i = indices.popleft()
try: try:
text, ground_truth_output = dataset[i] text_input, ground_truth_output = dataset[i]
try: try:
# get label names from dataset, if possible # get label names from dataset, if possible
label_names = dataset.label_names label_names = dataset.label_names
except AttributeError: except AttributeError:
label_names = None label_names = None
attacked_text = AttackedText( attacked_text = AttackedText(
text, attack_attrs={"label_names": label_names} text_input, attack_attrs={"label_names": label_names}
) )
goal_function_result, _ = self.goal_function.init_attack_example( goal_function_result, _ = self.goal_function.init_attack_example(
attacked_text, ground_truth_output attacked_text, ground_truth_output
@@ -236,9 +237,10 @@ class Attack:
yield goal_function_result yield goal_function_result
except IndexError: except IndexError:
raise IndexError( utils.logger.warn(
f"Out of bounds access of dataset. Size of data is {len(dataset)} but tried to access index {i}" f"Dataset has {len(dataset)} samples but tried to access index {i}. Ending attack early."
) )
break
def attack_dataset(self, dataset, indices=None): def attack_dataset(self, dataset, indices=None):
""" """