1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

update attackedtext references, need to update tokenization

This commit is contained in:
Jack Morris
2020-06-16 21:56:33 -04:00
parent 9bf213f7fd
commit d25bf44f52
22 changed files with 189 additions and 166 deletions

View File

@@ -8,7 +8,7 @@ lint: FORCE ## Run black (in check mode)
isort --check-only --recursive tests textattack
test: FORCE ## Run tests using pytest
python -m pytest
python -m pytest -x
docs: FORCE ## Build docs using Sphinx.
sphinx-build -b html docs docs/_build/html

View File

@@ -32,18 +32,18 @@ class AttackResult:
# We don't want the AttackedText `ids` sticking around clogging up
# space on our devices. Delete them here, if they're still present,
# because we won't need them anymore anyway.
self.original_result.tokenized_text.free_memory()
self.perturbed_result.tokenized_text.free_memory()
self.original_result.attacked_text.free_memory()
self.perturbed_result.attacked_text.free_memory()
def original_text(self):
""" Returns the text portion of `self.original_result`. Helper method.
"""
return self.original_result.tokenized_text.clean_text()
return self.original_result.attacked_text.printable_text
def perturbed_text(self):
""" Returns the text portion of `self.perturbed_result`. Helper method.
"""
return self.original_result.tokenized_text.clean_text()
return self.original_result.attacked_text.printable_text
def str_lines(self, color_method=None):
""" A list of the lines to be printed for this result's string
@@ -67,11 +67,11 @@ class AttackResult:
def diff_color(self, color_method=None):
""" Highlights the difference between two texts using color.
"""
t1 = self.original_result.tokenized_text
t2 = self.perturbed_result.tokenized_text
t1 = self.original_result.attacked_text
t2 = self.perturbed_result.attacked_text
if color_method is None:
return t1.clean_text(), t2.clean_text()
return t1.printable_text, t2.printable_text
color_1 = self.original_result.get_text_color_input()
color_2 = self.perturbed_result.get_text_color_perturbed()
@@ -107,11 +107,11 @@ class AttackResult:
i1 += 1
i2 += 1
t1 = self.original_result.tokenized_text.replace_words_at_indices(
t1 = self.original_result.attacked_text.replace_words_at_indices(
words_1_idxs, words_1
)
t2 = self.perturbed_result.tokenized_text.replace_words_at_indices(
t2 = self.perturbed_result.attacked_text.replace_words_at_indices(
words_2_idxs, words_2
)
return t1.clean_text(), t2.clean_text()
return t1.printable_text, '\n', t2.printable_text

View File

@@ -60,13 +60,13 @@ class Augmenter:
Returns all possible augmentations of ``text`` according to
``self.transformation``.
"""
tokenized_text = AttackedText(text, DummyTokenizer())
original_text = tokenized_text
attacked_text = AttackedText(text)
original_text = attacked_text
all_transformed_texts = set()
for _ in range(self.transformations_per_example):
index_order = list(range(len(tokenized_text.words)))
index_order = list(range(len(attacked_text.words)))
random.shuffle(index_order)
current_text = tokenized_text
current_text = attacked_text
words_swapped = 0
for i in index_order:
transformed_texts = self.transformation(
@@ -87,7 +87,7 @@ class Augmenter:
if words_swapped == self.num_words_to_swap:
break
all_transformed_texts.add(current_text)
return [t.clean_text() for t in all_transformed_texts]
return [t.printable_text for t in all_transformed_texts]
def augment_many(self, text_list, show_progress=False):
"""
@@ -124,14 +124,3 @@ class Augmenter:
all_text_list.extend([text] + augmented_texts)
all_id_list.extend([_id] * (1 + len(augmented_texts)))
return all_text_list, all_id_list
class DummyTokenizer:
"""
A dummy tokenizer class. Data augmentation applies a transformation
without querying a model, which means that tokenization is unnecessary.
In this case, we pass a dummy tokenizer to `AttackedText`.
"""
def encode(self, _):
return []

View File

@@ -43,8 +43,8 @@ class GPT2(LanguageModelConstraint):
predictions = outputs[0]
probs = []
for tokenized_text in text_list:
nxt_word_ids = self.tokenizer.encode(tokenized_text.words[word_index])
for attacked_text in text_list:
nxt_word_ids = self.tokenizer.encode(attacked_text.words[word_index])
next_word_prob = predictions[0, -1, next_word_ids[0]]
probs.append(next_word_prob)

View File

@@ -19,8 +19,8 @@ class LanguageTool(Constraint):
self.grammar_error_threshold = grammar_error_threshold
self.grammar_error_cache = {}
def get_errors(self, tokenized_text, use_cache=False):
text = tokenized_text.clean_text()
def get_errors(self, attacked_text, use_cache=False):
text = attacked_text.printable_text
if use_cache:
if text not in self.grammar_error_cache:
self.grammar_error_cache[text] = len(self.lang_tool.check(text))

View File

@@ -1,3 +1,5 @@
import collections
from textattack.datasets import TextAttackDataset
from textattack.shared import AttackedText
@@ -30,8 +32,7 @@ class EntailmentDataset(TextAttackDataset):
except ValueError:
# If the label is not an integer, it's a label description.
label = self._label_str_to_int(label)
text_input = collections.OrderedDict([
('premise', premise),
('hypothesis', hypothesis),
])
text_input = collections.OrderedDict(
[("premise", premise), ("hypothesis", hypothesis),]
)
return (text_input, label)

View File

@@ -1,5 +1,6 @@
import collections
import random
import nlp
from textattack.datasets import TextAttackDataset
@@ -86,7 +87,9 @@ class HuggingFaceNLPDataset(TextAttackDataset):
# Convert `raw_example` to an OrderedDict, so that we know which order
# in which to pass examples to the model.
input_dict = collections.OrderedDict([raw_example[c] for c in raw_example])
input_dict = collections.OrderedDict(
[(c, raw_example[c]) for c in self.input_columns]
)
output = raw_example[self.output_column]
if self.label_map:

View File

@@ -6,14 +6,14 @@ class GoalFunctionResult:
Represents the result of a goal function evaluating a AttackedText object.
Args:
tokenized_text: The sequence that was evaluated.
attacked_text: The sequence that was evaluated.
output: The display-friendly output.
succeeded: Whether the goal has been achieved.
score: A score representing how close the model is to achieving its goal.
"""
def __init__(self, tokenized_text, output, succeeded, score):
self.tokenized_text = tokenized_text
def __init__(self, attacked_text, output, succeeded, score):
self.attacked_text = attacked_text
self.output = output
self.score = score
self.succeeded = succeeded

View File

@@ -10,7 +10,7 @@ from textattack.shared.utils import batch_model_predict, default_class_repr
class GoalFunction:
"""
Evaluates how well a perturbed tokenized_text object is achieving a specified goal.
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
Args:
model: The PyTorch or TensorFlow model used for evaluation.
@@ -40,32 +40,32 @@ class GoalFunction:
else:
self._call_model_cache = None
def should_skip(self, tokenized_text, ground_truth_output):
def should_skip(self, attacked_text, ground_truth_output):
"""
Returns whether or not the goal has already been completed for ``tokenized_text``\,
Returns whether or not the goal has already been completed for ``attacked_text``\,
due to misprediction by the model.
"""
model_outputs = self._call_model([tokenized_text])
model_outputs = self._call_model([attacked_text])
return self._is_goal_complete(model_outputs[0], ground_truth_output)
def get_output(self, tokenized_text):
def get_output(self, attacked_text):
"""
Returns output for display based on the result of calling the model.
"""
return self._get_displayed_output(self._call_model([tokenized_text])[0])
return self._get_displayed_output(self._call_model([attacked_text])[0])
def get_result(self, tokenized_text, ground_truth_output):
def get_result(self, attacked_text, ground_truth_output):
"""
A helper method that queries `self.get_results` with a single
``AttackedText`` object.
"""
results, search_over = self.get_results([tokenized_text], ground_truth_output)
results, search_over = self.get_results([attacked_text], ground_truth_output)
result = results[0] if len(results) else None
return result, search_over
def get_results(self, tokenized_text_list, ground_truth_output):
def get_results(self, attacked_text_list, ground_truth_output):
"""
For each tokenized_text object in tokenized_text_list, returns a result
For each attacked_text object in attacked_text_list, returns a result
consisting of whether or not the goal has been achieved, the output for
display purposes, and a score. Additionally returns whether the search
is over due to the query budget.
@@ -73,16 +73,16 @@ class GoalFunction:
results = []
if self.query_budget < float("inf"):
queries_left = self.query_budget - self.num_queries
tokenized_text_list = tokenized_text_list[:queries_left]
self.num_queries += len(tokenized_text_list)
model_outputs = self._call_model(tokenized_text_list)
for tokenized_text, raw_output in zip(tokenized_text_list, model_outputs):
attacked_text_list = attacked_text_list[:queries_left]
self.num_queries += len(attacked_text_list)
model_outputs = self._call_model(attacked_text_list)
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
displayed_output = self._get_displayed_output(raw_output)
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
goal_function_score = self._get_score(raw_output, ground_truth_output)
results.append(
self._goal_function_result_type()(
tokenized_text, displayed_output, succeeded, goal_function_score
attacked_text, displayed_output, succeeded, goal_function_score
)
)
return results, self.num_queries == self.query_budget
@@ -111,31 +111,31 @@ class GoalFunction:
"""
raise NotImplementedError()
def _call_model_uncached(self, tokenized_text_list):
def _call_model_uncached(self, attacked_text_list):
"""
Queries model and returns outputs for a list of AttackedText
objects.
"""
if not len(tokenized_text_list):
if not len(attacked_text_list):
return []
ids = utils.batch_tokenize(self.tokenizer, tokenized_text_list)
ids = utils.batch_tokenize(self.tokenizer, attacked_text_list)
with torch.no_grad():
outputs = batch_model_predict(self.model, ids)
return self._process_model_outputs(tokenized_text_list, outputs)
return self._process_model_outputs(attacked_text_list, outputs)
def _call_model(self, tokenized_text_list):
def _call_model(self, attacked_text_list):
""" Gets predictions for a list of `AttackedText` objects.
Gets prediction from cache if possible. If prediction is not in the
cache, queries model and stores prediction in cache.
"""
if not self.use_cache:
return self._call_model_uncached(tokenized_text_list)
return self._call_model_uncached(attacked_text_list)
else:
uncached_list = []
for text in tokenized_text_list:
for text in attacked_text_list:
if text in self._call_model_cache:
# Re-write value in cache. This moves the key to the top of the
# LRU cache and prevents the unlikely event that the text
@@ -145,13 +145,13 @@ class GoalFunction:
uncached_list.append(text)
uncached_list = [
text
for text in tokenized_text_list
for text in attacked_text_list
if text not in self._call_model_cache
]
outputs = self._call_model_uncached(uncached_list)
for text, output in zip(uncached_list, outputs):
self._call_model_cache[text] = output
all_outputs = [self._call_model_cache[text] for text in tokenized_text_list]
all_outputs = [self._call_model_cache[text] for text in attacked_text_list]
return all_outputs
def extra_repr_keys(self):

View File

@@ -77,7 +77,7 @@ class AttackLogManager:
successful_attacks = 0
max_words_changed = None
for i, result in enumerate(self.results):
all_num_words[i] = len(result.original_result.tokenized_text.words)
all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult):
failed_attacks += 1
continue
@@ -87,19 +87,19 @@ class AttackLogManager:
else:
successful_attacks += 1
num_words_changed = len(
result.original_result.tokenized_text.all_words_diff(
result.perturbed_result.tokenized_text
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
num_words_changed_until_success[num_words_changed - 1] += 1
max_words_changed = max(
max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.tokenized_text.words) > 0:
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.tokenized_text.words)
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0

View File

@@ -1,3 +1,4 @@
import torch
import transformers
from textattack.models.tokenizers import Tokenizer
@@ -41,12 +42,16 @@ class AutoTokenizer(Tokenizer):
add_special_tokens=True,
pad_to_max_length=True,
)
return encoded_text
return dict(encoded_text)
def encode_batch(self, input_text_list):
""" The batch equivalent of ``encode``."""
return self.tokenizer.encode_batch(input_text_list,
if hasattr(self.tokenizer, "encode_batch"):
return self.tokenizer.encode_batch(
input_text_list,
max_length=self.max_length,
add_special_tokens=True,
pad_to_max_length=True,
)
else:
return [dict(self.encode(input_text)) for input_text in input_text_list]

View File

@@ -19,13 +19,13 @@ class BeamSearch(SearchMethod):
self.beam_width = beam_width
def _perform_search(self, initial_result):
beam = [initial_result.tokenized_text]
beam = [initial_result.attacked_text]
best_result = initial_result
while not best_result.succeeded:
potential_next_beam = []
for text in beam:
transformations = self.get_transformations(
text, original_text=initial_result.tokenized_text
text, original_text=initial_result.attacked_text
)
for next_text in transformations:
potential_next_beam.append(next_text)

View File

@@ -45,14 +45,14 @@ class GeneticAlgorithm(SearchMethod):
Whether a replacement which increased the score was found.
"""
transformations = self.get_transformations(
pop_member.tokenized_text,
original_text=self.original_tokenized_text,
pop_member.attacked_text,
original_text=self.original_attacked_text,
indices_to_modify=[idx],
)
if not len(transformations):
return False
orig_result, self.search_over = self.get_goal_results(
[pop_member.tokenized_text], self.correct_output
[pop_member.attacked_text], self.correct_output
)
if self.search_over:
return False
@@ -62,7 +62,7 @@ class GeneticAlgorithm(SearchMethod):
new_x_scores = torch.Tensor([r.score for r in new_x_results])
new_x_scores = new_x_scores - orig_result[0].score
if len(new_x_scores) and new_x_scores.max() > 0:
pop_member.tokenized_text = transformations[new_x_scores.argmax()]
pop_member.attacked_text = transformations[new_x_scores.argmax()]
return True
return False
@@ -99,7 +99,7 @@ class GeneticAlgorithm(SearchMethod):
pop = []
for _ in range(self.pop_size):
pop_member = PopulationMember(
self.original_tokenized_text, deepcopy(neighbors_len), initial_result
self.original_attacked_text, deepcopy(neighbors_len), initial_result
)
self._perturb(pop_member)
pop.append(pop_member)
@@ -116,8 +116,8 @@ class GeneticAlgorithm(SearchMethod):
"""
indices_to_replace = []
words_to_replace = []
x1_text = pop_member1.tokenized_text
x2_words = pop_member2.tokenized_text.words
x1_text = pop_member1.attacked_text
x2_words = pop_member2.attacked_text.words
new_neighbors_len = deepcopy(pop_member1.neighbors_len)
for i in range(len(x1_text.words)):
if np.random.uniform() < 0.5:
@@ -129,35 +129,35 @@ class GeneticAlgorithm(SearchMethod):
)
return PopulationMember(new_text, deepcopy(new_neighbors_len))
def _get_neighbors_len(self, tokenized_text):
def _get_neighbors_len(self, attacked_text):
"""
Generates this neighbors_len list
Args:
tokenized_text: The original text
attacked_text: The original text
Returns:
A list of number of candidate neighbors for each word
"""
words = tokenized_text.words
words = attacked_text.words
neighbors_list = [[] for _ in range(len(words))]
transformations = self.get_transformations(
tokenized_text, original_text=self.original_tokenized_text
attacked_text, original_text=self.original_attacked_text
)
for transformed_text in transformations:
diff_idx = tokenized_text.first_word_diff_index(transformed_text)
diff_idx = attacked_text.first_word_diff_index(transformed_text)
neighbors_list[diff_idx].append(transformed_text.words[diff_idx])
neighbors_list = [np.array(x) for x in neighbors_list]
neighbors_len = np.array([len(x) for x in neighbors_list])
return neighbors_len
def _perform_search(self, initial_result):
self.original_tokenized_text = initial_result.tokenized_text
self.original_attacked_text = initial_result.attacked_text
self.correct_output = initial_result.output
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
neighbors_len = self._get_neighbors_len(self.original_attacked_text)
pop = self._generate_population(neighbors_len, initial_result)
cur_score = initial_result.score
for i in range(self.max_iters):
pop_results, self.search_over = self.get_goal_results(
[pm.tokenized_text for pm in pop], self.correct_output
[pm.attacked_text for pm in pop], self.correct_output
)
if self.search_over:
if not len(pop_results):
@@ -213,11 +213,11 @@ class PopulationMember:
A member of the population during the course of the genetic algorithm.
Args:
tokenized_text: The ``AttackedText`` of the population member.
attacked_text: The ``AttackedText`` of the population member.
neighbors_len: A list of the number of candidate neighbors list for each word.
"""
def __init__(self, tokenized_text, neighbors_len, result=None):
self.tokenized_text = tokenized_text
def __init__(self, attacked_text, neighbors_len, result=None):
self.attacked_text = attacked_text
self.neighbors_len = neighbors_len
self.result = result

View File

@@ -39,23 +39,22 @@ class GreedyWordSwapWIR(SearchMethod):
return leave_one_scores, search_over
def _perform_search(self, initial_result):
tokenized_text = initial_result.tokenized_text
attacked_text = initial_result.attacked_text
cur_result = initial_result
# Sort words by order of importance
len_text = len(tokenized_text.words)
len_text = len(attacked_text.words)
if self.wir_method == "unk":
leave_one_texts = [
tokenized_text.replace_word_at_index(i, "[UNK]")
for i in range(len_text)
attacked_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
]
leave_one_scores, search_over = self._get_index_order(
initial_result, leave_one_texts
)
elif self.wir_method == "delete":
leave_one_texts = [
tokenized_text.delete_word_at_index(i) for i in range(len_text)
attacked_text.delete_word_at_index(i) for i in range(len_text)
]
leave_one_scores = self._get_index_order(initial_result, leave_one_texts)
elif self.wir_method == "random":
@@ -71,8 +70,8 @@ class GreedyWordSwapWIR(SearchMethod):
results = None
while i < len(index_order) and not search_over:
transformed_text_candidates = self.get_transformations(
cur_result.tokenized_text,
original_text=initial_result.tokenized_text,
cur_result.attacked_text,
original_text=initial_result.attacked_text,
indices_to_modify=[index_order[i]],
)
i += 1
@@ -95,7 +94,7 @@ class GreedyWordSwapWIR(SearchMethod):
for result in results:
if not result.succeeded:
break
candidate = result.tokenized_text
candidate = result.attacked_text
try:
similarity_score = candidate.attack_attrs["similarity_score"]
except KeyError:

View File

@@ -24,7 +24,7 @@ class SearchMethod:
def _perform_search(self, initial_result):
"""
Perturbs `tokenized_text` from ``initial_result`` until goal is reached or search is
Perturbs `attacked_text` from ``initial_result`` until goal is reached or search is
exhausted. Must be overridden by specific search methods.
"""
raise NotImplementedError()

View File

@@ -217,10 +217,10 @@ class Attack:
yield
for text, ground_truth_output in dataset:
tokenized_text = AttackedText(text, self.goal_function.tokenizer)
attacked_text = AttackedText(text)
self.goal_function.num_queries = 0
goal_function_result, _ = self.goal_function.get_result(
tokenized_text, ground_truth_output
attacked_text, ground_truth_output
)
# We can skip examples for which the goal is already succeeded,
# unless `attack_skippable_examples` is True.

View File

@@ -1,4 +1,5 @@
import collections
from collections import OrderedDict
import numpy as np
import torch
@@ -29,16 +30,24 @@ class AttackedText:
SPLIT_TOKEN = ">>>>"
def __init__(self, text_input, attack_attrs=None):
# Read in ``text_input`` as a string or OrderedDict.
if isinstance(text_input, str):
self._text_input = collections.OrderedDict(('text', text_input))
elif isinstance(text_input, collections.OrderedDict):
self._text_input = OrderedDict([("text", text_input)])
elif isinstance(text_input, OrderedDict):
self._text_input = text_input
else:
raise TypeError(f'Invalid text_input type {type(text_input)} (required str or OrderedDict)')
text = text.strip()
self.words = words_from_text(text, words_to_ignore=[AttackedText.SPLIT_TOKEN])
self.text = text
self.attack_attrs = attack_attrs or dict()
raise TypeError(
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
)
# Format text inputs.
self._text_input = {k: v.strip() for k, v in self._text_input.items()}
self.words = words_from_text(self.text)
if attack_attrs is None:
self.attack_attrs = dict()
elif isinstance(attack_attrs, dict):
self.attack_attrs = attack_attrs
else:
raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}")
# Indices of words from the *original* text. Allows us to map
# indices between original text and this text, and vice-versa.
self.attack_attrs.setdefault("original_index_map", np.arange(len(self.words)))
@@ -75,7 +84,6 @@ class AttackedText:
Can be called once the AttackedText is only needed to display.
"""
self.tokenizer = None
if "previous_attacked_text" in self.attack_attrs:
self.attack_attrs["previous_attacked_text"].free_memory()
if "last_transformation" in self.attack_attrs:
@@ -122,42 +130,42 @@ class AttackedText:
look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
return self.text[look_after_index:]
def first_word_diff(self, other_tokenized_text):
def first_word_diff(self, other_attacked_text):
""" Returns the first word in self.words that differs from
other_tokenized_text. Useful for word swap strategies. """
other_attacked_text. Useful for word swap strategies. """
w1 = self.words
w2 = other_tokenized_text.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
return w1
return None
def first_word_diff_index(self, other_tokenized_text):
def first_word_diff_index(self, other_attacked_text):
""" Returns the index of the first word in self.words that differs
from other_tokenized_text. Useful for word swap strategies. """
from other_attacked_text. Useful for word swap strategies. """
w1 = self.words
w2 = other_tokenized_text.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
return i
return None
def all_words_diff(self, other_tokenized_text):
""" Returns the set of indices for which this and other_tokenized_text
def all_words_diff(self, other_attacked_text):
""" Returns the set of indices for which this and other_attacked_text
have different words. """
indices = set()
w1 = self.words
w2 = other_tokenized_text.words
w2 = other_attacked_text.words
for i in range(min(len(w1), len(w2))):
if w1[i] != w2[i]:
indices.add(i)
return indices
def ith_word_diff(self, other_tokenized_text, i):
""" Returns whether the word at index i differs from other_tokenized_text
def ith_word_diff(self, other_attacked_text, i):
""" Returns whether the word at index i differs from other_attacked_text
"""
w1 = self.words
w2 = other_tokenized_text.words
w2 = other_attacked_text.words
if len(w1) - 1 < i or len(w2) - 1 < i:
return True
return w1[i] != w2[i]
@@ -244,7 +252,7 @@ class AttackedText:
of one or more words.
"""
perturbed_text = ""
original_text = self._input_text.join(AttackedText.SPLIT_TOKEN)
original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values())
new_attack_attrs = dict()
new_attack_attrs["newly_modified_indices"] = set()
# Point to previously monitored text.
@@ -261,9 +269,9 @@ class AttackedText:
# Create the new attacked text by swapping out words from the original
# text with a sequence of 0+ words in the new text.
for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)):
word_start = text.index(input_word)
word_start = original_text.index(input_word)
word_end = word_start + len(input_word)
perturbed_text += text[:word_start]
perturbed_text += original_text[:word_start]
original_text = original_text[word_end:]
adv_num_words = len(words_from_text(adv_word_seq))
num_words_diff = adv_num_words - len(words_from_text(input_word))
@@ -300,32 +308,47 @@ class AttackedText:
# @TODO What to do with punctuation in this case? This behavior is undefined.
if i == 0:
# If the first word was deleted, take a subsequent space.
if text[0] == " ":
text = text[1:]
if original_text[0] == " ":
original_text = original_text[1:]
else:
# If a word other than the first was deleted, take a preceding space.
if perturbed_text[-1] == " ":
perturbed_text = perturbed_text[:-1]
# Add substitute word(s) to new sentence.
perturbed_text += adv_word_seq
perturbed_text += text # Add all of the ending punctuation.
perturbed_text += original_text # Add all of the ending punctuation.
# Reform perturbed_text into an OrderedDict.
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
perturbed_input = OrderedDict(zip(self._text_input.keys(), perturbed_input_texts))
return AttackedText(
perturbed_text, self.tokenizer, attack_attrs=new_attack_attrs
perturbed_input = OrderedDict(
zip(self._text_input.keys(), perturbed_input_texts)
)
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
@property
def text(self):
""" Represents full text input. Multiply inputs are joined with a line
break.
"""
return '\n\n'.join(self._text_input.value())
return "\n".join(self._text_input.values())
@property
def printable_text(self):
return '\n\n'.join(f'{key}: {value}' for key, value in enumerate(self._text_input))
""" Represents full text input. Adds field descriptions.
For example, entailment inputs look like:
```
premise: ...
hypothesis: ...
```
"""
# For single-sequence inputs, don't show a prefix.
if len(self._text_input) == 1:
return next(iter(self._text_input.values()))
# For multiple-sequence inputs, show a prefix and a colon.
else:
return "\n\n".join(
f"{key}: {value}" for key, value in self._text_input.items()
)
def __repr__(self):
return f'<AttackedText "{self.text}">'

View File

@@ -1,4 +1,5 @@
import argparse
import collections
import sys
import torch
@@ -13,7 +14,7 @@ def _cb(s):
def get_num_successes(args, model, ids, true_labels):
with torch.no_grad():
preds = textattack.shared.utils.model_predict(model, ids)
preds = textattack.shared.utils.model_predict(model, **ids)
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.device)
guess_labels = preds.argmax(dim=1)
successes = (guess_labels == true_labels).sum().item()
@@ -28,9 +29,8 @@ def test_model_on_dataset(args, model, dataset, batch_size=128):
batch_labels = []
all_true_labels = []
all_guess_labels = []
for i, (text, label) in enumerate(dataset):
if i >= num_examples:
break
for i, (text_input, label) in enumerate(dataset):
text = textattack.shared.AttackedText(text_input).text
ids = model.tokenizer.encode(text)
batch_ids.append(ids)
batch_labels.append(label)

View File

@@ -73,10 +73,8 @@ def run(args):
print("Attacking...")
tokenized_text = textattack.shared.tokenized_text.AttackedText(
text, attack.goal_function.model.tokenizer
)
initial_result = attack.goal_function.get_output(tokenized_text)
attacked_text = textattack.shared.attacked_text.AttackedText(text)
initial_result = attack.goal_function.get_output(attacked_text)
result = next(attack.attack_dataset([(text, initial_result)]))
print(result.__str__(color_method="ansi"))

View File

@@ -3,13 +3,15 @@ import torch
import textattack
from textattack.shared import utils
def batch_tokenize(tokenizer, inputs):
def batch_tokenize(tokenizer, attacked_text_list):
""" Tokenizes a list of inputs and returns their tokenized forms in a list. """
if hasattr(tokenizer, 'encode_batch'):
encoded_inputs = tokenizer.encode_batch(inputs)
inputs = [at.text for at in attacked_text_list]
if hasattr(tokenizer, "encode_batch"):
return tokenizer.encode_batch(inputs)
else:
encoded_inputs = [tokenizer.encode(x) for x in inputs]
return [x.ids for x in encoded_inputs]
return [tokenizer.encode(x) for x in inputs]
def batch_model_predict(model, inputs, batch_size=utils.config("MODEL_BATCH_SIZE")):
outputs = []

View File

@@ -22,7 +22,7 @@ class CompositeTransformation(Transformation):
self.transformations = transformations
def __call__(self, *args, **kwargs):
new_tokenized_texts = set()
new_attacked_texts = set()
for transformation in self.transformations:
new_tokenized_texts.update(transformation(*args, **kwargs))
return list(new_tokenized_texts)
new_attacked_texts.update(transformation(*args, **kwargs))
return list(new_attacked_texts)

View File

@@ -43,12 +43,12 @@ class WordSwapGradientBased(Transformation):
self.top_n = top_n
self.is_black_box = False
def _get_replacement_words_by_grad(self, text, indices_to_replace):
def _get_replacement_words_by_grad(self, attacked_text, indices_to_replace):
""" Returns returns a list containing all possible words to replace
`word` with, based off of the model's gradient.
Arguments:
text (AttackedText): The full text input to perturb
attacked_text (AttackedText): The full text input to perturb
word_index (int): index of the word to replace
"""
self.model.train()
@@ -56,11 +56,14 @@ class WordSwapGradientBased(Transformation):
lookup_table = self.model.lookup_table.to(utils.device)
lookup_table_transpose = lookup_table.transpose(0, 1)
# get word IDs
text_ids = self.model.tokenizer.encode(attacked_text.text)
# set backward hook on the word embeddings for input x
emb_hook = Hook(self.model.word_embeddings, backward=True)
self.model.zero_grad()
predictions = self._call_model(text)
predictions = self._call_model(text_ids)
original_label = predictions.argmax()
y_true = torch.Tensor([original_label]).long().to(utils.device)
loss = self.loss(predictions, y_true)
@@ -78,7 +81,7 @@ class WordSwapGradientBased(Transformation):
b_grads = (
emb_grad[word_idx].view(1, -1).mm(lookup_table_transpose).squeeze()
)
a_grad = b_grads[text.ids[word_idx]]
a_grad = b_grads[text_ids[word_idx]]
diffs[j] = b_grads - a_grad
# Don't change to the pad token.
@@ -104,12 +107,12 @@ class WordSwapGradientBased(Transformation):
self.model.eval()
return candidates
def _call_model(self, text):
def _call_model(self, text_ids):
""" A helper function to query `self.model` with AttackedText `text`.
"""
return utils.model_predict(self.model, [text.ids])
return utils.model_predict(self.model, [text_ids])
def _get_transformations(self, tokenized_text, indices_to_replace):
def _get_transformations(self, attacked_text, indices_to_replace):
"""
Returns a list of all possible transformations for `text`.
@@ -118,9 +121,9 @@ class WordSwapGradientBased(Transformation):
"""
transformations = []
for word, idx in self._get_replacement_words_by_grad(
tokenized_text, indices_to_replace
attacked_text, indices_to_replace
):
transformations.append(tokenized_text.replace_word_at_index(idx, word))
transformations.append(attacked_text.replace_word_at_index(idx, word))
return transformations
def extra_repr_keys(self):