mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
138 lines
6.0 KiB
Python
138 lines
6.0 KiB
Python
"""
|
|
Constraints determine whether a given transformation is valid. Since transformations do not perfectly preserve semantics semantics or grammaticality, constraints can increase the likelihood that the resulting transformation preserves these qualities. All constraints are subclasses of the ``Constraint`` abstract class, and must implement at least one of ``__call__`` or ``call_many``.
|
|
|
|
We split constraints into three main categories.
|
|
|
|
:ref:`Semantics`: Based on the meaning of the input and perturbation.
|
|
|
|
:ref:`Grammaticality`: Based on syntactic properties like part-of-speech and grammar.
|
|
|
|
:ref:`Overlap`: Based on character-based properties, like edit distance.
|
|
|
|
A fourth type of constraint restricts the search method from exploring certain parts of the search space:
|
|
|
|
:ref:`pre_transformation`: Based on the input and index of word replacement.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
import textattack
|
|
from textattack.shared.utils import default_class_repr
|
|
|
|
|
|
class Constraint(ABC):
|
|
"""An abstract class that represents constraints on adversial text
|
|
examples. Constraints evaluate whether transformations from a
|
|
``AttackedText`` to another ``AttackedText`` meet certain conditions.
|
|
|
|
Args:
|
|
compare_against_original (bool): If `True`, the reference text should be the original text under attack.
|
|
If `False`, the reference text is the most recent text from which the transformed text was generated.
|
|
All constraints must have this attribute.
|
|
"""
|
|
|
|
def __init__(self, compare_against_original):
|
|
self.compare_against_original = compare_against_original
|
|
|
|
def call_many(self, transformed_texts, reference_text):
|
|
"""Filters ``transformed_texts`` based on which transformations fulfill
|
|
the constraint. First checks compatibility with latest
|
|
``Transformation``, then calls ``_check_constraint_many``
|
|
|
|
Args:
|
|
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``'s.
|
|
reference_text (AttackedText): The ``AttackedText`` to compare against.
|
|
"""
|
|
incompatible_transformed_texts = []
|
|
compatible_transformed_texts = []
|
|
for transformed_text in transformed_texts:
|
|
try:
|
|
if self.check_compatibility(
|
|
transformed_text.attack_attrs["last_transformation"]
|
|
):
|
|
compatible_transformed_texts.append(transformed_text)
|
|
else:
|
|
incompatible_transformed_texts.append(transformed_text)
|
|
except KeyError:
|
|
raise KeyError(
|
|
"transformed_text must have `last_transformation` attack_attr to apply constraint"
|
|
)
|
|
filtered_texts = self._check_constraint_many(
|
|
compatible_transformed_texts, reference_text
|
|
)
|
|
return list(filtered_texts) + incompatible_transformed_texts
|
|
|
|
def _check_constraint_many(self, transformed_texts, reference_text):
|
|
"""Filters ``transformed_texts`` based on which transformations fulfill
|
|
the constraint. Calls ``check_constraint``
|
|
|
|
Args:
|
|
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``
|
|
reference_texts (AttackedText): The ``AttackedText`` to compare against.
|
|
"""
|
|
return [
|
|
transformed_text
|
|
for transformed_text in transformed_texts
|
|
if self._check_constraint(transformed_text, reference_text)
|
|
]
|
|
|
|
def __call__(self, transformed_text, reference_text):
|
|
"""Returns True if the constraint is fulfilled, False otherwise. First
|
|
checks compatibility with latest ``Transformation``, then calls
|
|
``_check_constraint``
|
|
|
|
Args:
|
|
transformed_text (AttackedText): The candidate transformed ``AttackedText``.
|
|
reference_text (AttackedText): The ``AttackedText`` to compare against.
|
|
"""
|
|
if not isinstance(transformed_text, textattack.shared.AttackedText):
|
|
raise TypeError("transformed_text must be of type AttackedText")
|
|
if not isinstance(reference_text, textattack.shared.AttackedText):
|
|
raise TypeError("reference_text must be of type AttackedText")
|
|
|
|
try:
|
|
if not self.check_compatibility(
|
|
transformed_text.attack_attrs["last_transformation"]
|
|
):
|
|
return True
|
|
except KeyError:
|
|
raise KeyError(
|
|
"`transformed_text` must have `last_transformation` attack_attr to apply constraint."
|
|
)
|
|
return self._check_constraint(transformed_text, reference_text)
|
|
|
|
@abstractmethod
|
|
def _check_constraint(self, transformed_text, reference_text):
|
|
"""Returns True if the constraint is fulfilled, False otherwise. Must
|
|
be overridden by the specific constraint.
|
|
|
|
Args:
|
|
transformed_text: The candidate transformed ``AttackedText``.
|
|
reference_text (AttackedText): The ``AttackedText`` to compare against.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def check_compatibility(self, transformation):
|
|
"""Checks if this constraint is compatible with the given
|
|
transformation. For example, the ``WordEmbeddingDistance`` constraint
|
|
compares the embedding of the word inserted with that of the word
|
|
deleted. Therefore it can only be applied in the case of word swaps,
|
|
and not for transformations which involve only one of insertion or
|
|
deletion.
|
|
|
|
Args:
|
|
transformation: The ``Transformation`` to check compatibility with.
|
|
"""
|
|
return True
|
|
|
|
def extra_repr_keys(self):
|
|
"""Set the extra representation of the constraint using these keys.
|
|
|
|
To print customized extra information, you should reimplement
|
|
this method in your own constraint. Both single-line and multi-
|
|
line strings are acceptable.
|
|
"""
|
|
return ["compare_against_original"]
|
|
|
|
__str__ = __repr__ = default_class_repr
|