1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

merge in master & delete flaky test

This commit is contained in:
Jack Morris
2020-06-17 18:01:51 -04:00
parent 40b5640381
commit 2bbd06b0c2
19 changed files with 238 additions and 167 deletions

View File

@@ -1,5 +1,5 @@
from collections import deque
import os
import random
import lru
import numpy as np
@@ -182,85 +182,64 @@ class Attack:
initial_result, final_result, self.goal_function.num_queries
)
def _get_examples_from_dataset(
self,
dataset,
num_examples=None,
shuffle=False,
attack_n=False,
attack_skippable_examples=False,
):
def _get_examples_from_dataset(self, dataset, indices=None):
"""
Gets examples from a dataset and tokenizes them.
Args:
dataset: An iterable of (text, ground_truth_output) pairs
num_examples (int): the number of examples to return
shuffle (:obj:`bool`, optional): Whether to shuffle the data
attack_n (bool): If `True`, returns `num_examples` non-skipped
examples. If `False`, returns `num_examples` total examples.
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
Returns:
results (Iterable[Tuple[GoalFunctionResult, Boolean]]): a list of
objects containing (text, ground_truth_output, was_skipped)
results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
"""
examples = []
n = 0
if shuffle:
random.shuffle(dataset.examples)
num_examples = num_examples or len(dataset)
if num_examples <= 0:
indices = indices if indices else deque(range(len(dataset)))
if not isinstance(indices, deque):
indices = deque(indices)
if not indices:
return
yield
for text, ground_truth_output in dataset:
attacked_text = AttackedText(text)
self.goal_function.num_queries = 0
goal_function_result, _ = self.goal_function.get_result(
attacked_text, ground_truth_output
)
# We can skip examples for which the goal is already succeeded,
# unless `attack_skippable_examples` is True.
if (not attack_skippable_examples) and (goal_function_result.succeeded):
if not attack_n:
n += 1
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = ground_truth_output
yield (goal_function_result, True)
else:
n += 1
yield (goal_function_result, False)
if num_examples is not None and (n >= num_examples):
break
while indices:
i = indices.popleft()
try:
text, ground_truth_output = dataset[i]
tokenized_text = AttackedText(text)
self.goal_function.num_queries = 0
goal_function_result, _ = self.goal_function.get_result(
tokenized_text, ground_truth_output
)
if goal_function_result.succeeded:
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = ground_truth_output
yield goal_function_result
def attack_dataset(self, dataset, num_examples=None, shuffle=False, attack_n=False):
except IndexError:
raise IndexError(
"Out of bounds access of dataset. Size of data is {} but tried to access index {}".format(
len(dataset), i
)
)
def attack_dataset(self, dataset, indices=None):
"""
Runs an attack on the given dataset and outputs the results to the
console and the output file.
Args:
dataset: An iterable of (text, ground_truth_output) pairs.
num_examples: The number of samples to attack.
shuffle (:obj:`bool`, optional): Whether to shuffle the data. Defaults to False.
attack_n: Whether or not to attack ``num_examples`` examples. If false, will process
``num_examples`` examples including ones which are skipped due to the model
mispredicting the original sample.
indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
"""
examples = self._get_examples_from_dataset(
dataset, num_examples=num_examples, shuffle=shuffle, attack_n=attack_n
)
examples = self._get_examples_from_dataset(dataset, indices=indices)
for goal_function_result, was_skipped in examples:
if was_skipped:
for goal_function_result in examples:
if goal_function_result.succeeded:
yield SkippedAttackResult(goal_function_result)
continue
result = self.attack_one(goal_function_result)
yield result
else:
result = self.attack_one(goal_function_result)
yield result
def __repr__(self):
"""