mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #292 from QData/gradient-support
Add model-agnostic gradient retreival + gradient-based GreedyWordWIR search
This commit is contained in:
@@ -67,7 +67,9 @@ class Attack:
|
||||
self.transformation = transformation
|
||||
if not self.transformation:
|
||||
raise NameError("Cannot instantiate attack without transformation")
|
||||
self.is_black_box = getattr(transformation, "is_black_box", True)
|
||||
self.is_black_box = (
|
||||
getattr(transformation, "is_black_box", True) and search_method.is_black_box
|
||||
)
|
||||
|
||||
if not self.search_method.check_transformation_compatibility(
|
||||
self.transformation
|
||||
@@ -114,6 +116,8 @@ class Attack:
|
||||
)
|
||||
)
|
||||
self.search_method.filter_transformations = self.filter_transformations
|
||||
if not search_method.is_black_box:
|
||||
self.search_method.get_model = lambda: self.goal_function.model
|
||||
|
||||
def clear_cache(self, recursive=True):
|
||||
self.constraints_cache.clear()
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
"""
|
||||
|
||||
.. _attacked_text:
|
||||
|
||||
""".. _attacked_text:
|
||||
|
||||
Attacked Text Class
|
||||
=====================
|
||||
|
||||
A helper class that represents a string that can be attacked.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
@@ -436,6 +431,40 @@ class AttackedText:
|
||||
assert self.num_words == x.num_words
|
||||
return float(np.sum(self.words != x.words)) / self.num_words
|
||||
|
||||
def align_with_model_tokens(self, model_wrapper):
|
||||
"""Align AttackedText's `words` with target model's tokenization scheme
|
||||
(e.g. word, character, subword). Specifically, we map each word to list
|
||||
of indices of tokens that compose the word (e.g. embedding --> ["em",
|
||||
"##bed", "##ding"])
|
||||
|
||||
Args:
|
||||
model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model
|
||||
|
||||
Returns:
|
||||
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
|
||||
"""
|
||||
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
|
||||
word2token_mapping = {}
|
||||
j = 0
|
||||
last_matched = 0
|
||||
for i, word in enumerate(self.words):
|
||||
matched_tokens = []
|
||||
while j < len(tokens) and len(word) > 0:
|
||||
token = tokens[j].lower()
|
||||
idx = word.find(token)
|
||||
if idx == 0:
|
||||
word = word[idx + len(token) :]
|
||||
matched_tokens.append(j)
|
||||
last_matched = j
|
||||
j += 1
|
||||
|
||||
if not matched_tokens:
|
||||
j = last_matched
|
||||
else:
|
||||
word2token_mapping[self.words[i]] = matched_tokens
|
||||
|
||||
return word2token_mapping
|
||||
|
||||
@property
|
||||
def tokenizer_input(self):
|
||||
"""The tuple of inputs to be passed to the tokenizer."""
|
||||
|
||||
@@ -17,8 +17,6 @@ def html_style_from_dict(style_dict):
|
||||
|
||||
into
|
||||
style: "color: red; height: 100px"
|
||||
|
||||
|
||||
"""
|
||||
style_str = ""
|
||||
for key in style_dict:
|
||||
@@ -100,14 +98,6 @@ def load_textattack_model_from_path(model_name, model_path):
|
||||
model = textattack.models.helpers.WordCNNForClassification(
|
||||
model_path=model_path, num_labels=num_labels
|
||||
)
|
||||
elif model_name.startswith("bert"):
|
||||
model_path, num_labels = model_path
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
|
||||
)
|
||||
model = textattack.models.helpers.BERTForClassification(
|
||||
model_path=model_path, num_labels=num_labels
|
||||
)
|
||||
elif model_name.startswith("t5"):
|
||||
model = textattack.models.helpers.T5ForTextToText(model_path)
|
||||
else:
|
||||
|
||||
@@ -82,7 +82,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
|
||||
|
||||
def validate_model_gradient_word_swap_compatibility(model):
|
||||
"""Determines if ``model`` is task-compatible with
|
||||
``radientBasedWordSwap``.
|
||||
``GradientBasedWordSwap``.
|
||||
|
||||
We can only take the gradient with respect to an individual word if
|
||||
the model uses a word-based tokenizer.
|
||||
|
||||
Reference in New Issue
Block a user