mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update attackedtext references, need to update tokenization
This commit is contained in:
2
Makefile
2
Makefile
@@ -8,7 +8,7 @@ lint: FORCE ## Run black (in check mode)
|
||||
isort --check-only --recursive tests textattack
|
||||
|
||||
test: FORCE ## Run tests using pytest
|
||||
python -m pytest
|
||||
python -m pytest -x
|
||||
|
||||
docs: FORCE ## Build docs using Sphinx.
|
||||
sphinx-build -b html docs docs/_build/html
|
||||
|
||||
@@ -32,18 +32,18 @@ class AttackResult:
|
||||
# We don't want the AttackedText `ids` sticking around clogging up
|
||||
# space on our devices. Delete them here, if they're still present,
|
||||
# because we won't need them anymore anyway.
|
||||
self.original_result.tokenized_text.free_memory()
|
||||
self.perturbed_result.tokenized_text.free_memory()
|
||||
self.original_result.attacked_text.free_memory()
|
||||
self.perturbed_result.attacked_text.free_memory()
|
||||
|
||||
def original_text(self):
|
||||
""" Returns the text portion of `self.original_result`. Helper method.
|
||||
"""
|
||||
return self.original_result.tokenized_text.clean_text()
|
||||
return self.original_result.attacked_text.printable_text
|
||||
|
||||
def perturbed_text(self):
|
||||
""" Returns the text portion of `self.perturbed_result`. Helper method.
|
||||
"""
|
||||
return self.original_result.tokenized_text.clean_text()
|
||||
return self.original_result.attacked_text.printable_text
|
||||
|
||||
def str_lines(self, color_method=None):
|
||||
""" A list of the lines to be printed for this result's string
|
||||
@@ -67,11 +67,11 @@ class AttackResult:
|
||||
def diff_color(self, color_method=None):
|
||||
""" Highlights the difference between two texts using color.
|
||||
"""
|
||||
t1 = self.original_result.tokenized_text
|
||||
t2 = self.perturbed_result.tokenized_text
|
||||
t1 = self.original_result.attacked_text
|
||||
t2 = self.perturbed_result.attacked_text
|
||||
|
||||
if color_method is None:
|
||||
return t1.clean_text(), t2.clean_text()
|
||||
return t1.printable_text, t2.printable_text
|
||||
|
||||
color_1 = self.original_result.get_text_color_input()
|
||||
color_2 = self.perturbed_result.get_text_color_perturbed()
|
||||
@@ -107,11 +107,11 @@ class AttackResult:
|
||||
i1 += 1
|
||||
i2 += 1
|
||||
|
||||
t1 = self.original_result.tokenized_text.replace_words_at_indices(
|
||||
t1 = self.original_result.attacked_text.replace_words_at_indices(
|
||||
words_1_idxs, words_1
|
||||
)
|
||||
t2 = self.perturbed_result.tokenized_text.replace_words_at_indices(
|
||||
t2 = self.perturbed_result.attacked_text.replace_words_at_indices(
|
||||
words_2_idxs, words_2
|
||||
)
|
||||
|
||||
return t1.clean_text(), t2.clean_text()
|
||||
return t1.printable_text, '\n', t2.printable_text
|
||||
|
||||
@@ -60,13 +60,13 @@ class Augmenter:
|
||||
Returns all possible augmentations of ``text`` according to
|
||||
``self.transformation``.
|
||||
"""
|
||||
tokenized_text = AttackedText(text, DummyTokenizer())
|
||||
original_text = tokenized_text
|
||||
attacked_text = AttackedText(text)
|
||||
original_text = attacked_text
|
||||
all_transformed_texts = set()
|
||||
for _ in range(self.transformations_per_example):
|
||||
index_order = list(range(len(tokenized_text.words)))
|
||||
index_order = list(range(len(attacked_text.words)))
|
||||
random.shuffle(index_order)
|
||||
current_text = tokenized_text
|
||||
current_text = attacked_text
|
||||
words_swapped = 0
|
||||
for i in index_order:
|
||||
transformed_texts = self.transformation(
|
||||
@@ -87,7 +87,7 @@ class Augmenter:
|
||||
if words_swapped == self.num_words_to_swap:
|
||||
break
|
||||
all_transformed_texts.add(current_text)
|
||||
return [t.clean_text() for t in all_transformed_texts]
|
||||
return [t.printable_text for t in all_transformed_texts]
|
||||
|
||||
def augment_many(self, text_list, show_progress=False):
|
||||
"""
|
||||
@@ -124,14 +124,3 @@ class Augmenter:
|
||||
all_text_list.extend([text] + augmented_texts)
|
||||
all_id_list.extend([_id] * (1 + len(augmented_texts)))
|
||||
return all_text_list, all_id_list
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
"""
|
||||
A dummy tokenizer class. Data augmentation applies a transformation
|
||||
without querying a model, which means that tokenization is unnecessary.
|
||||
In this case, we pass a dummy tokenizer to `AttackedText`.
|
||||
"""
|
||||
|
||||
def encode(self, _):
|
||||
return []
|
||||
|
||||
@@ -43,8 +43,8 @@ class GPT2(LanguageModelConstraint):
|
||||
predictions = outputs[0]
|
||||
|
||||
probs = []
|
||||
for tokenized_text in text_list:
|
||||
nxt_word_ids = self.tokenizer.encode(tokenized_text.words[word_index])
|
||||
for attacked_text in text_list:
|
||||
nxt_word_ids = self.tokenizer.encode(attacked_text.words[word_index])
|
||||
next_word_prob = predictions[0, -1, next_word_ids[0]]
|
||||
probs.append(next_word_prob)
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ class LanguageTool(Constraint):
|
||||
self.grammar_error_threshold = grammar_error_threshold
|
||||
self.grammar_error_cache = {}
|
||||
|
||||
def get_errors(self, tokenized_text, use_cache=False):
|
||||
text = tokenized_text.clean_text()
|
||||
def get_errors(self, attacked_text, use_cache=False):
|
||||
text = attacked_text.printable_text
|
||||
if use_cache:
|
||||
if text not in self.grammar_error_cache:
|
||||
self.grammar_error_cache[text] = len(self.lang_tool.check(text))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import collections
|
||||
|
||||
from textattack.datasets import TextAttackDataset
|
||||
from textattack.shared import AttackedText
|
||||
|
||||
@@ -30,8 +32,7 @@ class EntailmentDataset(TextAttackDataset):
|
||||
except ValueError:
|
||||
# If the label is not an integer, it's a label description.
|
||||
label = self._label_str_to_int(label)
|
||||
text_input = collections.OrderedDict([
|
||||
('premise', premise),
|
||||
('hypothesis', hypothesis),
|
||||
])
|
||||
text_input = collections.OrderedDict(
|
||||
[("premise", premise), ("hypothesis", hypothesis),]
|
||||
)
|
||||
return (text_input, label)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import collections
|
||||
import random
|
||||
|
||||
import nlp
|
||||
|
||||
from textattack.datasets import TextAttackDataset
|
||||
@@ -86,7 +87,9 @@ class HuggingFaceNLPDataset(TextAttackDataset):
|
||||
|
||||
# Convert `raw_example` to an OrderedDict, so that we know which order
|
||||
# in which to pass examples to the model.
|
||||
input_dict = collections.OrderedDict([raw_example[c] for c in raw_example])
|
||||
input_dict = collections.OrderedDict(
|
||||
[(c, raw_example[c]) for c in self.input_columns]
|
||||
)
|
||||
|
||||
output = raw_example[self.output_column]
|
||||
if self.label_map:
|
||||
|
||||
@@ -6,14 +6,14 @@ class GoalFunctionResult:
|
||||
Represents the result of a goal function evaluating a AttackedText object.
|
||||
|
||||
Args:
|
||||
tokenized_text: The sequence that was evaluated.
|
||||
attacked_text: The sequence that was evaluated.
|
||||
output: The display-friendly output.
|
||||
succeeded: Whether the goal has been achieved.
|
||||
score: A score representing how close the model is to achieving its goal.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenized_text, output, succeeded, score):
|
||||
self.tokenized_text = tokenized_text
|
||||
def __init__(self, attacked_text, output, succeeded, score):
|
||||
self.attacked_text = attacked_text
|
||||
self.output = output
|
||||
self.score = score
|
||||
self.succeeded = succeeded
|
||||
|
||||
@@ -10,7 +10,7 @@ from textattack.shared.utils import batch_model_predict, default_class_repr
|
||||
|
||||
class GoalFunction:
|
||||
"""
|
||||
Evaluates how well a perturbed tokenized_text object is achieving a specified goal.
|
||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
@@ -40,32 +40,32 @@ class GoalFunction:
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, tokenized_text, ground_truth_output):
|
||||
def should_skip(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
Returns whether or not the goal has already been completed for ``tokenized_text``\,
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``\,
|
||||
due to misprediction by the model.
|
||||
"""
|
||||
model_outputs = self._call_model([tokenized_text])
|
||||
model_outputs = self._call_model([attacked_text])
|
||||
return self._is_goal_complete(model_outputs[0], ground_truth_output)
|
||||
|
||||
def get_output(self, tokenized_text):
|
||||
def get_output(self, attacked_text):
|
||||
"""
|
||||
Returns output for display based on the result of calling the model.
|
||||
"""
|
||||
return self._get_displayed_output(self._call_model([tokenized_text])[0])
|
||||
return self._get_displayed_output(self._call_model([attacked_text])[0])
|
||||
|
||||
def get_result(self, tokenized_text, ground_truth_output):
|
||||
def get_result(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
A helper method that queries `self.get_results` with a single
|
||||
``AttackedText`` object.
|
||||
"""
|
||||
results, search_over = self.get_results([tokenized_text], ground_truth_output)
|
||||
results, search_over = self.get_results([attacked_text], ground_truth_output)
|
||||
result = results[0] if len(results) else None
|
||||
return result, search_over
|
||||
|
||||
def get_results(self, tokenized_text_list, ground_truth_output):
|
||||
def get_results(self, attacked_text_list, ground_truth_output):
|
||||
"""
|
||||
For each tokenized_text object in tokenized_text_list, returns a result
|
||||
For each attacked_text object in attacked_text_list, returns a result
|
||||
consisting of whether or not the goal has been achieved, the output for
|
||||
display purposes, and a score. Additionally returns whether the search
|
||||
is over due to the query budget.
|
||||
@@ -73,16 +73,16 @@ class GoalFunction:
|
||||
results = []
|
||||
if self.query_budget < float("inf"):
|
||||
queries_left = self.query_budget - self.num_queries
|
||||
tokenized_text_list = tokenized_text_list[:queries_left]
|
||||
self.num_queries += len(tokenized_text_list)
|
||||
model_outputs = self._call_model(tokenized_text_list)
|
||||
for tokenized_text, raw_output in zip(tokenized_text_list, model_outputs):
|
||||
attacked_text_list = attacked_text_list[:queries_left]
|
||||
self.num_queries += len(attacked_text_list)
|
||||
model_outputs = self._call_model(attacked_text_list)
|
||||
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
|
||||
displayed_output = self._get_displayed_output(raw_output)
|
||||
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
|
||||
goal_function_score = self._get_score(raw_output, ground_truth_output)
|
||||
results.append(
|
||||
self._goal_function_result_type()(
|
||||
tokenized_text, displayed_output, succeeded, goal_function_score
|
||||
attacked_text, displayed_output, succeeded, goal_function_score
|
||||
)
|
||||
)
|
||||
return results, self.num_queries == self.query_budget
|
||||
@@ -111,31 +111,31 @@ class GoalFunction:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _call_model_uncached(self, tokenized_text_list):
|
||||
def _call_model_uncached(self, attacked_text_list):
|
||||
"""
|
||||
Queries model and returns outputs for a list of AttackedText
|
||||
objects.
|
||||
"""
|
||||
if not len(tokenized_text_list):
|
||||
if not len(attacked_text_list):
|
||||
return []
|
||||
ids = utils.batch_tokenize(self.tokenizer, tokenized_text_list)
|
||||
ids = utils.batch_tokenize(self.tokenizer, attacked_text_list)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = batch_model_predict(self.model, ids)
|
||||
|
||||
return self._process_model_outputs(tokenized_text_list, outputs)
|
||||
return self._process_model_outputs(attacked_text_list, outputs)
|
||||
|
||||
def _call_model(self, tokenized_text_list):
|
||||
def _call_model(self, attacked_text_list):
|
||||
""" Gets predictions for a list of `AttackedText` objects.
|
||||
|
||||
Gets prediction from cache if possible. If prediction is not in the
|
||||
cache, queries model and stores prediction in cache.
|
||||
"""
|
||||
if not self.use_cache:
|
||||
return self._call_model_uncached(tokenized_text_list)
|
||||
return self._call_model_uncached(attacked_text_list)
|
||||
else:
|
||||
uncached_list = []
|
||||
for text in tokenized_text_list:
|
||||
for text in attacked_text_list:
|
||||
if text in self._call_model_cache:
|
||||
# Re-write value in cache. This moves the key to the top of the
|
||||
# LRU cache and prevents the unlikely event that the text
|
||||
@@ -145,13 +145,13 @@ class GoalFunction:
|
||||
uncached_list.append(text)
|
||||
uncached_list = [
|
||||
text
|
||||
for text in tokenized_text_list
|
||||
for text in attacked_text_list
|
||||
if text not in self._call_model_cache
|
||||
]
|
||||
outputs = self._call_model_uncached(uncached_list)
|
||||
for text, output in zip(uncached_list, outputs):
|
||||
self._call_model_cache[text] = output
|
||||
all_outputs = [self._call_model_cache[text] for text in tokenized_text_list]
|
||||
all_outputs = [self._call_model_cache[text] for text in attacked_text_list]
|
||||
return all_outputs
|
||||
|
||||
def extra_repr_keys(self):
|
||||
|
||||
@@ -77,7 +77,7 @@ class AttackLogManager:
|
||||
successful_attacks = 0
|
||||
max_words_changed = None
|
||||
for i, result in enumerate(self.results):
|
||||
all_num_words[i] = len(result.original_result.tokenized_text.words)
|
||||
all_num_words[i] = len(result.original_result.attacked_text.words)
|
||||
if isinstance(result, FailedAttackResult):
|
||||
failed_attacks += 1
|
||||
continue
|
||||
@@ -87,19 +87,19 @@ class AttackLogManager:
|
||||
else:
|
||||
successful_attacks += 1
|
||||
num_words_changed = len(
|
||||
result.original_result.tokenized_text.all_words_diff(
|
||||
result.perturbed_result.tokenized_text
|
||||
result.original_result.attacked_text.all_words_diff(
|
||||
result.perturbed_result.attacked_text
|
||||
)
|
||||
)
|
||||
num_words_changed_until_success[num_words_changed - 1] += 1
|
||||
max_words_changed = max(
|
||||
max_words_changed or num_words_changed, num_words_changed
|
||||
)
|
||||
if len(result.original_result.tokenized_text.words) > 0:
|
||||
if len(result.original_result.attacked_text.words) > 0:
|
||||
perturbed_word_percentage = (
|
||||
num_words_changed
|
||||
* 100.0
|
||||
/ len(result.original_result.tokenized_text.words)
|
||||
/ len(result.original_result.attacked_text.words)
|
||||
)
|
||||
else:
|
||||
perturbed_word_percentage = 0
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from textattack.models.tokenizers import Tokenizer
|
||||
@@ -41,12 +42,16 @@ class AutoTokenizer(Tokenizer):
|
||||
add_special_tokens=True,
|
||||
pad_to_max_length=True,
|
||||
)
|
||||
return encoded_text
|
||||
return dict(encoded_text)
|
||||
|
||||
def encode_batch(self, input_text_list):
|
||||
""" The batch equivalent of ``encode``."""
|
||||
return self.tokenizer.encode_batch(input_text_list,
|
||||
if hasattr(self.tokenizer, "encode_batch"):
|
||||
return self.tokenizer.encode_batch(
|
||||
input_text_list,
|
||||
max_length=self.max_length,
|
||||
add_special_tokens=True,
|
||||
pad_to_max_length=True,
|
||||
)
|
||||
else:
|
||||
return [dict(self.encode(input_text)) for input_text in input_text_list]
|
||||
|
||||
@@ -19,13 +19,13 @@ class BeamSearch(SearchMethod):
|
||||
self.beam_width = beam_width
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
beam = [initial_result.tokenized_text]
|
||||
beam = [initial_result.attacked_text]
|
||||
best_result = initial_result
|
||||
while not best_result.succeeded:
|
||||
potential_next_beam = []
|
||||
for text in beam:
|
||||
transformations = self.get_transformations(
|
||||
text, original_text=initial_result.tokenized_text
|
||||
text, original_text=initial_result.attacked_text
|
||||
)
|
||||
for next_text in transformations:
|
||||
potential_next_beam.append(next_text)
|
||||
|
||||
@@ -45,14 +45,14 @@ class GeneticAlgorithm(SearchMethod):
|
||||
Whether a replacement which increased the score was found.
|
||||
"""
|
||||
transformations = self.get_transformations(
|
||||
pop_member.tokenized_text,
|
||||
original_text=self.original_tokenized_text,
|
||||
pop_member.attacked_text,
|
||||
original_text=self.original_attacked_text,
|
||||
indices_to_modify=[idx],
|
||||
)
|
||||
if not len(transformations):
|
||||
return False
|
||||
orig_result, self.search_over = self.get_goal_results(
|
||||
[pop_member.tokenized_text], self.correct_output
|
||||
[pop_member.attacked_text], self.correct_output
|
||||
)
|
||||
if self.search_over:
|
||||
return False
|
||||
@@ -62,7 +62,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||
new_x_scores = new_x_scores - orig_result[0].score
|
||||
if len(new_x_scores) and new_x_scores.max() > 0:
|
||||
pop_member.tokenized_text = transformations[new_x_scores.argmax()]
|
||||
pop_member.attacked_text = transformations[new_x_scores.argmax()]
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -99,7 +99,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
pop = []
|
||||
for _ in range(self.pop_size):
|
||||
pop_member = PopulationMember(
|
||||
self.original_tokenized_text, deepcopy(neighbors_len), initial_result
|
||||
self.original_attacked_text, deepcopy(neighbors_len), initial_result
|
||||
)
|
||||
self._perturb(pop_member)
|
||||
pop.append(pop_member)
|
||||
@@ -116,8 +116,8 @@ class GeneticAlgorithm(SearchMethod):
|
||||
"""
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
x1_text = pop_member1.tokenized_text
|
||||
x2_words = pop_member2.tokenized_text.words
|
||||
x1_text = pop_member1.attacked_text
|
||||
x2_words = pop_member2.attacked_text.words
|
||||
new_neighbors_len = deepcopy(pop_member1.neighbors_len)
|
||||
for i in range(len(x1_text.words)):
|
||||
if np.random.uniform() < 0.5:
|
||||
@@ -129,35 +129,35 @@ class GeneticAlgorithm(SearchMethod):
|
||||
)
|
||||
return PopulationMember(new_text, deepcopy(new_neighbors_len))
|
||||
|
||||
def _get_neighbors_len(self, tokenized_text):
|
||||
def _get_neighbors_len(self, attacked_text):
|
||||
"""
|
||||
Generates this neighbors_len list
|
||||
Args:
|
||||
tokenized_text: The original text
|
||||
attacked_text: The original text
|
||||
Returns:
|
||||
A list of number of candidate neighbors for each word
|
||||
"""
|
||||
words = tokenized_text.words
|
||||
words = attacked_text.words
|
||||
neighbors_list = [[] for _ in range(len(words))]
|
||||
transformations = self.get_transformations(
|
||||
tokenized_text, original_text=self.original_tokenized_text
|
||||
attacked_text, original_text=self.original_attacked_text
|
||||
)
|
||||
for transformed_text in transformations:
|
||||
diff_idx = tokenized_text.first_word_diff_index(transformed_text)
|
||||
diff_idx = attacked_text.first_word_diff_index(transformed_text)
|
||||
neighbors_list[diff_idx].append(transformed_text.words[diff_idx])
|
||||
neighbors_list = [np.array(x) for x in neighbors_list]
|
||||
neighbors_len = np.array([len(x) for x in neighbors_list])
|
||||
return neighbors_len
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
self.original_tokenized_text = initial_result.tokenized_text
|
||||
self.original_attacked_text = initial_result.attacked_text
|
||||
self.correct_output = initial_result.output
|
||||
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
||||
neighbors_len = self._get_neighbors_len(self.original_attacked_text)
|
||||
pop = self._generate_population(neighbors_len, initial_result)
|
||||
cur_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results, self.search_over = self.get_goal_results(
|
||||
[pm.tokenized_text for pm in pop], self.correct_output
|
||||
[pm.attacked_text for pm in pop], self.correct_output
|
||||
)
|
||||
if self.search_over:
|
||||
if not len(pop_results):
|
||||
@@ -213,11 +213,11 @@ class PopulationMember:
|
||||
A member of the population during the course of the genetic algorithm.
|
||||
|
||||
Args:
|
||||
tokenized_text: The ``AttackedText`` of the population member.
|
||||
attacked_text: The ``AttackedText`` of the population member.
|
||||
neighbors_len: A list of the number of candidate neighbors list for each word.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenized_text, neighbors_len, result=None):
|
||||
self.tokenized_text = tokenized_text
|
||||
def __init__(self, attacked_text, neighbors_len, result=None):
|
||||
self.attacked_text = attacked_text
|
||||
self.neighbors_len = neighbors_len
|
||||
self.result = result
|
||||
|
||||
@@ -39,23 +39,22 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
return leave_one_scores, search_over
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
tokenized_text = initial_result.tokenized_text
|
||||
attacked_text = initial_result.attacked_text
|
||||
cur_result = initial_result
|
||||
|
||||
# Sort words by order of importance
|
||||
len_text = len(tokenized_text.words)
|
||||
len_text = len(attacked_text.words)
|
||||
|
||||
if self.wir_method == "unk":
|
||||
leave_one_texts = [
|
||||
tokenized_text.replace_word_at_index(i, "[UNK]")
|
||||
for i in range(len_text)
|
||||
attacked_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
|
||||
]
|
||||
leave_one_scores, search_over = self._get_index_order(
|
||||
initial_result, leave_one_texts
|
||||
)
|
||||
elif self.wir_method == "delete":
|
||||
leave_one_texts = [
|
||||
tokenized_text.delete_word_at_index(i) for i in range(len_text)
|
||||
attacked_text.delete_word_at_index(i) for i in range(len_text)
|
||||
]
|
||||
leave_one_scores = self._get_index_order(initial_result, leave_one_texts)
|
||||
elif self.wir_method == "random":
|
||||
@@ -71,8 +70,8 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
results = None
|
||||
while i < len(index_order) and not search_over:
|
||||
transformed_text_candidates = self.get_transformations(
|
||||
cur_result.tokenized_text,
|
||||
original_text=initial_result.tokenized_text,
|
||||
cur_result.attacked_text,
|
||||
original_text=initial_result.attacked_text,
|
||||
indices_to_modify=[index_order[i]],
|
||||
)
|
||||
i += 1
|
||||
@@ -95,7 +94,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
for result in results:
|
||||
if not result.succeeded:
|
||||
break
|
||||
candidate = result.tokenized_text
|
||||
candidate = result.attacked_text
|
||||
try:
|
||||
similarity_score = candidate.attack_attrs["similarity_score"]
|
||||
except KeyError:
|
||||
|
||||
@@ -24,7 +24,7 @@ class SearchMethod:
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
"""
|
||||
Perturbs `tokenized_text` from ``initial_result`` until goal is reached or search is
|
||||
Perturbs `attacked_text` from ``initial_result`` until goal is reached or search is
|
||||
exhausted. Must be overridden by specific search methods.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -217,10 +217,10 @@ class Attack:
|
||||
yield
|
||||
|
||||
for text, ground_truth_output in dataset:
|
||||
tokenized_text = AttackedText(text, self.goal_function.tokenizer)
|
||||
attacked_text = AttackedText(text)
|
||||
self.goal_function.num_queries = 0
|
||||
goal_function_result, _ = self.goal_function.get_result(
|
||||
tokenized_text, ground_truth_output
|
||||
attacked_text, ground_truth_output
|
||||
)
|
||||
# We can skip examples for which the goal is already succeeded,
|
||||
# unless `attack_skippable_examples` is True.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import collections
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -29,16 +30,24 @@ class AttackedText:
|
||||
SPLIT_TOKEN = ">>>>"
|
||||
|
||||
def __init__(self, text_input, attack_attrs=None):
|
||||
# Read in ``text_input`` as a string or OrderedDict.
|
||||
if isinstance(text_input, str):
|
||||
self._text_input = collections.OrderedDict(('text', text_input))
|
||||
elif isinstance(text_input, collections.OrderedDict):
|
||||
self._text_input = OrderedDict([("text", text_input)])
|
||||
elif isinstance(text_input, OrderedDict):
|
||||
self._text_input = text_input
|
||||
else:
|
||||
raise TypeError(f'Invalid text_input type {type(text_input)} (required str or OrderedDict)')
|
||||
text = text.strip()
|
||||
self.words = words_from_text(text, words_to_ignore=[AttackedText.SPLIT_TOKEN])
|
||||
self.text = text
|
||||
self.attack_attrs = attack_attrs or dict()
|
||||
raise TypeError(
|
||||
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
|
||||
)
|
||||
# Format text inputs.
|
||||
self._text_input = {k: v.strip() for k, v in self._text_input.items()}
|
||||
self.words = words_from_text(self.text)
|
||||
if attack_attrs is None:
|
||||
self.attack_attrs = dict()
|
||||
elif isinstance(attack_attrs, dict):
|
||||
self.attack_attrs = attack_attrs
|
||||
else:
|
||||
raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}")
|
||||
# Indices of words from the *original* text. Allows us to map
|
||||
# indices between original text and this text, and vice-versa.
|
||||
self.attack_attrs.setdefault("original_index_map", np.arange(len(self.words)))
|
||||
@@ -75,7 +84,6 @@ class AttackedText:
|
||||
|
||||
Can be called once the AttackedText is only needed to display.
|
||||
"""
|
||||
self.tokenizer = None
|
||||
if "previous_attacked_text" in self.attack_attrs:
|
||||
self.attack_attrs["previous_attacked_text"].free_memory()
|
||||
if "last_transformation" in self.attack_attrs:
|
||||
@@ -122,42 +130,42 @@ class AttackedText:
|
||||
look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
|
||||
return self.text[look_after_index:]
|
||||
|
||||
def first_word_diff(self, other_tokenized_text):
|
||||
def first_word_diff(self, other_attacked_text):
|
||||
""" Returns the first word in self.words that differs from
|
||||
other_tokenized_text. Useful for word swap strategies. """
|
||||
other_attacked_text. Useful for word swap strategies. """
|
||||
w1 = self.words
|
||||
w2 = other_tokenized_text.words
|
||||
w2 = other_attacked_text.words
|
||||
for i in range(min(len(w1), len(w2))):
|
||||
if w1[i] != w2[i]:
|
||||
return w1
|
||||
return None
|
||||
|
||||
def first_word_diff_index(self, other_tokenized_text):
|
||||
def first_word_diff_index(self, other_attacked_text):
|
||||
""" Returns the index of the first word in self.words that differs
|
||||
from other_tokenized_text. Useful for word swap strategies. """
|
||||
from other_attacked_text. Useful for word swap strategies. """
|
||||
w1 = self.words
|
||||
w2 = other_tokenized_text.words
|
||||
w2 = other_attacked_text.words
|
||||
for i in range(min(len(w1), len(w2))):
|
||||
if w1[i] != w2[i]:
|
||||
return i
|
||||
return None
|
||||
|
||||
def all_words_diff(self, other_tokenized_text):
|
||||
""" Returns the set of indices for which this and other_tokenized_text
|
||||
def all_words_diff(self, other_attacked_text):
|
||||
""" Returns the set of indices for which this and other_attacked_text
|
||||
have different words. """
|
||||
indices = set()
|
||||
w1 = self.words
|
||||
w2 = other_tokenized_text.words
|
||||
w2 = other_attacked_text.words
|
||||
for i in range(min(len(w1), len(w2))):
|
||||
if w1[i] != w2[i]:
|
||||
indices.add(i)
|
||||
return indices
|
||||
|
||||
def ith_word_diff(self, other_tokenized_text, i):
|
||||
""" Returns whether the word at index i differs from other_tokenized_text
|
||||
def ith_word_diff(self, other_attacked_text, i):
|
||||
""" Returns whether the word at index i differs from other_attacked_text
|
||||
"""
|
||||
w1 = self.words
|
||||
w2 = other_tokenized_text.words
|
||||
w2 = other_attacked_text.words
|
||||
if len(w1) - 1 < i or len(w2) - 1 < i:
|
||||
return True
|
||||
return w1[i] != w2[i]
|
||||
@@ -244,7 +252,7 @@ class AttackedText:
|
||||
of one or more words.
|
||||
"""
|
||||
perturbed_text = ""
|
||||
original_text = self._input_text.join(AttackedText.SPLIT_TOKEN)
|
||||
original_text = AttackedText.SPLIT_TOKEN.join(self._text_input.values())
|
||||
new_attack_attrs = dict()
|
||||
new_attack_attrs["newly_modified_indices"] = set()
|
||||
# Point to previously monitored text.
|
||||
@@ -261,9 +269,9 @@ class AttackedText:
|
||||
# Create the new attacked text by swapping out words from the original
|
||||
# text with a sequence of 0+ words in the new text.
|
||||
for i, (input_word, adv_word_seq) in enumerate(zip(self.words, new_words)):
|
||||
word_start = text.index(input_word)
|
||||
word_start = original_text.index(input_word)
|
||||
word_end = word_start + len(input_word)
|
||||
perturbed_text += text[:word_start]
|
||||
perturbed_text += original_text[:word_start]
|
||||
original_text = original_text[word_end:]
|
||||
adv_num_words = len(words_from_text(adv_word_seq))
|
||||
num_words_diff = adv_num_words - len(words_from_text(input_word))
|
||||
@@ -300,32 +308,47 @@ class AttackedText:
|
||||
# @TODO What to do with punctuation in this case? This behavior is undefined.
|
||||
if i == 0:
|
||||
# If the first word was deleted, take a subsequent space.
|
||||
if text[0] == " ":
|
||||
text = text[1:]
|
||||
if original_text[0] == " ":
|
||||
original_text = original_text[1:]
|
||||
else:
|
||||
# If a word other than the first was deleted, take a preceding space.
|
||||
if perturbed_text[-1] == " ":
|
||||
perturbed_text = perturbed_text[:-1]
|
||||
# Add substitute word(s) to new sentence.
|
||||
perturbed_text += adv_word_seq
|
||||
perturbed_text += text # Add all of the ending punctuation.
|
||||
perturbed_text += original_text # Add all of the ending punctuation.
|
||||
# Reform perturbed_text into an OrderedDict.
|
||||
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
|
||||
perturbed_input = OrderedDict(zip(self._text_input.keys(), perturbed_input_texts))
|
||||
return AttackedText(
|
||||
perturbed_text, self.tokenizer, attack_attrs=new_attack_attrs
|
||||
perturbed_input = OrderedDict(
|
||||
zip(self._text_input.keys(), perturbed_input_texts)
|
||||
)
|
||||
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
""" Represents full text input. Multiply inputs are joined with a line
|
||||
break.
|
||||
"""
|
||||
return '\n\n'.join(self._text_input.value())
|
||||
return "\n".join(self._text_input.values())
|
||||
|
||||
@property
|
||||
def printable_text(self):
|
||||
return '\n\n'.join(f'{key}: {value}' for key, value in enumerate(self._text_input))
|
||||
""" Represents full text input. Adds field descriptions.
|
||||
|
||||
For example, entailment inputs look like:
|
||||
```
|
||||
premise: ...
|
||||
hypothesis: ...
|
||||
```
|
||||
"""
|
||||
# For single-sequence inputs, don't show a prefix.
|
||||
if len(self._text_input) == 1:
|
||||
return next(iter(self._text_input.values()))
|
||||
# For multiple-sequence inputs, show a prefix and a colon.
|
||||
else:
|
||||
return "\n\n".join(
|
||||
f"{key}: {value}" for key, value in self._text_input.items()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<AttackedText "{self.text}">'
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import collections
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@@ -13,7 +14,7 @@ def _cb(s):
|
||||
|
||||
def get_num_successes(args, model, ids, true_labels):
|
||||
with torch.no_grad():
|
||||
preds = textattack.shared.utils.model_predict(model, ids)
|
||||
preds = textattack.shared.utils.model_predict(model, **ids)
|
||||
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.device)
|
||||
guess_labels = preds.argmax(dim=1)
|
||||
successes = (guess_labels == true_labels).sum().item()
|
||||
@@ -28,9 +29,8 @@ def test_model_on_dataset(args, model, dataset, batch_size=128):
|
||||
batch_labels = []
|
||||
all_true_labels = []
|
||||
all_guess_labels = []
|
||||
for i, (text, label) in enumerate(dataset):
|
||||
if i >= num_examples:
|
||||
break
|
||||
for i, (text_input, label) in enumerate(dataset):
|
||||
text = textattack.shared.AttackedText(text_input).text
|
||||
ids = model.tokenizer.encode(text)
|
||||
batch_ids.append(ids)
|
||||
batch_labels.append(label)
|
||||
|
||||
@@ -73,10 +73,8 @@ def run(args):
|
||||
|
||||
print("Attacking...")
|
||||
|
||||
tokenized_text = textattack.shared.tokenized_text.AttackedText(
|
||||
text, attack.goal_function.model.tokenizer
|
||||
)
|
||||
initial_result = attack.goal_function.get_output(tokenized_text)
|
||||
attacked_text = textattack.shared.attacked_text.AttackedText(text)
|
||||
initial_result = attack.goal_function.get_output(attacked_text)
|
||||
result = next(attack.attack_dataset([(text, initial_result)]))
|
||||
print(result.__str__(color_method="ansi"))
|
||||
|
||||
|
||||
@@ -3,13 +3,15 @@ import torch
|
||||
import textattack
|
||||
from textattack.shared import utils
|
||||
|
||||
def batch_tokenize(tokenizer, inputs):
|
||||
|
||||
def batch_tokenize(tokenizer, attacked_text_list):
|
||||
""" Tokenizes a list of inputs and returns their tokenized forms in a list. """
|
||||
if hasattr(tokenizer, 'encode_batch'):
|
||||
encoded_inputs = tokenizer.encode_batch(inputs)
|
||||
inputs = [at.text for at in attacked_text_list]
|
||||
if hasattr(tokenizer, "encode_batch"):
|
||||
return tokenizer.encode_batch(inputs)
|
||||
else:
|
||||
encoded_inputs = [tokenizer.encode(x) for x in inputs]
|
||||
return [x.ids for x in encoded_inputs]
|
||||
return [tokenizer.encode(x) for x in inputs]
|
||||
|
||||
|
||||
def batch_model_predict(model, inputs, batch_size=utils.config("MODEL_BATCH_SIZE")):
|
||||
outputs = []
|
||||
|
||||
@@ -22,7 +22,7 @@ class CompositeTransformation(Transformation):
|
||||
self.transformations = transformations
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
new_tokenized_texts = set()
|
||||
new_attacked_texts = set()
|
||||
for transformation in self.transformations:
|
||||
new_tokenized_texts.update(transformation(*args, **kwargs))
|
||||
return list(new_tokenized_texts)
|
||||
new_attacked_texts.update(transformation(*args, **kwargs))
|
||||
return list(new_attacked_texts)
|
||||
|
||||
@@ -43,12 +43,12 @@ class WordSwapGradientBased(Transformation):
|
||||
self.top_n = top_n
|
||||
self.is_black_box = False
|
||||
|
||||
def _get_replacement_words_by_grad(self, text, indices_to_replace):
|
||||
def _get_replacement_words_by_grad(self, attacked_text, indices_to_replace):
|
||||
""" Returns returns a list containing all possible words to replace
|
||||
`word` with, based off of the model's gradient.
|
||||
|
||||
Arguments:
|
||||
text (AttackedText): The full text input to perturb
|
||||
attacked_text (AttackedText): The full text input to perturb
|
||||
word_index (int): index of the word to replace
|
||||
"""
|
||||
self.model.train()
|
||||
@@ -56,11 +56,14 @@ class WordSwapGradientBased(Transformation):
|
||||
lookup_table = self.model.lookup_table.to(utils.device)
|
||||
lookup_table_transpose = lookup_table.transpose(0, 1)
|
||||
|
||||
# get word IDs
|
||||
text_ids = self.model.tokenizer.encode(attacked_text.text)
|
||||
|
||||
# set backward hook on the word embeddings for input x
|
||||
emb_hook = Hook(self.model.word_embeddings, backward=True)
|
||||
|
||||
self.model.zero_grad()
|
||||
predictions = self._call_model(text)
|
||||
predictions = self._call_model(text_ids)
|
||||
original_label = predictions.argmax()
|
||||
y_true = torch.Tensor([original_label]).long().to(utils.device)
|
||||
loss = self.loss(predictions, y_true)
|
||||
@@ -78,7 +81,7 @@ class WordSwapGradientBased(Transformation):
|
||||
b_grads = (
|
||||
emb_grad[word_idx].view(1, -1).mm(lookup_table_transpose).squeeze()
|
||||
)
|
||||
a_grad = b_grads[text.ids[word_idx]]
|
||||
a_grad = b_grads[text_ids[word_idx]]
|
||||
diffs[j] = b_grads - a_grad
|
||||
|
||||
# Don't change to the pad token.
|
||||
@@ -104,12 +107,12 @@ class WordSwapGradientBased(Transformation):
|
||||
self.model.eval()
|
||||
return candidates
|
||||
|
||||
def _call_model(self, text):
|
||||
def _call_model(self, text_ids):
|
||||
""" A helper function to query `self.model` with AttackedText `text`.
|
||||
"""
|
||||
return utils.model_predict(self.model, [text.ids])
|
||||
return utils.model_predict(self.model, [text_ids])
|
||||
|
||||
def _get_transformations(self, tokenized_text, indices_to_replace):
|
||||
def _get_transformations(self, attacked_text, indices_to_replace):
|
||||
"""
|
||||
Returns a list of all possible transformations for `text`.
|
||||
|
||||
@@ -118,9 +121,9 @@ class WordSwapGradientBased(Transformation):
|
||||
"""
|
||||
transformations = []
|
||||
for word, idx in self._get_replacement_words_by_grad(
|
||||
tokenized_text, indices_to_replace
|
||||
attacked_text, indices_to_replace
|
||||
):
|
||||
transformations.append(tokenized_text.replace_word_at_index(idx, word))
|
||||
transformations.append(attacked_text.replace_word_at_index(idx, word))
|
||||
return transformations
|
||||
|
||||
def extra_repr_keys(self):
|
||||
|
||||
Reference in New Issue
Block a user