mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
rename things, update README
This commit is contained in:
@@ -56,13 +56,13 @@ class Attack:
|
||||
self.search_method.get_transformations = self.get_transformations
|
||||
self.search_method.get_goal_results = self.goal_function.get_results
|
||||
|
||||
def get_transformations(self, text, original_text=None, **kwargs):
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
"""
|
||||
Applies ``self.transformation`` to ``text``, then filters the list of possible transformations
|
||||
through the applicable constraints.
|
||||
|
||||
Args:
|
||||
text: The current ``TokenizedText`` on which to perform the transformations.
|
||||
current_text: The current ``TokenizedText`` on which to perform the transformations.
|
||||
original_text: The original ``TokenizedText`` from which the attack started.
|
||||
apply_constraints: Whether or not to apply post-transformation constraints.
|
||||
|
||||
@@ -73,58 +73,58 @@ class Attack:
|
||||
if not self.transformation:
|
||||
raise RuntimeError('Cannot call `get_transformations` without a transformation.')
|
||||
|
||||
transformations = np.array(self.transformation(text,
|
||||
transformed_texts = np.array(self.transformation(current_text,
|
||||
pre_transformation_constraints=self.pre_transformation_constraints,
|
||||
**kwargs))
|
||||
return self._filter_transformations(transformations, text, original_text)
|
||||
return self._filter_transformations(transformed_texts, current_text, original_text)
|
||||
|
||||
def _filter_transformations_uncached(self, original_transformations, text, original_text=None):
|
||||
""" Filters a list of potential perturbations based on a list of
|
||||
transformations. Checks cache first.
|
||||
|
||||
Args:
|
||||
transformations (list: function): a list of transformations
|
||||
that filter a list of candidate perturbations
|
||||
text (list: TokenizedText): a list of TokenizedText objects
|
||||
representation potential perturbations
|
||||
"""
|
||||
transformations = original_transformations[:]
|
||||
for C in self.constraints:
|
||||
if len(transformations) == 0: break
|
||||
transformations = C.call_many(text, transformations, original_text=original_text)
|
||||
# Default to false for all original transformations.
|
||||
for original_transformation in original_transformations:
|
||||
self.constraints_cache[original_transformation] = False
|
||||
# Set unfiltered transformations to True in the cache.
|
||||
for successful_transformation in transformations:
|
||||
self.constraints_cache[successful_transformation] = True
|
||||
return transformations
|
||||
|
||||
def _filter_transformations(self, transformations, text, original_text=None):
|
||||
def _filter_transformations_uncached(self, transformed_texts, current_text, original_text=None):
|
||||
"""
|
||||
Filters a list of potential perturbations based on a list of
|
||||
transformations. Checks cache first.
|
||||
Filters a list of potential transformaed texts based on ``self.constraints``\.
|
||||
|
||||
Args:
|
||||
transformed_texts: A list of candidate transformed ``TokenizedText``\s to filter.
|
||||
current_text: The current ``TokenizedText`` on which the transformation was applied.
|
||||
original_text: The original ``TokenizedText`` from which the attack started.
|
||||
"""
|
||||
filtered_texts = transformed_texts[:]
|
||||
for C in self.constraints:
|
||||
if len(filtered_texts) == 0: break
|
||||
filtered_texts = C.call_many(filtered_texts, current_text,
|
||||
original_text=original_text)
|
||||
# Default to false for all original transformations.
|
||||
for original_transformed_text in transformed_texts:
|
||||
self.constraints_cache[original_transformed_text] = False
|
||||
# Set unfiltered transformations to True in the cache.
|
||||
for filtered_text in filtered_texts:
|
||||
self.constraints_cache[filtered_text] = True
|
||||
return filtered_texts
|
||||
|
||||
def _filter_transformations(self, transformed_texts, current_text, original_text=None):
|
||||
"""
|
||||
Filters a list of potential transformed texts based on ``self.constraints``\.
|
||||
Checks cache first.
|
||||
|
||||
Args:
|
||||
transformations (list: function): a list of transformations
|
||||
that filter a list of candidate perturbations
|
||||
text (list: TokenizedText): a list of TokenizedText objects
|
||||
representation potential perturbations
|
||||
transformed_texts: A list of candidate transformed ``TokenizedText``\s to filter.
|
||||
current_text: The current ``TokenizedText`` on which the transformation was applied.
|
||||
original_text: The original ``TokenizedText`` from which the attack started.
|
||||
"""
|
||||
# Populate cache with transformations.
|
||||
uncached_transformations = []
|
||||
for t in transformations:
|
||||
if t not in self.constraints_cache:
|
||||
uncached_transformations.append(t)
|
||||
# Populate cache with transformed_texts
|
||||
uncached_texts = []
|
||||
for transformed_text in transformed_texts:
|
||||
if transformed_text not in self.constraints_cache:
|
||||
uncached_texts.append(transformed_text)
|
||||
else:
|
||||
# promote t to the top of the LRU cache
|
||||
self.constraints_cache[t] = self.constraints_cache[t]
|
||||
self._filter_transformations_uncached(uncached_transformations, text, original_text=original_text)
|
||||
# Return transformations from cache.
|
||||
filtered_transformations = [t for t in transformations if self.constraints_cache[t]]
|
||||
# Sort transformations to ensure order is preserved between runs.
|
||||
filtered_transformations.sort(key=lambda t: t.text)
|
||||
return filtered_transformations
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.constraints_cache[transformed_text] = self.constraints_cache[transformed_text]
|
||||
self._filter_transformations_uncached(uncached_texts, current_text,
|
||||
original_text=original_text)
|
||||
# Return transformed_texts from cache
|
||||
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
|
||||
# Sort transformations to ensure order is preserved between runs
|
||||
filtered_texts.sort(key=lambda t: t.text)
|
||||
return filtered_texts
|
||||
|
||||
def attack_one(self, initial_result):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user