1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add check for random transformations

This commit is contained in:
Jin Yong Yoo
2020-07-16 09:57:51 -04:00
parent d65b5963ed
commit ebe050c864
7 changed files with 46 additions and 16 deletions

View File

@@ -11,8 +11,7 @@ from textattack.attack_results import (
)
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.shared import AttackedText, utils
# import os
from textattack.transformations import CompositeTransformation, RandomTransformation
class Attack:
@@ -75,6 +74,22 @@ class Attack:
else:
self.constraints.append(constraint)
# Check if we can use transformation cache for our transformation.
if isinstance(self.transformation, RandomTransformation) or (
hasattr(self.transformation, "random_one")
and self.transformation.random_one
):
self.use_transformation_cache = False
elif isinstance(self.transformation, CompositeTransformation):
self.use_transformation_cache = True
for t in self.transformation.transformations:
if isinstance(t, RandomTransformation) or (
hasattr(t, "random_one") and t.random_one
):
self.use_transformation_cache = False
break
else:
self.use_transformation_cache = True
self.transformation_cache_size = transformation_cache_size
self.transformation_cache = lru.LRU(transformation_cache_size)
@@ -105,7 +120,8 @@ class Attack:
pre_transformation_constraints=self.pre_transformation_constraints,
**kwargs,
)
self.transformation_cache[current_text] = tuple(transformed_texts)
if self.use_transformation_cache:
self.transformation_cache[current_text] = tuple(transformed_texts)
return transformed_texts
def get_transformations(self, current_text, original_text=None, **kwargs):
@@ -123,12 +139,17 @@ class Attack:
"Cannot call `get_transformations` without a transformation."
)
if current_text in self.transformation_cache:
# promote transformed_text to the top of the LRU cache
self.transformation_cache[current_text] = self.transformation_cache[
current_text
]
transformed_texts = list(self.transformation_cache[current_text])
if self.use_transformation_cache:
if current_text in self.transformation_cache:
# promote transformed_text to the top of the LRU cache
self.transformation_cache[current_text] = self.transformation_cache[
current_text
]
transformed_texts = list(self.transformation_cache[current_text])
else:
transformed_texts = self._get_transformations_uncached(
current_text, original_text, **kwargs
)
else:
transformed_texts = self._get_transformations_uncached(
current_text, original_text, **kwargs