mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
from textattack.constraints.pre_transformation import PreTransformationConstraint
|
|
|
|
|
|
class InputColumnModification(PreTransformationConstraint):
|
|
"""
|
|
A constraint disallowing the modification of words within a specific input
|
|
column.
|
|
|
|
For example, can prevent modification of 'premise' during
|
|
entailment.
|
|
"""
|
|
|
|
def __init__(self, matching_column_labels, columns_to_ignore):
|
|
self.matching_column_labels = matching_column_labels
|
|
self.columns_to_ignore = columns_to_ignore
|
|
|
|
def _get_modifiable_indices(self, current_text):
|
|
""" Returns the word indices in current_text which are able to be
|
|
deleted.
|
|
|
|
If ``current_text.column_labels`` doesn't match
|
|
``self.matching_column_labels``, do nothing, and allow all words
|
|
to be modified.
|
|
|
|
If it does match, only allow words to be modified if they are not
|
|
in columns from ``columns_to_ignore``.
|
|
"""
|
|
if current_text.column_labels != self.matching_column_labels:
|
|
return set(range(len(current_text.words)))
|
|
|
|
idx = 0
|
|
indices_to_modify = set()
|
|
for column, words in zip(
|
|
current_text.column_labels, current_text.words_per_input
|
|
):
|
|
num_words = len(words)
|
|
if column not in self.columns_to_ignore:
|
|
indices_to_modify |= set(range(idx, idx + num_words))
|
|
idx += num_words
|
|
return indices_to_modify
|
|
|
|
def extra_repr_keys(self):
|
|
return ["matching_column_labels", "columns_to_ignore"]
|