1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

rename modification constraint to pre-transformation constraint

This commit is contained in:
uvafan
2020-05-19 15:11:55 -04:00
parent 30cd452ed1
commit a36e4c7bbe
6 changed files with 21 additions and 17 deletions

View File

@@ -1,5 +1,5 @@
from .constraint import Constraint
from .modification_constraint import ModificationConstraint
from .pre_transformation_constraint import PreTransformationConstraint
from . import grammaticality
from . import semantics

View File

@@ -4,10 +4,13 @@
from textattack.shared.utils import default_class_repr
from textattack.constraints import Constraint
class ModificationConstraint(Constraint):
class PreTransformationConstraint(Constraint):
"""
An abstract class that represents constraints which apply only
to which words can be modified.
An abstract class that represents constraints which are applied before
the transformation. These restrict which words are allowed to be modified
during the transformation. For example, we might not allow stopwords to be
modified.
"""
def __call__(self, x, transformation):

View File

@@ -2,9 +2,9 @@
"""
from textattack.shared.utils import default_class_repr
from textattack.constraints import ModificationConstraint
from textattack.constraints import PreTransformationConstraint
class RepeatModification(ModificationConstraint):
class RepeatModification(PreTransformationConstraint):
"""
A constraint disallowing the modification of words which have already been modified.
"""

View File

@@ -2,11 +2,11 @@
"""
from textattack.shared.utils import default_class_repr
from textattack.constraints import ModificationConstraint
from textattack.constraints import PreTransformationConstraint
from textattack.shared.validators import transformation_consists_of_word_swaps
import nltk
class StopwordModification(ModificationConstraint):
class StopwordModification(PreTransformationConstraint):
"""
A constraint disallowing the modification of stopwords
"""

View File

@@ -4,7 +4,7 @@ import os
import random
from textattack.shared import utils
from textattack.constraints import Constraint, ModificationConstraint
from textattack.constraints import Constraint, PreTransformationConstraint
from textattack.shared import TokenizedText
from textattack.attack_results import SkippedAttackResult, SuccessfulAttackResult, FailedAttackResult
@@ -45,10 +45,10 @@ class Attack:
raise ValueError('SearchMethod {self.search_method} incompatible with transformation {self.transformation}')
self.constraints = []
self.modification_constraints = []
self.pre_transformation_constraints = []
for constraint in constraints:
if isinstance(constraint, ModificationConstraint):
self.modification_constraints.append(constraint)
if isinstance(constraint, PreTransformationConstraint):
self.pre_transformation_constraints.append(constraint)
else:
self.constraints.append(constraint)
@@ -67,7 +67,7 @@ class Attack:
transformation:
text:
original text (:obj:`type`, optional): Defaults to None.
apply_constraints: Whether or not to apply non-modification constraints
apply_constraints: Whether or not to apply post-transformation constraints
**kwargs:
Returns:
@@ -78,7 +78,8 @@ class Attack:
raise RuntimeError('Cannot call `get_transformations` without a transformation.')
transformations = np.array(self.transformation(text,
modification_constraints=self.modification_constraints, **kwargs))
pre_transformation_constraints=self.pre_transformation_constraints,
**kwargs))
if apply_constraints:
return self._filter_transformations(transformations, text, original_text)
return transformations
@@ -230,7 +231,7 @@ class Attack:
)
# self.constraints
constraints_lines = []
constraints = self.constraints + self.modification_constraints
constraints = self.constraints + self.pre_transformation_constraints
if len(constraints):
for i, constraint in enumerate(constraints):
constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2))

View File

@@ -7,13 +7,13 @@ class Transformation:
"""
def __call__(self, tokenized_text, modification_constraints=[], indices_to_modify=None):
def __call__(self, tokenized_text, pre_transformation_constraints=[], indices_to_modify=None):
""" Returns a list of all possible transformations for `tokenized_text`."""
if indices_to_modify is None:
indices_to_modify = set(range(len(tokenized_text.words)))
else:
indices_to_modify = set(indices_to_modify)
for constraint in modification_constraints:
for constraint in pre_transformation_constraints:
indices_to_modify = indices_to_modify & constraint(tokenized_text, self)
transformed_texts = self._get_transformations(tokenized_text, indices_to_modify)
for text in transformed_texts: