mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge in master & delete flaky test
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from collections import deque
|
||||
import os
|
||||
import random
|
||||
|
||||
import lru
|
||||
import numpy as np
|
||||
@@ -182,85 +182,64 @@ class Attack:
|
||||
initial_result, final_result, self.goal_function.num_queries
|
||||
)
|
||||
|
||||
def _get_examples_from_dataset(
|
||||
self,
|
||||
dataset,
|
||||
num_examples=None,
|
||||
shuffle=False,
|
||||
attack_n=False,
|
||||
attack_skippable_examples=False,
|
||||
):
|
||||
def _get_examples_from_dataset(self, dataset, indices=None):
|
||||
"""
|
||||
Gets examples from a dataset and tokenizes them.
|
||||
|
||||
Args:
|
||||
dataset: An iterable of (text, ground_truth_output) pairs
|
||||
num_examples (int): the number of examples to return
|
||||
shuffle (:obj:`bool`, optional): Whether to shuffle the data
|
||||
attack_n (bool): If `True`, returns `num_examples` non-skipped
|
||||
examples. If `False`, returns `num_examples` total examples.
|
||||
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
|
||||
|
||||
Returns:
|
||||
results (Iterable[Tuple[GoalFunctionResult, Boolean]]): a list of
|
||||
objects containing (text, ground_truth_output, was_skipped)
|
||||
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
|
||||
"""
|
||||
examples = []
|
||||
n = 0
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(dataset.examples)
|
||||
|
||||
num_examples = num_examples or len(dataset)
|
||||
|
||||
if num_examples <= 0:
|
||||
indices = indices if indices else deque(range(len(dataset)))
|
||||
if not isinstance(indices, deque):
|
||||
indices = deque(indices)
|
||||
if not indices:
|
||||
return
|
||||
yield
|
||||
|
||||
for text, ground_truth_output in dataset:
|
||||
attacked_text = AttackedText(text)
|
||||
self.goal_function.num_queries = 0
|
||||
goal_function_result, _ = self.goal_function.get_result(
|
||||
attacked_text, ground_truth_output
|
||||
)
|
||||
# We can skip examples for which the goal is already succeeded,
|
||||
# unless `attack_skippable_examples` is True.
|
||||
if (not attack_skippable_examples) and (goal_function_result.succeeded):
|
||||
if not attack_n:
|
||||
n += 1
|
||||
# Store the true output on the goal function so that the
|
||||
# SkippedAttackResult has the correct output, not the incorrect.
|
||||
goal_function_result.output = ground_truth_output
|
||||
yield (goal_function_result, True)
|
||||
else:
|
||||
n += 1
|
||||
yield (goal_function_result, False)
|
||||
if num_examples is not None and (n >= num_examples):
|
||||
break
|
||||
while indices:
|
||||
i = indices.popleft()
|
||||
try:
|
||||
text, ground_truth_output = dataset[i]
|
||||
tokenized_text = AttackedText(text)
|
||||
self.goal_function.num_queries = 0
|
||||
goal_function_result, _ = self.goal_function.get_result(
|
||||
tokenized_text, ground_truth_output
|
||||
)
|
||||
if goal_function_result.succeeded:
|
||||
# Store the true output on the goal function so that the
|
||||
# SkippedAttackResult has the correct output, not the incorrect.
|
||||
goal_function_result.output = ground_truth_output
|
||||
yield goal_function_result
|
||||
|
||||
def attack_dataset(self, dataset, num_examples=None, shuffle=False, attack_n=False):
|
||||
except IndexError:
|
||||
raise IndexError(
|
||||
"Out of bounds access of dataset. Size of data is {} but tried to access index {}".format(
|
||||
len(dataset), i
|
||||
)
|
||||
)
|
||||
|
||||
def attack_dataset(self, dataset, indices=None):
|
||||
"""
|
||||
Runs an attack on the given dataset and outputs the results to the
|
||||
console and the output file.
|
||||
|
||||
Args:
|
||||
dataset: An iterable of (text, ground_truth_output) pairs.
|
||||
num_examples: The number of samples to attack.
|
||||
shuffle (:obj:`bool`, optional): Whether to shuffle the data. Defaults to False.
|
||||
attack_n: Whether or not to attack ``num_examples`` examples. If false, will process
|
||||
``num_examples`` examples including ones which are skipped due to the model
|
||||
mispredicting the original sample.
|
||||
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
|
||||
"""
|
||||
|
||||
examples = self._get_examples_from_dataset(
|
||||
dataset, num_examples=num_examples, shuffle=shuffle, attack_n=attack_n
|
||||
)
|
||||
examples = self._get_examples_from_dataset(dataset, indices=indices)
|
||||
|
||||
for goal_function_result, was_skipped in examples:
|
||||
if was_skipped:
|
||||
for goal_function_result in examples:
|
||||
if goal_function_result.succeeded:
|
||||
yield SkippedAttackResult(goal_function_result)
|
||||
continue
|
||||
result = self.attack_one(goal_function_result)
|
||||
yield result
|
||||
else:
|
||||
result = self.attack_one(goal_function_result)
|
||||
yield result
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user