1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

include kwargs for transformation cache key

This commit is contained in:
Jin Yong Yoo
2020-07-16 14:18:22 -04:00
parent 995c2d7c25
commit 613bbf0b88

View File

@@ -120,8 +120,7 @@ class Attack:
pre_transformation_constraints=self.pre_transformation_constraints,
**kwargs,
)
if self.use_transformation_cache:
self.transformation_cache[current_text] = tuple(transformed_texts)
return transformed_texts
def get_transformations(self, current_text, original_text=None, **kwargs):
@@ -140,16 +139,19 @@ class Attack:
)
if self.use_transformation_cache:
if current_text in self.transformation_cache:
cache_key = tuple([current_text] + sorted(kwargs.items()))
if hashable(cache_key) and cache_key in self.transformation_cache:
# promote transformed_text to the top of the LRU cache
self.transformation_cache[current_text] = self.transformation_cache[
current_text
self.transformation_cache[cache_key] = self.transformation_cache[
cache_key
]
transformed_texts = list(self.transformation_cache[current_text])
else:
transformed_texts = self._get_transformations_uncached(
current_text, original_text, **kwargs
)
if hashable(cache_key):
self.transformation_cache[cache_key] = tuple(transformed_texts)
else:
transformed_texts = self._get_transformations_uncached(
current_text, original_text, **kwargs
@@ -335,3 +337,11 @@ class Attack:
return main_str
__str__ = __repr__
def hashable(key):
try:
hash(key)
return True
except TypeError:
return False