1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

merge with master

This commit is contained in:
uvafan
2020-06-28 20:43:07 -04:00
72 changed files with 11747 additions and 504 deletions

View File

@@ -29,6 +29,7 @@ class Attack:
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
transformation: The transformation applied at each step of the attack.
search_method: A strategy for exploring the search space of possible perturbations
constraint_cache_size (int): the number of items to keep in the constraints cache
"""
def __init__(
@@ -37,6 +38,7 @@ class Attack:
constraints=[],
transformation=None,
search_method=None,
constraint_cache_size=2 ** 18,
):
""" Initialize an attack object. Attacks can be run multiple times. """
self.goal_function = goal_function
@@ -70,7 +72,8 @@ class Attack:
else:
self.constraints.append(constraint)
self.constraints_cache = lru.LRU(utils.config("CONSTRAINT_CACHE_SIZE"))
self.constraint_cache_size = constraint_cache_size
self.constraints_cache = lru.LRU(constraint_cache_size)
# Give search method access to functions for getting transformations and evaluating them
self.search_method.get_transformations = self.get_transformations
@@ -126,10 +129,10 @@ class Attack:
)
# Default to false for all original transformations.
for original_transformed_text in transformed_texts:
self.constraints_cache[original_transformed_text] = False
self.constraints_cache[(current_text, original_transformed_text)] = False
# Set unfiltered transformations to True in the cache.
for filtered_text in filtered_texts:
self.constraints_cache[filtered_text] = True
self.constraints_cache[(current_text, filtered_text)] = True
return filtered_texts
def _filter_transformations(
@@ -147,18 +150,20 @@ class Attack:
# Populate cache with transformed_texts
uncached_texts = []
for transformed_text in transformed_texts:
if transformed_text not in self.constraints_cache:
if (current_text, transformed_text) not in self.constraints_cache:
uncached_texts.append(transformed_text)
else:
# promote transformed_text to the top of the LRU cache
self.constraints_cache[transformed_text] = self.constraints_cache[
transformed_text
]
self.constraints_cache[
(current_text, transformed_text)
] = self.constraints_cache[(current_text, transformed_text)]
self._filter_transformations_uncached(
uncached_texts, current_text, original_text=original_text
)
# Return transformed_texts from cache
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
filtered_texts = [
t for t in transformed_texts if self.constraints_cache[(current_text, t)]
]
# Sort transformations to ensure order is preserved between runs
filtered_texts.sort(key=lambda t: t.text)
return filtered_texts