mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
refactor pso and transformation caching
This commit is contained in:
@@ -11,7 +11,7 @@ from textattack.attack_results import (
|
||||
)
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.shared import AttackedText, utils
|
||||
from textattack.transformations import CompositeTransformation, RandomTransformation
|
||||
from textattack.transformations import CompositeTransformation
|
||||
|
||||
|
||||
class Attack:
|
||||
@@ -75,17 +75,12 @@ class Attack:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
# Check if we can use transformation cache for our transformation.
|
||||
if isinstance(self.transformation, RandomTransformation) or (
|
||||
hasattr(self.transformation, "random_one")
|
||||
and self.transformation.random_one
|
||||
):
|
||||
if not self.transformation.deterministic:
|
||||
self.use_transformation_cache = False
|
||||
elif isinstance(self.transformation, CompositeTransformation):
|
||||
self.use_transformation_cache = True
|
||||
for t in self.transformation.transformations:
|
||||
if isinstance(t, RandomTransformation) or (
|
||||
hasattr(t, "random_one") and t.random_one
|
||||
):
|
||||
if not t.deterministic:
|
||||
self.use_transformation_cache = False
|
||||
break
|
||||
else:
|
||||
@@ -140,17 +135,17 @@ class Attack:
|
||||
|
||||
if self.use_transformation_cache:
|
||||
cache_key = tuple([current_text] + sorted(kwargs.items()))
|
||||
if hashable(cache_key) and cache_key in self.transformation_cache:
|
||||
if utils.hashable(cache_key) and cache_key in self.transformation_cache:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.transformation_cache[cache_key] = self.transformation_cache[
|
||||
cache_key
|
||||
]
|
||||
transformed_texts = list(self.transformation_cache[current_text])
|
||||
transformed_texts = list(self.transformation_cache[cache_key])
|
||||
else:
|
||||
transformed_texts = self._get_transformations_uncached(
|
||||
current_text, original_text, **kwargs
|
||||
)
|
||||
if hashable(cache_key):
|
||||
if utils.hashable(cache_key):
|
||||
self.transformation_cache[cache_key] = tuple(transformed_texts)
|
||||
else:
|
||||
transformed_texts = self._get_transformations_uncached(
|
||||
@@ -337,11 +332,3 @@ class Attack:
|
||||
return main_str
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
def hashable(key):
|
||||
try:
|
||||
hash(key)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user