mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge with master
This commit is contained in:
@@ -29,6 +29,7 @@ class Attack:
|
||||
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
|
||||
transformation: The transformation applied at each step of the attack.
|
||||
search_method: A strategy for exploring the search space of possible perturbations
|
||||
constraint_cache_size (int): the number of items to keep in the constraints cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -37,6 +38,7 @@ class Attack:
|
||||
constraints=[],
|
||||
transformation=None,
|
||||
search_method=None,
|
||||
constraint_cache_size=2 ** 18,
|
||||
):
|
||||
""" Initialize an attack object. Attacks can be run multiple times. """
|
||||
self.goal_function = goal_function
|
||||
@@ -70,7 +72,8 @@ class Attack:
|
||||
else:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
self.constraints_cache = lru.LRU(utils.config("CONSTRAINT_CACHE_SIZE"))
|
||||
self.constraint_cache_size = constraint_cache_size
|
||||
self.constraints_cache = lru.LRU(constraint_cache_size)
|
||||
|
||||
# Give search method access to functions for getting transformations and evaluating them
|
||||
self.search_method.get_transformations = self.get_transformations
|
||||
@@ -126,10 +129,10 @@ class Attack:
|
||||
)
|
||||
# Default to false for all original transformations.
|
||||
for original_transformed_text in transformed_texts:
|
||||
self.constraints_cache[original_transformed_text] = False
|
||||
self.constraints_cache[(current_text, original_transformed_text)] = False
|
||||
# Set unfiltered transformations to True in the cache.
|
||||
for filtered_text in filtered_texts:
|
||||
self.constraints_cache[filtered_text] = True
|
||||
self.constraints_cache[(current_text, filtered_text)] = True
|
||||
return filtered_texts
|
||||
|
||||
def _filter_transformations(
|
||||
@@ -147,18 +150,20 @@ class Attack:
|
||||
# Populate cache with transformed_texts
|
||||
uncached_texts = []
|
||||
for transformed_text in transformed_texts:
|
||||
if transformed_text not in self.constraints_cache:
|
||||
if (current_text, transformed_text) not in self.constraints_cache:
|
||||
uncached_texts.append(transformed_text)
|
||||
else:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.constraints_cache[transformed_text] = self.constraints_cache[
|
||||
transformed_text
|
||||
]
|
||||
self.constraints_cache[
|
||||
(current_text, transformed_text)
|
||||
] = self.constraints_cache[(current_text, transformed_text)]
|
||||
self._filter_transformations_uncached(
|
||||
uncached_texts, current_text, original_text=original_text
|
||||
)
|
||||
# Return transformed_texts from cache
|
||||
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
|
||||
filtered_texts = [
|
||||
t for t in transformed_texts if self.constraints_cache[(current_text, t)]
|
||||
]
|
||||
# Sort transformations to ensure order is preserved between runs
|
||||
filtered_texts.sort(key=lambda t: t.text)
|
||||
return filtered_texts
|
||||
|
||||
Reference in New Issue
Block a user