1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

rename modification constraint to pre-transformation constraint

This commit is contained in:
uvafan
2020-05-19 15:11:55 -04:00
parent 30cd452ed1
commit a36e4c7bbe
6 changed files with 21 additions and 17 deletions

View File

@@ -1,5 +1,5 @@
from .constraint import Constraint from .constraint import Constraint
from .modification_constraint import ModificationConstraint from .pre_transformation_constraint import PreTransformationConstraint
from . import grammaticality from . import grammaticality
from . import semantics from . import semantics

View File

@@ -4,10 +4,13 @@
from textattack.shared.utils import default_class_repr from textattack.shared.utils import default_class_repr
from textattack.constraints import Constraint from textattack.constraints import Constraint
class ModificationConstraint(Constraint):
class PreTransformationConstraint(Constraint):
""" """
An abstract class that represents constraints which apply only An abstract class that represents constraints which are applied before
to which words can be modified. the transformation. These restrict which words are allowed to be modified
during the transformation. For example, we might not allow stopwords to be
modified.
""" """
def __call__(self, x, transformation): def __call__(self, x, transformation):

View File

@@ -2,9 +2,9 @@
""" """
from textattack.shared.utils import default_class_repr from textattack.shared.utils import default_class_repr
from textattack.constraints import ModificationConstraint from textattack.constraints import PreTransformationConstraint
class RepeatModification(ModificationConstraint): class RepeatModification(PreTransformationConstraint):
""" """
A constraint disallowing the modification of words which have already been modified. A constraint disallowing the modification of words which have already been modified.
""" """

View File

@@ -2,11 +2,11 @@
""" """
from textattack.shared.utils import default_class_repr from textattack.shared.utils import default_class_repr
from textattack.constraints import ModificationConstraint from textattack.constraints import PreTransformationConstraint
from textattack.shared.validators import transformation_consists_of_word_swaps from textattack.shared.validators import transformation_consists_of_word_swaps
import nltk import nltk
class StopwordModification(ModificationConstraint): class StopwordModification(PreTransformationConstraint):
""" """
A constraint disallowing the modification of stopwords A constraint disallowing the modification of stopwords
""" """

View File

@@ -4,7 +4,7 @@ import os
import random import random
from textattack.shared import utils from textattack.shared import utils
from textattack.constraints import Constraint, ModificationConstraint from textattack.constraints import Constraint, PreTransformationConstraint
from textattack.shared import TokenizedText from textattack.shared import TokenizedText
from textattack.attack_results import SkippedAttackResult, SuccessfulAttackResult, FailedAttackResult from textattack.attack_results import SkippedAttackResult, SuccessfulAttackResult, FailedAttackResult
@@ -45,10 +45,10 @@ class Attack:
raise ValueError('SearchMethod {self.search_method} incompatible with transformation {self.transformation}') raise ValueError('SearchMethod {self.search_method} incompatible with transformation {self.transformation}')
self.constraints = [] self.constraints = []
self.modification_constraints = [] self.pre_transformation_constraints = []
for constraint in constraints: for constraint in constraints:
if isinstance(constraint, ModificationConstraint): if isinstance(constraint, PreTransformationConstraint):
self.modification_constraints.append(constraint) self.pre_transformation_constraints.append(constraint)
else: else:
self.constraints.append(constraint) self.constraints.append(constraint)
@@ -67,7 +67,7 @@ class Attack:
transformation: transformation:
text: text:
original text (:obj:`type`, optional): Defaults to None. original text (:obj:`type`, optional): Defaults to None.
apply_constraints: Whether or not to apply non-modification constraints apply_constraints: Whether or not to apply post-transformation constraints
**kwargs: **kwargs:
Returns: Returns:
@@ -78,7 +78,8 @@ class Attack:
raise RuntimeError('Cannot call `get_transformations` without a transformation.') raise RuntimeError('Cannot call `get_transformations` without a transformation.')
transformations = np.array(self.transformation(text, transformations = np.array(self.transformation(text,
modification_constraints=self.modification_constraints, **kwargs)) pre_transformation_constraints=self.pre_transformation_constraints,
**kwargs))
if apply_constraints: if apply_constraints:
return self._filter_transformations(transformations, text, original_text) return self._filter_transformations(transformations, text, original_text)
return transformations return transformations
@@ -230,7 +231,7 @@ class Attack:
) )
# self.constraints # self.constraints
constraints_lines = [] constraints_lines = []
constraints = self.constraints + self.modification_constraints constraints = self.constraints + self.pre_transformation_constraints
if len(constraints): if len(constraints):
for i, constraint in enumerate(constraints): for i, constraint in enumerate(constraints):
constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2)) constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2))

View File

@@ -7,13 +7,13 @@ class Transformation:
""" """
def __call__(self, tokenized_text, modification_constraints=[], indices_to_modify=None): def __call__(self, tokenized_text, pre_transformation_constraints=[], indices_to_modify=None):
""" Returns a list of all possible transformations for `tokenized_text`.""" """ Returns a list of all possible transformations for `tokenized_text`."""
if indices_to_modify is None: if indices_to_modify is None:
indices_to_modify = set(range(len(tokenized_text.words))) indices_to_modify = set(range(len(tokenized_text.words)))
else: else:
indices_to_modify = set(indices_to_modify) indices_to_modify = set(indices_to_modify)
for constraint in modification_constraints: for constraint in pre_transformation_constraints:
indices_to_modify = indices_to_modify & constraint(tokenized_text, self) indices_to_modify = indices_to_modify & constraint(tokenized_text, self)
transformed_texts = self._get_transformations(tokenized_text, indices_to_modify) transformed_texts = self._get_transformations(tokenized_text, indices_to_modify)
for text in transformed_texts: for text in transformed_texts: