1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #292 from QData/gradient-support

Add model-agnostic gradient retreival + gradient-based GreedyWordWIR search
This commit is contained in:
Jack Morris
2020-11-06 18:16:35 -05:00
committed by GitHub
43 changed files with 440 additions and 180 deletions

View File

@@ -67,7 +67,9 @@ class Attack:
self.transformation = transformation
if not self.transformation:
raise NameError("Cannot instantiate attack without transformation")
self.is_black_box = getattr(transformation, "is_black_box", True)
self.is_black_box = (
getattr(transformation, "is_black_box", True) and search_method.is_black_box
)
if not self.search_method.check_transformation_compatibility(
self.transformation
@@ -114,6 +116,8 @@ class Attack:
)
)
self.search_method.filter_transformations = self.filter_transformations
if not search_method.is_black_box:
self.search_method.get_model = lambda: self.goal_function.model
def clear_cache(self, recursive=True):
self.constraints_cache.clear()

View File

@@ -1,14 +1,9 @@
"""
.. _attacked_text:
""".. _attacked_text:
Attacked Text Class
=====================
A helper class that represents a string that can be attacked.
"""
from collections import OrderedDict
@@ -436,6 +431,40 @@ class AttackedText:
assert self.num_words == x.num_words
return float(np.sum(self.words != x.words)) / self.num_words
def align_with_model_tokens(self, model_wrapper):
"""Align AttackedText's `words` with target model's tokenization scheme
(e.g. word, character, subword). Specifically, we map each word to list
of indices of tokens that compose the word (e.g. embedding --> ["em",
"##bed", "##ding"])
Args:
model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model
Returns:
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
"""
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
word2token_mapping = {}
j = 0
last_matched = 0
for i, word in enumerate(self.words):
matched_tokens = []
while j < len(tokens) and len(word) > 0:
token = tokens[j].lower()
idx = word.find(token)
if idx == 0:
word = word[idx + len(token) :]
matched_tokens.append(j)
last_matched = j
j += 1
if not matched_tokens:
j = last_matched
else:
word2token_mapping[self.words[i]] = matched_tokens
return word2token_mapping
@property
def tokenizer_input(self):
"""The tuple of inputs to be passed to the tokenizer."""

View File

@@ -17,8 +17,6 @@ def html_style_from_dict(style_dict):
into
style: "color: red; height: 100px"
"""
style_str = ""
for key in style_dict:
@@ -100,14 +98,6 @@ def load_textattack_model_from_path(model_name, model_path):
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("bert"):
model_path, num_labels = model_path
textattack.shared.logger.info(
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
)
model = textattack.models.helpers.BERTForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("t5"):
model = textattack.models.helpers.T5ForTextToText(model_path)
else:

View File

@@ -82,7 +82,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
def validate_model_gradient_word_swap_compatibility(model):
"""Determines if ``model`` is task-compatible with
``radientBasedWordSwap``.
``GradientBasedWordSwap``.
We can only take the gradient with respect to an individual word if
the model uses a word-based tokenizer.