mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add check for random transformations
This commit is contained in:
@@ -11,8 +11,7 @@ from textattack.attack_results import (
|
||||
)
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.shared import AttackedText, utils
|
||||
|
||||
# import os
|
||||
from textattack.transformations import CompositeTransformation, RandomTransformation
|
||||
|
||||
|
||||
class Attack:
|
||||
@@ -75,6 +74,22 @@ class Attack:
|
||||
else:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
# Check if we can use transformation cache for our transformation.
|
||||
if isinstance(self.transformation, RandomTransformation) or (
|
||||
hasattr(self.transformation, "random_one")
|
||||
and self.transformation.random_one
|
||||
):
|
||||
self.use_transformation_cache = False
|
||||
elif isinstance(self.transformation, CompositeTransformation):
|
||||
self.use_transformation_cache = True
|
||||
for t in self.transformation.transformations:
|
||||
if isinstance(t, RandomTransformation) or (
|
||||
hasattr(t, "random_one") and t.random_one
|
||||
):
|
||||
self.use_transformation_cache = False
|
||||
break
|
||||
else:
|
||||
self.use_transformation_cache = True
|
||||
self.transformation_cache_size = transformation_cache_size
|
||||
self.transformation_cache = lru.LRU(transformation_cache_size)
|
||||
|
||||
@@ -105,7 +120,8 @@ class Attack:
|
||||
pre_transformation_constraints=self.pre_transformation_constraints,
|
||||
**kwargs,
|
||||
)
|
||||
self.transformation_cache[current_text] = tuple(transformed_texts)
|
||||
if self.use_transformation_cache:
|
||||
self.transformation_cache[current_text] = tuple(transformed_texts)
|
||||
return transformed_texts
|
||||
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
@@ -123,12 +139,17 @@ class Attack:
|
||||
"Cannot call `get_transformations` without a transformation."
|
||||
)
|
||||
|
||||
if current_text in self.transformation_cache:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.transformation_cache[current_text] = self.transformation_cache[
|
||||
current_text
|
||||
]
|
||||
transformed_texts = list(self.transformation_cache[current_text])
|
||||
if self.use_transformation_cache:
|
||||
if current_text in self.transformation_cache:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.transformation_cache[current_text] = self.transformation_cache[
|
||||
current_text
|
||||
]
|
||||
transformed_texts = list(self.transformation_cache[current_text])
|
||||
else:
|
||||
transformed_texts = self._get_transformations_uncached(
|
||||
current_text, original_text, **kwargs
|
||||
)
|
||||
else:
|
||||
transformed_texts = self._get_transformations_uncached(
|
||||
current_text, original_text, **kwargs
|
||||
|
||||
Reference in New Issue
Block a user