1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

refactor pso and transformation caching

This commit is contained in:
Jin Yong Yoo
2020-07-25 08:57:52 -04:00
parent 613bbf0b88
commit 5a8d74d288
22 changed files with 166 additions and 116 deletions

View File

@@ -11,7 +11,7 @@ from textattack.attack_results import (
)
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.shared import AttackedText, utils
from textattack.transformations import CompositeTransformation, RandomTransformation
from textattack.transformations import CompositeTransformation
class Attack:
@@ -75,17 +75,12 @@ class Attack:
self.constraints.append(constraint)
# Check if we can use transformation cache for our transformation.
if isinstance(self.transformation, RandomTransformation) or (
hasattr(self.transformation, "random_one")
and self.transformation.random_one
):
if not self.transformation.deterministic:
self.use_transformation_cache = False
elif isinstance(self.transformation, CompositeTransformation):
self.use_transformation_cache = True
for t in self.transformation.transformations:
if isinstance(t, RandomTransformation) or (
hasattr(t, "random_one") and t.random_one
):
if not t.deterministic:
self.use_transformation_cache = False
break
else:
@@ -140,17 +135,17 @@ class Attack:
if self.use_transformation_cache:
cache_key = tuple([current_text] + sorted(kwargs.items()))
if hashable(cache_key) and cache_key in self.transformation_cache:
if utils.hashable(cache_key) and cache_key in self.transformation_cache:
# promote transformed_text to the top of the LRU cache
self.transformation_cache[cache_key] = self.transformation_cache[
cache_key
]
transformed_texts = list(self.transformation_cache[current_text])
transformed_texts = list(self.transformation_cache[cache_key])
else:
transformed_texts = self._get_transformations_uncached(
current_text, original_text, **kwargs
)
if hashable(cache_key):
if utils.hashable(cache_key):
self.transformation_cache[cache_key] = tuple(transformed_texts)
else:
transformed_texts = self._get_transformations_uncached(
@@ -337,11 +332,3 @@ class Attack:
return main_str
__str__ = __repr__
def hashable(key):
try:
hash(key)
return True
except TypeError:
return False