1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/constraints/constraint.py
2020-07-03 16:01:19 -04:00

128 lines
5.1 KiB
Python

from abc import ABC, abstractmethod
import textattack
from textattack.shared.utils import default_class_repr
class Constraint(ABC):
"""
An abstract class that represents constraints on adversial text examples.
Constraints evaluate whether transformations from a ``AttackedText`` to another
``AttackedText`` meet certain conditions.
"""
def call_many(self, transformed_texts, current_text, original_text=None):
"""
Filters ``transformed_texts`` based on which transformations fulfill the constraint.
First checks compatibility with latest ``Transformation``, then calls
``_check_constraint_many``\.
Args:
transformed_texts: The candidate transformed ``AttackedText``\s.
current_text: The current ``AttackedText``.
original_text: The original ``AttackedText`` from which the attack began.
"""
incompatible_transformed_texts = []
compatible_transformed_texts = []
for transformed_text in transformed_texts:
try:
if self.check_compatibility(
transformed_text.attack_attrs["last_transformation"]
):
compatible_transformed_texts.append(transformed_text)
else:
incompatible_transformed_texts.append(transformed_text)
except KeyError:
raise KeyError(
"transformed_text must have `last_transformation` attack_attr to apply constraint"
)
filtered_texts = self._check_constraint_many(
compatible_transformed_texts, current_text, original_text=original_text
)
return list(filtered_texts) + incompatible_transformed_texts
def _check_constraint_many(
self, transformed_texts, current_text, original_text=None
):
"""
Filters ``transformed_texts`` based on which transformations fulfill the constraint.
Calls ``check_constraint``\.
Args:
transformed_texts: The candidate transformed ``AttackedText``\s.
current_text: The current ``AttackedText``.
original_text: The original ``AttackedText`` from which the attack began.
"""
return [
transformed_text
for transformed_text in transformed_texts
if self._check_constraint(
transformed_text, current_text, original_text=original_text
)
]
def __call__(self, transformed_text, current_text, original_text=None):
"""
Returns True if the constraint is fulfilled, False otherwise. First checks
compatibility with latest ``Transformation``, then calls ``_check_constraint``\.
Args:
transformed_text: The candidate transformed ``AttackedText``.
current_text: The current ``AttackedText``.
original_text: The original ``AttackedText`` from which the attack began.
"""
if not isinstance(transformed_text, textattack.shared.AttackedText):
raise TypeError("transformed_text must be of type AttackedText")
if not isinstance(current_text, textattack.shared.AttackedText):
raise TypeError("current_text must be of type AttackedText")
try:
if not self.check_compatibility(
transformed_text.attack_attrs["last_transformation"]
):
return True
except KeyError:
raise KeyError(
"x_adv must have `last_transformation` attack_attr to apply constraint."
)
return self._check_constraint(
transformed_text, current_text, original_text=original_text
)
@abstractmethod
def _check_constraint(self, transformed_text, current_text, original_text=None):
"""
Returns True if the constraint is fulfilled, False otherwise. Must be overridden by
the specific constraint.
Args:
transformed_text: The candidate transformed ``AttackedText``.
current_text: The current ``AttackedText``.
original_text: The original ``AttackedText`` from which the attack began.
"""
raise NotImplementedError()
def check_compatibility(self, transformation):
"""
Checks if this constraint is compatible with the given transformation.
For example, the ``WordEmbeddingDistance`` constraint compares the embedding of
the word inserted with that of the word deleted. Therefore it can only be
applied in the case of word swaps, and not for transformations which involve
only one of insertion or deletion.
Args:
transformation: The ``Transformation`` to check compatibility with.
"""
return True
def extra_repr_keys(self):
"""Set the extra representation of the constraint using these keys.
To print customized extra information, you should reimplement
this method in your own constraint. Both single-line and multi-line
strings are acceptable.
"""
return []
__str__ = __repr__ = default_class_repr