mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix bugs and refactor
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from collections import deque
|
||||
|
||||
import lru
|
||||
import numpy as np
|
||||
|
||||
import textattack
|
||||
from textattack.attack_results import (
|
||||
@@ -29,6 +28,7 @@ class Attack:
|
||||
constraints: A list of constraints to add to the attack, defining which perturbations are valid.
|
||||
transformation: The transformation applied at each step of the attack.
|
||||
search_method: A strategy for exploring the search space of possible perturbations
|
||||
transformation_cache_size (int): the number of items to keep in the transformations cache
|
||||
constraint_cache_size (int): the number of items to keep in the constraints cache
|
||||
"""
|
||||
|
||||
@@ -38,6 +38,7 @@ class Attack:
|
||||
constraints=[],
|
||||
transformation=None,
|
||||
search_method=None,
|
||||
transformation_cache_size=2 ** 20,
|
||||
constraint_cache_size=2 ** 20,
|
||||
):
|
||||
"""Initialize an attack object.
|
||||
@@ -74,6 +75,9 @@ class Attack:
|
||||
else:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
self.transformation_cache_size = transformation_cache_size
|
||||
self.transformation_cache = lru.LRU(transformation_cache_size)
|
||||
|
||||
self.constraint_cache_size = constraint_cache_size
|
||||
self.constraints_cache = lru.LRU(constraint_cache_size)
|
||||
|
||||
@@ -86,6 +90,24 @@ class Attack:
|
||||
)
|
||||
self.search_method.filter_transformations = self.filter_transformations
|
||||
|
||||
def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
|
||||
"""Applies ``self.transformation`` to ``text``, then filters the list
|
||||
of possible transformations through the applicable constraints.
|
||||
|
||||
Args:
|
||||
current_text: The current ``AttackedText`` on which to perform the transformations.
|
||||
original_text: The original ``AttackedText`` from which the attack started.
|
||||
Returns:
|
||||
A filtered list of transformations where each transformation matches the constraints
|
||||
"""
|
||||
transformed_texts = self.transformation(
|
||||
current_text,
|
||||
pre_transformation_constraints=self.pre_transformation_constraints,
|
||||
**kwargs,
|
||||
)
|
||||
self.transformation_cache[current_text] = tuple(transformed_texts)
|
||||
return transformed_texts
|
||||
|
||||
def get_transformations(self, current_text, original_text=None, **kwargs):
|
||||
"""Applies ``self.transformation`` to ``text``, then filters the list
|
||||
of possible transformations through the applicable constraints.
|
||||
@@ -93,8 +115,6 @@ class Attack:
|
||||
Args:
|
||||
current_text: The current ``AttackedText`` on which to perform the transformations.
|
||||
original_text: The original ``AttackedText`` from which the attack started.
|
||||
apply_constraints: Whether or not to apply post-transformation constraints.
|
||||
|
||||
Returns:
|
||||
A filtered list of transformations where each transformation matches the constraints
|
||||
"""
|
||||
@@ -103,13 +123,17 @@ class Attack:
|
||||
"Cannot call `get_transformations` without a transformation."
|
||||
)
|
||||
|
||||
transformed_texts = np.array(
|
||||
self.transformation(
|
||||
current_text,
|
||||
pre_transformation_constraints=self.pre_transformation_constraints,
|
||||
**kwargs,
|
||||
if current_text in self.transformation_cache:
|
||||
# promote transformed_text to the top of the LRU cache
|
||||
self.transformation_cache[current_text] = self.transformation_cache[
|
||||
current_text
|
||||
]
|
||||
transformed_texts = list(self.transformation_cache[current_text])
|
||||
else:
|
||||
transformed_texts = self._get_transformations_uncached(
|
||||
current_text, original_text, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
return self.filter_transformations(
|
||||
transformed_texts, current_text, original_text
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user