1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/shared/attacked_text.py

412 lines
17 KiB
Python

from collections import OrderedDict
import math
import numpy as np
import torch
import textattack
from .utils import words_from_text
class AttackedText:
"""
A helper class that represents a string that can be attacked.
Models that take multiple sentences as input separate them by ``SPLIT_TOKEN``.
Attacks "see" the entire input, joined into one string, without the split token.
``AttackedText`` instances that were perturbed from other ``AttackedText``
objects contain a pointer to the previous text
(``attack_attrs["previous_attacked_text"]``), so that the full chain of
perturbations might be reconstructed by using this key to form a linked
list.
Args:
text (string): The string that this AttackedText represents
attack_attrs (dict): Dictionary of various attributes stored
during the course of an attack.
"""
SPLIT_TOKEN = ">>>>"
def __init__(self, text_input, attack_attrs=None):
# Read in ``text_input`` as a string or OrderedDict.
if isinstance(text_input, str):
self._text_input = OrderedDict([("text", text_input)])
elif isinstance(text_input, OrderedDict):
self._text_input = text_input
else:
raise TypeError(
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
)
# Find words in input lazily.
self._words = None
self._words_per_input = None
# Format text inputs.
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
if attack_attrs is None:
self.attack_attrs = dict()
elif isinstance(attack_attrs, dict):
self.attack_attrs = attack_attrs
else:
raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}")
# Indices of words from the *original* text. Allows us to map
# indices between original text and this text, and vice-versa.
self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words))
# A list of all indices in *this* text that have been modified.
self.attack_attrs.setdefault("modified_indices", set())
def __eq__(self, other):
""" Compares two text instances to make sure they have the same attack
attributes.
Since some elements stored in ``self.attack_attrs`` may be numpy
arrays, we have to take special care when comparing them.
"""
if not (self.text == other.text):
return False
for key in self.attack_attrs:
if key not in other.attack_attrs:
return False
elif isinstance(self.attack_attrs[key], np.ndarray):
if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape):
return False
elif not (self.attack_attrs[key] == other.attack_attrs[key]).all():
return False
else:
if not self.attack_attrs[key] == other.attack_attrs[key]:
return False
return True
def __hash__(self):
return hash(self.text)
def free_memory(self):
""" Delete items that take up memory.
Can be called once the AttackedText is only needed to display.
"""
if "previous_attacked_text" in self.attack_attrs:
self.attack_attrs["previous_attacked_text"].free_memory()
if "last_transformation" in self.attack_attrs:
del self.attack_attrs["last_transformation"]
for key in self.attack_attrs:
if isinstance(self.attack_attrs[key], torch.Tensor):
del self.attack_attrs[key]
def text_window_around_index(self, index, window_size):
""" The text window of ``window_size`` words centered around ``index``. """
length = self.num_words
half_size = (window_size - 1) / 2.0
if index - half_size < 0:
start = 0
end = min(window_size - 1, length - 1)
elif index + half_size >= length:
start = max(0, length - window_size)
end = length - 1
else:
start = index - math.ceil(half_size)
end = index + math.floor(half_size)
text_idx_start = self._text_index_of_word_index(start)
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
return self.text[text_idx_start:text_idx_end]
def _text_index_of_word_index(self, i):
""" Returns the index of word ``i`` in self.text. """
pre_words = self.words[: i + 1]
lower_text = self.text.lower()
# Find all words until `i` in string.
look_after_index = 0
for word in pre_words:
look_after_index = lower_text.find(word.lower(), look_after_index)
return look_after_index
def text_until_word_index(self, i):
""" Returns the text before the beginning of word at index ``i``. """
look_after_index = self._text_index_of_word_index(i)
return self.text[:look_after_index]
def text_after_word_index(self, i):
""" Returns the text after the end of word at index ``i``. """
# Get index of beginning of word then jump to end of word.
look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
return self.text[look_after_index:]
def first_word_diff(self, other_attacked_text):
""" Returns the first word in self.words that differs from
other_attacked_text. Useful for word swap strategies. """
w1 = self.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
return w1
return None
def first_word_diff_index(self, other_attacked_text):
""" Returns the index of the first word in self.words that differs
from other_attacked_text. Useful for word swap strategies. """
w1 = self.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
return i
return None
def all_words_diff(self, other_attacked_text):
""" Returns the set of indices for which this and other_attacked_text
have different words. """
indices = set()
w1 = self.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
indices.add(i)
return indices
def ith_word_diff(self, other_attacked_text, i):
""" Returns whether the word at index i differs from other_attacked_text
"""
w1 = self.words
w2 = other_attacked_text.words
if len(w1) - 1 < i or len(w2) - 1 < i:
return True
return w1[i] != w2[i]
def convert_from_original_idxs(self, idxs):
""" Takes indices of words from original string and converts them to
indices of the same words in the current string.
Uses information from ``self.attack_attrs['original_index_map']``,
which maps word indices from the original to perturbed text.
"""
if len(self.attack_attrs["original_index_map"]) == 0:
return idxs
elif isinstance(idxs, set):
idxs = list(idxs)
if isinstance(idxs, list) or isinstance(idxs, np.ndarray):
idxs = torch.tensor(idxs)
elif not isinstance(idxs, torch.Tensor):
raise TypeError(
f"convert_from_original_idxs got invalid idxs type {type(idxs)}"
)
return [self.attack_attrs["original_index_map"][i] for i in idxs]
def replace_words_at_indices(self, indices, new_words):
""" This code returns a new AttackedText object where the word at
``index`` is replaced with a new word."""
if len(indices) != len(new_words):
raise ValueError(
f"Cannot replace {len(new_words)} words at {len(indices)} indices."
)
words = self.words[:]
for i, new_word in zip(indices, new_words):
if not isinstance(new_word, str):
raise TypeError(
f"replace_words_at_indices requires ``str`` words, got {type(new_word)}"
)
if (i < 0) or (i > len(words)):
raise ValueError(f"Cannot assign word at index {i}")
words[i] = new_word
return self.generate_new_attacked_text(words)
def replace_word_at_index(self, index, new_word):
""" This code returns a new AttackedText object where the word at
``index`` is replaced with a new word.
"""
if not isinstance(new_word, str):
raise TypeError(
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
)
return self.replace_words_at_indices([index], [new_word])
def delete_word_at_index(self, index):
""" This code returns a new AttackedText object where the word at
``index`` is removed.
"""
return self.replace_word_at_index(index, "")
def insert_text_after_word_index(self, index, text):
""" Inserts a string before word at index ``index`` and attempts to add
appropriate spacing.
"""
if not isinstance(text, str):
raise TypeError(f"text must be an str, got type {type(text)}")
word_at_index = self.words[index]
new_text = " ".join((word_at_index, text))
return self.replace_word_at_index(index, new_text)
def insert_text_before_word_index(self, index, text):
""" Inserts a string before word at index ``index`` and attempts to add
appropriate spacing.
"""
if not isinstance(text, str):
raise TypeError(f"text must be an str, got type {type(text)}")
word_at_index = self.words[index]
# TODO if ``word_at_index`` is at the beginning of a sentence, we should
# optionally capitalize ``text``.
new_text = " ".join((text, word_at_index))
return self.replace_word_at_index(index, new_text)
def get_deletion_indices(self):
return self.attack_attrs["original_index_map"][
self.attack_attrs["original_index_map"] == -1
]
def generate_new_attacked_text(self, new_words):
""" Returns a new AttackedText object and replaces old list of words
with a new list of words, but preserves the punctuation and spacing
of the original message.
``self.words`` is a list of the words in the current text with
punctuation removed. However, each "word" in ``new_words``
could be an empty string, representing a word deletion, or a string
with multiple space-separated words, representation an insertion
of one or more words.
"""
perturbed_text = ""
original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values())
new_attack_attrs = dict()
if "label_names" in self.attack_attrs:
new_attack_attrs["label_names"] = self.attack_attrs["label_names"]
new_attack_attrs["newly_modified_indices"] = set()
# Point to previously monitored text.
new_attack_attrs["previous_attacked_text"] = self
# Use `new_attack_attrs` to track indices with respect to the original
# text.
new_attack_attrs["modified_indices"] = self.attack_attrs[
"modified_indices"
].copy()
new_attack_attrs["original_index_map"] = self.attack_attrs[
"original_index_map"
].copy()
new_i = 0
# Create the new attacked text by swapping out words from the original
# text with a sequence of 0+ words in the new text.
for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)):
word_start = original_text.index(input_word)
word_end = word_start + len(input_word)
perturbed_text += original_text[:word_start]
original_text = original_text[word_end:]
adv_num_words = len(words_from_text(adv_word_seq))
num_words_diff = adv_num_words - len(words_from_text(input_word))
# Track indices on insertions and deletions.
if num_words_diff != 0:
# Re-calculated modified indices. If words are inserted or deleted,
# they could change.
shifted_modified_indices = set()
for modified_idx in new_attack_attrs["modified_indices"]:
if modified_idx < i:
shifted_modified_indices.add(modified_idx)
elif modified_idx > i:
shifted_modified_indices.add(modified_idx + num_words_diff)
else:
pass
new_attack_attrs["modified_indices"] = shifted_modified_indices
# Track insertions and deletions wrt original text.
original_modification_idx = i
new_idx_map = new_attack_attrs["original_index_map"].copy()
if num_words_diff == -1:
new_idx_map[new_idx_map == i] = -1
new_idx_map[new_idx_map > i] += num_words_diff
new_attack_attrs["original_index_map"] = new_idx_map
# Move pointer and save indices of new modified words.
for j in range(i, i + adv_num_words):
if input_word != adv_word_seq:
new_attack_attrs["modified_indices"].add(new_i)
new_attack_attrs["newly_modified_indices"].add(new_i)
new_i += 1
# Check spaces for deleted text.
if adv_num_words == 0:
# Remove extra space (or else there would be two spaces for each
# deleted word).
# @TODO What to do with punctuation in this case? This behavior is undefined.
if i == 0:
# If the first word was deleted, take a subsequent space.
if original_text[0] == " ":
original_text = original_text[1:]
else:
# If a word other than the first was deleted, take a preceding space.
if perturbed_text[-1] == " ":
perturbed_text = perturbed_text[:-1]
# Add substitute word(s) to new sentence.
perturbed_text += adv_word_seq
perturbed_text += original_text # Add all of the ending punctuation.
# Reform perturbed_text into an OrderedDict.
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
perturbed_input = OrderedDict(
zip(self._text_input.keys(), perturbed_input_texts)
)
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
@property
def tokenizer_input(self):
""" The tuple of inputs to be passed to the tokenizer. """
return tuple(self._text_input.values())
@property
def column_labels(self):
""" Returns the labels for this text's columns. For single-sequence
inputs, this simply returns ['text'].
"""
return list(self._text_input.keys())
@property
def words_per_input(self):
""" Returns a list of lists of words corresponding to each input.
"""
if not self._words_per_input:
self._words_per_input = [
words_from_text(_input) for _input in self._text_input.values()
]
return self._words_per_input
@property
def words(self):
if not self._words:
self._words = words_from_text(self.text)
return self._words
@property
def text(self):
""" Represents full text input. Multiply inputs are joined with a line
break.
"""
return "\n".join(self._text_input.values())
@property
def num_words(self):
""" Returns the number of words in the sequence. """
return len(self.words)
def printable_text(self, key_color="bold", key_color_method=None):
""" Represents full text input. Adds field descriptions.
For example, entailment inputs look like:
```
premise: ...
hypothesis: ...
```
"""
# For single-sequence inputs, don't show a prefix.
if len(self._text_input) == 1:
return next(iter(self._text_input.values()))
# For multiple-sequence inputs, show a prefix and a colon. Optionally,
# color the key.
else:
if key_color_method:
ck = lambda k: textattack.shared.utils.color_text(
k, key_color, key_color_method
)
else:
ck = lambda k: k
return "\n".join(
f"{ck(key.capitalize())}: {value}"
for key, value in self._text_input.items()
)
def __repr__(self):
return f'<AttackedText "{self.text}">'