mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
makefile and setup; need to fix imports
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
import lru
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
|
||||
from textattack.shared import utils
|
||||
import lru
|
||||
import numpy as np
|
||||
|
||||
from textattack.attack_results import (
|
||||
FailedAttackResult,
|
||||
SkippedAttackResult,
|
||||
SuccessfulAttackResult,
|
||||
)
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.constraints.pre_transformation import PreTransformationConstraint
|
||||
from textattack.shared import TokenizedText
|
||||
from textattack.attack_results import SkippedAttackResult, SuccessfulAttackResult, FailedAttackResult
|
||||
from textattack.shared import TokenizedText, utils
|
||||
|
||||
|
||||
class Attack:
|
||||
"""
|
||||
@@ -25,21 +30,33 @@ class Attack:
|
||||
search_method: A strategy for exploring the search space of possible perturbations
|
||||
"""
|
||||
|
||||
def __init__(self, goal_function=None, constraints=[], transformation=None, search_method=None):
|
||||
def __init__(
|
||||
self,
|
||||
goal_function=None,
|
||||
constraints=[],
|
||||
transformation=None,
|
||||
search_method=None,
|
||||
):
|
||||
""" Initialize an attack object. Attacks can be run multiple times. """
|
||||
self.goal_function = goal_function
|
||||
if not self.goal_function:
|
||||
raise NameError('Cannot instantiate attack without self.goal_function for predictions')
|
||||
raise NameError(
|
||||
"Cannot instantiate attack without self.goal_function for predictions"
|
||||
)
|
||||
self.search_method = search_method
|
||||
if not self.search_method:
|
||||
raise NameError('Cannot instantiate attack without search method')
|
||||
raise NameError("Cannot instantiate attack without search method")
|
||||
self.transformation = transformation
|
||||
if not self.transformation:
|
||||
raise NameError('Cannot instantiate attack without transformation')
|
||||
self.is_black_box = getattr(transformation, 'is_black_box', True)
|
||||
raise NameError("Cannot instantiate attack without transformation")
|
||||
self.is_black_box = getattr(transformation, "is_black_box", True)
|
||||
|
||||
if not self.search_method.check_transformation_compatibility(self.transformation):
|
||||
raise ValueError('SearchMethod {self.search_method} incompatible with transformation {self.transformation}')
|
||||
if not self.search_method.check_transformation_compatibility(
|
||||
self.transformation
|
||||
):
|
||||
raise ValueError(
|
||||
"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
|
||||
)
|
||||
|
||||
self.constraints = []
|
||||
self.pre_transformation_constraints = []
|
||||
@@ -48,13 +65,13 @@ class Attack:
|
||||
self.pre_transformation_constraints.append(constraint)
|
||||
else:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
self.constraints_cache = lru.LRU(utils.config('CONSTRAINT_CACHE_SIZE'))
|
||||
|
||||
|
||||
self.constraints_cache = lru.LRU(utils.config("CONSTRAINT_CACHE_SIZE"))
|
||||
|
||||
# Give search method access to functions for getting transformations and evaluating them
|
||||
self.search_method.get_transformations = self.get_transformations
|
||||
self.search_method.get_goal_results = self.goal_function.get_results
|
||||
|
||||
self.search_method.get_goal_results = self.goal_function.get_results
|
||||
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
"""
|
||||
Applies ``self.transformation`` to ``text``, then filters the list of possible transformations
|
||||
@@ -70,14 +87,24 @@ class Attack:
|
||||
|
||||
"""
|
||||
if not self.transformation:
|
||||
raise RuntimeError('Cannot call `get_transformations` without a transformation.')
|
||||
|
||||
transformed_texts = np.array(self.transformation(current_text,
|
||||
pre_transformation_constraints=self.pre_transformation_constraints,
|
||||
**kwargs))
|
||||
return self._filter_transformations(transformed_texts, current_text, original_text)
|
||||
|
||||
def _filter_transformations_uncached(self, transformed_texts, current_text, original_text=None):
|
||||
raise RuntimeError(
|
||||
"Cannot call `get_transformations` without a transformation."
|
||||
)
|
||||
|
||||
transformed_texts = np.array(
|
||||
self.transformation(
|
||||
current_text,
|
||||
pre_transformation_constraints=self.pre_transformation_constraints,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
return self._filter_transformations(
|
||||
transformed_texts, current_text, original_text
|
||||
)
|
||||
|
||||
def _filter_transformations_uncached(
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""
|
||||
Filters a list of potential transformaed texts based on ``self.constraints``\.
|
||||
|
||||
@@ -88,18 +115,22 @@ class Attack:
|
||||
"""
|
||||
filtered_texts = transformed_texts[:]
|
||||
for C in self.constraints:
|
||||
if len(filtered_texts) == 0: break
|
||||
filtered_texts = C.call_many(filtered_texts, current_text,
|
||||
original_text=original_text)
|
||||
if len(filtered_texts) == 0:
|
||||
break
|
||||
filtered_texts = C.call_many(
|
||||
filtered_texts, current_text, original_text=original_text
|
||||
)
|
||||
# Default to false for all original transformations.
|
||||
for original_transformed_text in transformed_texts:
|
||||
self.constraints_cache[original_transformed_text] = False
|
||||
# Set unfiltered transformations to True in the cache.
|
||||
for filtered_text in filtered_texts:
|
||||
self.constraints_cache[filtered_text] = True
|
||||
return filtered_texts
|
||||
|
||||
def _filter_transformations(self, transformed_texts, current_text, original_text=None):
|
||||
return filtered_texts
|
||||
|
||||
def _filter_transformations(
|
||||
self, transformed_texts, current_text, original_text=None
|
||||
):
|
||||
"""
|
||||
Filters a list of potential transformed texts based on ``self.constraints``\.
|
||||
Checks cache first.
|
||||
@@ -116,9 +147,12 @@ class Attack:
|
||||
uncached_texts.append(transformed_text)
|
||||
else:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.constraints_cache[transformed_text] = self.constraints_cache[transformed_text]
|
||||
self._filter_transformations_uncached(uncached_texts, current_text,
|
||||
original_text=original_text)
|
||||
self.constraints_cache[transformed_text] = self.constraints_cache[
|
||||
transformed_text
|
||||
]
|
||||
self._filter_transformations_uncached(
|
||||
uncached_texts, current_text, original_text=original_text
|
||||
)
|
||||
# Return transformed_texts from cache
|
||||
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
|
||||
# Sort transformations to ensure order is preserved between runs
|
||||
@@ -138,12 +172,22 @@ class Attack:
|
||||
"""
|
||||
final_result = self.search_method(initial_result)
|
||||
if final_result.succeeded:
|
||||
return SuccessfulAttackResult(initial_result, final_result, self.goal_function.num_queries)
|
||||
return SuccessfulAttackResult(
|
||||
initial_result, final_result, self.goal_function.num_queries
|
||||
)
|
||||
else:
|
||||
return FailedAttackResult(initial_result, final_result, self.goal_function.num_queries)
|
||||
|
||||
def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,
|
||||
attack_n=False, attack_skippable_examples=False):
|
||||
return FailedAttackResult(
|
||||
initial_result, final_result, self.goal_function.num_queries
|
||||
)
|
||||
|
||||
def _get_examples_from_dataset(
|
||||
self,
|
||||
dataset,
|
||||
num_examples=None,
|
||||
shuffle=False,
|
||||
attack_n=False,
|
||||
attack_skippable_examples=False,
|
||||
):
|
||||
"""
|
||||
Gets examples from a dataset and tokenizes them.
|
||||
|
||||
@@ -160,24 +204,26 @@ class Attack:
|
||||
"""
|
||||
examples = []
|
||||
n = 0
|
||||
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(dataset.examples)
|
||||
|
||||
num_examples = num_examples or len(dataset)
|
||||
|
||||
num_examples = num_examples or len(dataset)
|
||||
|
||||
if num_examples <= 0:
|
||||
return
|
||||
yield
|
||||
|
||||
|
||||
for text, ground_truth_output in dataset:
|
||||
tokenized_text = TokenizedText(text, self.goal_function.tokenizer)
|
||||
self.goal_function.num_queries = 0
|
||||
goal_function_result, _ = self.goal_function.get_result(tokenized_text, ground_truth_output)
|
||||
goal_function_result, _ = self.goal_function.get_result(
|
||||
tokenized_text, ground_truth_output
|
||||
)
|
||||
# We can skip examples for which the goal is already succeeded,
|
||||
# unless `attack_skippable_examples` is True.
|
||||
if (not attack_skippable_examples) and (goal_function_result.succeeded):
|
||||
if not attack_n:
|
||||
if not attack_n:
|
||||
n += 1
|
||||
# Store the true output on the goal function so that the
|
||||
# SkippedAttackResult has the correct output, not the incorrect.
|
||||
@@ -202,9 +248,10 @@ class Attack:
|
||||
``num_examples`` examples including ones which are skipped due to the model
|
||||
mispredicting the original sample.
|
||||
"""
|
||||
|
||||
examples = self._get_examples_from_dataset(dataset,
|
||||
num_examples=num_examples, shuffle=shuffle, attack_n=attack_n)
|
||||
|
||||
examples = self._get_examples_from_dataset(
|
||||
dataset, num_examples=num_examples, shuffle=shuffle, attack_n=attack_n
|
||||
)
|
||||
|
||||
for goal_function_result, was_skipped in examples:
|
||||
if was_skipped:
|
||||
@@ -212,7 +259,7 @@ class Attack:
|
||||
continue
|
||||
result = self.attack_one(goal_function_result)
|
||||
yield result
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Prints attack parameters in a human-readable string.
|
||||
@@ -220,34 +267,28 @@ class Attack:
|
||||
Inspired by the readability of printing PyTorch nn.Modules:
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
|
||||
"""
|
||||
main_str = 'Attack' + '('
|
||||
main_str = "Attack" + "("
|
||||
lines = []
|
||||
|
||||
lines.append(
|
||||
utils.add_indent(f'(search_method): {self.search_method}', 2)
|
||||
)
|
||||
|
||||
lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2))
|
||||
# self.goal_function
|
||||
lines.append(
|
||||
utils.add_indent(f'(goal_function): {self.goal_function}', 2)
|
||||
)
|
||||
lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2))
|
||||
# self.transformation
|
||||
lines.append(
|
||||
utils.add_indent(f'(transformation): {self.transformation}', 2)
|
||||
)
|
||||
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
|
||||
# self.constraints
|
||||
constraints_lines = []
|
||||
constraints = self.constraints + self.pre_transformation_constraints
|
||||
if len(constraints):
|
||||
for i, constraint in enumerate(constraints):
|
||||
constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2))
|
||||
constraints_str = utils.add_indent('\n' + '\n'.join(constraints_lines), 2)
|
||||
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
|
||||
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
|
||||
else:
|
||||
constraints_str = 'None'
|
||||
lines.append(utils.add_indent(f'(constraints): {constraints_str}', 2))
|
||||
constraints_str = "None"
|
||||
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
|
||||
# self.is_black_box
|
||||
lines.append(utils.add_indent(f'(is_black_box): {self.is_black_box}', 2))
|
||||
main_str += '\n ' + '\n '.join(lines) + '\n'
|
||||
main_str += ')'
|
||||
lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2))
|
||||
main_str += "\n " + "\n ".join(lines) + "\n"
|
||||
main_str += ")"
|
||||
return main_str
|
||||
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
Reference in New Issue
Block a user