1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

rename things, update README

This commit is contained in:
uvafan
2020-05-21 14:26:56 -04:00
parent 231788575d
commit d32b685dc3
25 changed files with 296 additions and 279 deletions

View File

@@ -56,13 +56,13 @@ class Attack:
self.search_method.get_transformations = self.get_transformations
self.search_method.get_goal_results = self.goal_function.get_results
def get_transformations(self, text, original_text=None, **kwargs):
def get_transformations(self, current_text, original_text=None, **kwargs):
"""
Applies ``self.transformation`` to ``text``, then filters the list of possible transformations
through the applicable constraints.
Args:
text: The current ``TokenizedText`` on which to perform the transformations.
current_text: The current ``TokenizedText`` on which to perform the transformations.
original_text: The original ``TokenizedText`` from which the attack started.
apply_constraints: Whether or not to apply post-transformation constraints.
@@ -73,58 +73,58 @@ class Attack:
if not self.transformation:
raise RuntimeError('Cannot call `get_transformations` without a transformation.')
transformations = np.array(self.transformation(text,
transformed_texts = np.array(self.transformation(current_text,
pre_transformation_constraints=self.pre_transformation_constraints,
**kwargs))
return self._filter_transformations(transformations, text, original_text)
return self._filter_transformations(transformed_texts, current_text, original_text)
def _filter_transformations_uncached(self, original_transformations, text, original_text=None):
""" Filters a list of potential perturbations based on a list of
transformations. Checks cache first.
Args:
transformations (list: function): a list of transformations
that filter a list of candidate perturbations
text (list: TokenizedText): a list of TokenizedText objects
representation potential perturbations
"""
transformations = original_transformations[:]
for C in self.constraints:
if len(transformations) == 0: break
transformations = C.call_many(text, transformations, original_text=original_text)
# Default to false for all original transformations.
for original_transformation in original_transformations:
self.constraints_cache[original_transformation] = False
# Set unfiltered transformations to True in the cache.
for successful_transformation in transformations:
self.constraints_cache[successful_transformation] = True
return transformations
def _filter_transformations(self, transformations, text, original_text=None):
def _filter_transformations_uncached(self, transformed_texts, current_text, original_text=None):
"""
Filters a list of potential perturbations based on a list of
transformations. Checks cache first.
Filters a list of potential transformaed texts based on ``self.constraints``\.
Args:
transformed_texts: A list of candidate transformed ``TokenizedText``\s to filter.
current_text: The current ``TokenizedText`` on which the transformation was applied.
original_text: The original ``TokenizedText`` from which the attack started.
"""
filtered_texts = transformed_texts[:]
for C in self.constraints:
if len(filtered_texts) == 0: break
filtered_texts = C.call_many(filtered_texts, current_text,
original_text=original_text)
# Default to false for all original transformations.
for original_transformed_text in transformed_texts:
self.constraints_cache[original_transformed_text] = False
# Set unfiltered transformations to True in the cache.
for filtered_text in filtered_texts:
self.constraints_cache[filtered_text] = True
return filtered_texts
def _filter_transformations(self, transformed_texts, current_text, original_text=None):
"""
Filters a list of potential transformed texts based on ``self.constraints``\.
Checks cache first.
Args:
transformations (list: function): a list of transformations
that filter a list of candidate perturbations
text (list: TokenizedText): a list of TokenizedText objects
representation potential perturbations
transformed_texts: A list of candidate transformed ``TokenizedText``\s to filter.
current_text: The current ``TokenizedText`` on which the transformation was applied.
original_text: The original ``TokenizedText`` from which the attack started.
"""
# Populate cache with transformations.
uncached_transformations = []
for t in transformations:
if t not in self.constraints_cache:
uncached_transformations.append(t)
# Populate cache with transformed_texts
uncached_texts = []
for transformed_text in transformed_texts:
if transformed_text not in self.constraints_cache:
uncached_texts.append(transformed_text)
else:
# promote t to the top of the LRU cache
self.constraints_cache[t] = self.constraints_cache[t]
self._filter_transformations_uncached(uncached_transformations, text, original_text=original_text)
# Return transformations from cache.
filtered_transformations = [t for t in transformations if self.constraints_cache[t]]
# Sort transformations to ensure order is preserved between runs.
filtered_transformations.sort(key=lambda t: t.text)
return filtered_transformations
# promote transformed_text to the top of the LRU cache
self.constraints_cache[transformed_text] = self.constraints_cache[transformed_text]
self._filter_transformations_uncached(uncached_texts, current_text,
original_text=original_text)
# Return transformed_texts from cache
filtered_texts = [t for t in transformed_texts if self.constraints_cache[t]]
# Sort transformations to ensure order is preserved between runs
filtered_texts.sort(key=lambda t: t.text)
return filtered_texts
def attack_one(self, initial_result):
"""