mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
include kwargs for transformation cache key
This commit is contained in:
@@ -120,8 +120,7 @@ class Attack:
|
||||
pre_transformation_constraints=self.pre_transformation_constraints,
|
||||
**kwargs,
|
||||
)
|
||||
if self.use_transformation_cache:
|
||||
self.transformation_cache[current_text] = tuple(transformed_texts)
|
||||
|
||||
return transformed_texts
|
||||
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
@@ -140,16 +139,19 @@ class Attack:
|
||||
)
|
||||
|
||||
if self.use_transformation_cache:
|
||||
if current_text in self.transformation_cache:
|
||||
cache_key = tuple([current_text] + sorted(kwargs.items()))
|
||||
if hashable(cache_key) and cache_key in self.transformation_cache:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.transformation_cache[current_text] = self.transformation_cache[
|
||||
current_text
|
||||
self.transformation_cache[cache_key] = self.transformation_cache[
|
||||
cache_key
|
||||
]
|
||||
transformed_texts = list(self.transformation_cache[current_text])
|
||||
else:
|
||||
transformed_texts = self._get_transformations_uncached(
|
||||
current_text, original_text, **kwargs
|
||||
)
|
||||
if hashable(cache_key):
|
||||
self.transformation_cache[cache_key] = tuple(transformed_texts)
|
||||
else:
|
||||
transformed_texts = self._get_transformations_uncached(
|
||||
current_text, original_text, **kwargs
|
||||
@@ -335,3 +337,11 @@ class Attack:
|
||||
return main_str
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
def hashable(key):
|
||||
try:
|
||||
hash(key)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user