1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix bugs and refactor

This commit is contained in:
Jin Yong Yoo
2020-07-16 09:12:59 -04:00
parent 8f0c443e66
commit d65b5963ed
4 changed files with 73 additions and 31 deletions

View File

@@ -1,7 +1,6 @@
from collections import deque
import lru
import numpy as np
import textattack
from textattack.attack_results import (
@@ -29,6 +28,7 @@ class Attack:
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
transformation: The transformation applied at each step of the attack.
search_method: A strategy for exploring the search space of possible perturbations
transformation_cache_size (int): the number of items to keep in the transformations cache
constraint_cache_size (int): the number of items to keep in the constraints cache
"""
@@ -38,6 +38,7 @@ class Attack:
constraints=[],
transformation=None,
search_method=None,
transformation_cache_size=2 ** 20,
constraint_cache_size=2 ** 20,
):
"""Initialize an attack object.
@@ -74,6 +75,9 @@ class Attack:
else:
self.constraints.append(constraint)
self.transformation_cache_size = transformation_cache_size
self.transformation_cache = lru.LRU(transformation_cache_size)
self.constraint_cache_size = constraint_cache_size
self.constraints_cache = lru.LRU(constraint_cache_size)
@@ -86,6 +90,24 @@ class Attack:
)
self.search_method.filter_transformations = self.filter_transformations
def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
"""Applies ``self.transformation`` to ``text``, then filters the list
of possible transformations through the applicable constraints.
Args:
current_text: The current ``AttackedText`` on which to perform the transformations.
original_text: The original ``AttackedText`` from which the attack started.
Returns:
A filtered list of transformations where each transformation matches the constraints
"""
transformed_texts = self.transformation(
current_text,
pre_transformation_constraints=self.pre_transformation_constraints,
**kwargs,
)
self.transformation_cache[current_text] = tuple(transformed_texts)
return transformed_texts
def get_transformations(self, current_text, original_text=None, **kwargs):
"""Applies ``self.transformation`` to ``text``, then filters the list
of possible transformations through the applicable constraints.
@@ -93,8 +115,6 @@ class Attack:
Args:
current_text: The current ``AttackedText`` on which to perform the transformations.
original_text: The original ``AttackedText`` from which the attack started.
apply_constraints: Whether or not to apply post-transformation constraints.
Returns:
A filtered list of transformations where each transformation matches the constraints
"""
@@ -103,13 +123,17 @@ class Attack:
"Cannot call `get_transformations` without a transformation."
)
transformed_texts = np.array(
self.transformation(
current_text,
pre_transformation_constraints=self.pre_transformation_constraints,
**kwargs,
if current_text in self.transformation_cache:
# promote transformed_text to the top of the LRU cache
self.transformation_cache[current_text] = self.transformation_cache[
current_text
]
transformed_texts = list(self.transformation_cache[current_text])
else:
transformed_texts = self._get_transformations_uncached(
current_text, original_text, **kwargs
)
)
return self.filter_transformations(
transformed_texts, current_text, original_text
)