1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #292 from QData/gradient-support

Add model-agnostic gradient retreival + gradient-based GreedyWordWIR search
This commit is contained in:
Jack Morris
2020-11-06 18:16:35 -05:00
committed by GitHub
43 changed files with 440 additions and 180 deletions

View File

@@ -1,5 +1,4 @@
"""
Welcome to the API references for TextAttack!
"""Welcome to the API references for TextAttack!
What is TextAttack?

View File

@@ -1,5 +1,4 @@
"""
.. _attack_recipes:
""".. _attack_recipes:
Attack Recipes:
======================
@@ -17,7 +16,6 @@ For example, ``attack = InputReductionFeng2018.build(model)`` creates `attack`,
TextAttack supports the following attack recipes (each recipe's documentation contains a link to the corresponding paper):
.. contents:: :local:
"""
from .attack_recipe import AttackRecipe

View File

@@ -24,7 +24,8 @@ from .attack_recipe import AttackRecipe
class Pruthi2019(AttackRecipe):
"""An implementation of the attack used in "Combating Adversarial Misspellings with Robust Word Recognition", Pruthi et al., 2019.
"""An implementation of the attack used in "Combating Adversarial
Misspellings with Robust Word Recognition", Pruthi et al., 2019.
This attack focuses on a small number of character-level changes that simulate common typos. It combines:
- Swapping neighboring characters

View File

@@ -1,12 +1,9 @@
"""
.. _augmentation:
""".. _augmentation:
Augmenter:
==================
Transformations and constraints can be used outside of an attack for simple NLP data augmentation with the ``Augmenter`` class that returns all possible transformations for a given string.
"""
from .augmenter import Augmenter
from .recipes import (

View File

@@ -312,9 +312,7 @@ def parse_model_from_args(args):
)
# Choose the approprate model wrapper (based on whether or not this is
# a HuggingFace model).
if isinstance(
model, textattack.models.helpers.BERTForClassification
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
if isinstance(model, textattack.models.helpers.T5ForTextToText):
model = textattack.models.wrappers.HuggingFaceModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)

View File

@@ -31,7 +31,7 @@ def run(args, checkpoint=None):
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Disable tensorflow logs, except in the case of an error.
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "4"
# Fix TensorFlow GPU memory growth
import tensorflow as tf

View File

@@ -1,5 +1,4 @@
"""
.. _constraint:
""".. _constraint:
Constraint Package
===================
@@ -17,7 +16,6 @@ We split constraints into three main categories.
A fourth type of constraint restricts the search method from exploring certain parts of the search space:
:ref:`pre_transformation <pre_transformation>`: Based on the input and index of word replacement.
"""
from .pre_transformation_constraint import PreTransformationConstraint

View File

@@ -1,13 +1,10 @@
"""
.. _grammaticality:
""".. _grammaticality:
Grammaticality:
--------------------------
Grammaticality constraints determine if a transformation is valid based on
syntactic properties of the perturbation.
"""
from . import language_models

View File

@@ -1,12 +1,9 @@
"""
.. _overlap:
""".. _overlap:
Overlap Constraints
--------------------------
Overlap constraints determine if a transformation is valid based on character-level analysis.
"""
from .bleu_score import BLEU

View File

@@ -1,13 +1,9 @@
"""
.. _pre_transformation:
""".. _pre_transformation:
Pre-Transformation:
---------------------
Pre-transformation constraints determine if a transformation is valid based on only the original input and the position of the replacement. These constraints are applied before the transformation is even called. For example, these constraints can prevent search methods from swapping words at the same index twice, or from replacing stopwords.
"""
from .stopword_modification import StopwordModification
from .repeat_modification import RepeatModification

View File

@@ -1,11 +1,8 @@
"""
.. _semantics:
""".. _semantics:
Semantic Constraints
---------------------
Semantic constraints determine if a transformation is valid based on similarity of the semantics of the orignal input and the transformed input.
"""
from . import sentence_encoders

View File

@@ -18,8 +18,7 @@ from textattack.shared import utils
class BERTScore(Constraint):
"""
A constraint on BERT-Score difference.
"""A constraint on BERT-Score difference.
Args:
min_bert_score (float), minimum threshold value for BERT-Score
@@ -33,7 +32,6 @@ class BERTScore(Constraint):
compare_against_original (bool):
If ``True``, compare new ``x_adv`` against the original ``x``.
Otherwise, compare it against the previous ``x_adv``.
"""
SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2}
@@ -60,7 +58,8 @@ class BERTScore(Constraint):
)
def _check_constraint(self, transformed_text, reference_text):
"""Return `True` if BERT Score between `transformed_text` and `reference_text` is lower than minimum BERT Score."""
"""Return `True` if BERT Score between `transformed_text` and
`reference_text` is lower than minimum BERT Score."""
cand = transformed_text.text
ref = reference_text.text
result = self._bert_scorer.score([cand], [ref])

View File

@@ -1,7 +1,4 @@
"""
.. _goal_function:
""".. _goal_function:
Goal functions determine if an attack has been successful.
===========================================================

View File

@@ -15,7 +15,8 @@ from .text_to_text_goal_function import TextToTextGoalFunction
class MinimizeBleu(TextToTextGoalFunction):
"""Attempts to minimize the BLEU score between the current output translation and the reference translation.
"""Attempts to minimize the BLEU score between the current output
translation and the reference translation.
BLEU score was defined in (BLEU: a Method for Automatic Evaluation of Machine Translation).
@@ -28,8 +29,6 @@ class MinimizeBleu(TextToTextGoalFunction):
`ArxivURL2`_
.. _ArxivURL2: https://www.aclweb.org/anthology/2020.acl-main.263
"""
EPS = 1e-10

View File

@@ -1,7 +1,4 @@
"""
.. _loggers:
""".. _loggers:
Misc Loggers: Loggers track, visualize, and export attack results.
===================================================================

View File

@@ -1,5 +1,4 @@
"""
.. _models:
""".. _models:
Models
=========

View File

@@ -9,7 +9,6 @@ from . import utils
from .glove_embedding_layer import GloveEmbeddingLayer
# Helper modules.
from .bert_for_classification import BERTForClassification
from .lstm_for_classification import LSTMForClassification
from .t5_for_text_to_text import T5ForTextToText
from .word_cnn_for_classification import WordCNNForClassification

View File

@@ -62,8 +62,6 @@ class LSTMForClassification(nn.Module):
def load_from_disk(self, model_path):
self.load_state_dict(load_cached_state_dict(model_path))
self.word_embeddings = self.emb_layer.embedding
self.lookup_table = self.emb_layer.embedding.weight.data
self.to(utils.device)
self.eval()
@@ -80,3 +78,6 @@ class LSTMForClassification(nn.Module):
output = self.drop(output)
pred = self.out(output)
return pred
def get_input_embeddings(self):
return self.emb_layer.embedding

View File

@@ -57,3 +57,6 @@ class T5ForTextToText(torch.nn.Module):
)
# Convert ID tensor to string and return.
return [self.tokenizer.decode(ids) for ids in output_ids_list]
def get_input_embeddings(self):
return self.model.get_input_embeddings()

View File

@@ -65,6 +65,9 @@ class WordCNNForClassification(nn.Module):
pred = self.out(output)
return pred
def get_input_embeddings(self):
return self.emb_layer.embedding
class CNNTextLayer(nn.Module):
def __init__(self, n_in, widths=[3, 4, 5], filters=100):

View File

@@ -86,3 +86,20 @@ class AutoTokenizer:
return list_of_dicts
else:
return [self.encode(input_text) for input_text in input_text_list]
def convert_ids_to_tokens(self, ids):
return self.tokenizer.convert_ids_to_tokens(ids)
@property
def pad_token_id(self):
if hasattr(self.tokenizer, "pad_token_id"):
return self.tokenizer.pad_token_id
else:
raise AttributeError("Tokenizer does not have `pad_token_id` attribute.")
@property
def mask_token_id(self):
if hasattr(self.tokenizer, "mask_token_id"):
return self.tokenizer.mask_token_id
else:
raise AttributeError("Tokenizer does not have `mask_token_id` attribute.")

View File

@@ -104,6 +104,9 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
"unicode_normalizer": unicode_normalizer,
}
self.unk_token = unk_token
self.pad_token = pad_token
super().__init__(tokenizer, parameters)
@@ -122,8 +125,8 @@ class GloveTokenizer(WordLevelTokenizer):
pad_token_id=pad_token_id,
lowercase=True,
)
self.pad_id = pad_token_id
self.oov_id = unk_token_id
self.pad_token_id = pad_token_id
self.oov_token_id = unk_token_id
self.convert_id_to_word = self.id_to_token
# Set defaults.
self.enable_padding(length=max_length, pad_id=pad_token_id)
@@ -156,3 +159,6 @@ class GloveTokenizer(WordLevelTokenizer):
add_special_tokens=False,
)
return [x.ids for x in encodings]
def convert_ids_to_tokens(self, ids):
return [self.convert_id_to_word(_id) for _id in ids]

View File

@@ -21,40 +21,110 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
self.tokenizer = tokenizer
self.batch_size = batch_size
def _model_predict(self, inputs):
"""Turn a list of dicts into a dict of lists.
Then make lists (values of dict) into tensors.
"""
model_device = next(self.model.parameters()).device
input_dict = {k: [_dict[k] for _dict in inputs] for k in inputs[0]}
input_dict = {
k: torch.tensor(v).to(model_device) for k, v in input_dict.items()
}
outputs = self.model(**input_dict)
if isinstance(outputs[0], str):
# HuggingFace sequence-to-sequence models return a list of
# string predictions as output. In this case, return the full
# list of outputs.
return outputs
else:
# HuggingFace classification models return a tuple as output
# where the first item in the tuple corresponds to the list of
# scores for each input.
return outputs[0]
def __call__(self, text_input_list):
"""Passes inputs to HuggingFace models as keyword arguments.
(Regular PyTorch ``nn.Module`` models typically take inputs as
positional arguments.)
"""
ids = self.tokenize(text_input_list)
def model_predict(inputs):
"""Turn a list of dicts into a dict of lists.
Then make lists (values of dict) into tensors.
"""
model_device = next(self.model.parameters()).device
input_dict = {k: [_dict[k] for _dict in inputs] for k in inputs[0]}
input_dict = {
k: torch.tensor(v).to(model_device) for k, v in input_dict.items()
}
outputs = self.model(**input_dict)
if isinstance(outputs[0], str):
# HuggingFace sequence-to-sequence models return a list of
# string predictions as output. In this case, return the full
# list of outputs.
return outputs
else:
# HuggingFace classification models return a tuple as output
# where the first item in the tuple corresponds to the list of
# scores for each input.
return outputs[0]
ids = self.encode(text_input_list)
with torch.no_grad():
outputs = textattack.shared.utils.batch_model_predict(
model_predict, ids, batch_size=self.batch_size
self._model_predict, ids, batch_size=self.batch_size
)
return outputs
def get_grad(self, text_input):
"""Get gradient of loss with respect to input tokens.
Args:
text_input (str): input string
Returns:
Dict of ids, tokens, and gradient as numpy array.
"""
if isinstance(self.model, textattack.models.helpers.T5ForTextToText):
raise NotImplementedError(
"`get_grads` for T5FotTextToText has not been implemented yet."
)
self.model.train()
embedding_layer = self.model.get_input_embeddings()
original_state = embedding_layer.weight.requires_grad
embedding_layer.weight.requires_grad = True
emb_grads = []
def grad_hook(module, grad_in, grad_out):
emb_grads.append(grad_out[0])
emb_hook = embedding_layer.register_backward_hook(grad_hook)
self.model.zero_grad()
model_device = next(self.model.parameters()).device
ids = self.encode([text_input])
predictions = self._model_predict(ids)
model_device = next(self.model.parameters()).device
input_dict = {k: [_dict[k] for _dict in ids] for k in ids[0]}
input_dict = {
k: torch.tensor(v).to(model_device) for k, v in input_dict.items()
}
try:
labels = predictions.argmax(dim=1)
loss = self.model(**input_dict, labels=labels)[0]
except TypeError:
raise TypeError(
f"{type(self.model)} class does not take in `labels` to calculate loss. "
"One cause for this might be if you instantiatedyour model using `transformer.AutoModel` "
"(instead of `transformers.AutoModelForSequenceClassification`)."
)
loss.backward()
# grad w.r.t to word embeddings
grad = emb_grads[0][0].cpu().numpy()
embedding_layer.weight.requires_grad = original_state
emb_hook.remove()
self.model.eval()
output = {"ids": ids[0]["input_ids"], "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x)["input_ids"])
for x in inputs
]

View File

@@ -15,10 +15,49 @@ class ModelWrapper(ABC):
def __call__(self, text_list):
raise NotImplementedError()
def tokenize(self, inputs):
@abstractmethod
def get_grad(self, text_input):
"""Get gradient of loss with respect to input tokens."""
raise NotImplementedError()
def encode(self, inputs):
"""Helper method that calls ``tokenizer.batch_encode`` if possible, and
if not, falls back to calling ``tokenizer.encode`` for each input."""
if not, falls back to calling ``tokenizer.encode`` for each input.
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[int]]): List of list of ids
"""
if hasattr(self.tokenizer, "batch_encode"):
return self.tokenizer.batch_encode(inputs)
else:
return [self.tokenizer.encode(x) for x in inputs]
def _tokenize(self, inputs):
"""Helper method for `tokenize`"""
raise NotImplementedError()
def tokenize(self, inputs, strip_prefix=False):
"""Helper method that tokenizes input strings
Args:
inputs (list[str]): list of input strings
strip_prefix (bool): If `True`, we strip auxiliary characters added to tokens as prefixes (e.g. "##" for BERT, "Ġ" for RoBERTa)
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
tokens = self._tokenize(inputs)
if strip_prefix:
# `aux_chars` are known auxiliary characters that are added to tokens
strip_chars = ["##", "Ġ", "__"]
# TODO: Find a better way to identify prefixes. These depend on the model, so cannot be resolved in ModelWrapper.
def strip(s, chars):
for c in chars:
s = s.replace(c, "")
return s
tokens = [[strip(t, strip_chars) for t in x] for x in tokens]
return tokens

View File

@@ -5,6 +5,7 @@ PyTorch Model Wrapper
import torch
from torch.nn import CrossEntropyLoss
import textattack
@@ -12,7 +13,13 @@ from .model_wrapper import ModelWrapper
class PyTorchModelWrapper(ModelWrapper):
"""Loads a PyTorch model (`nn.Module`) and tokenizer."""
"""Loads a PyTorch model (`nn.Module`) and tokenizer.
Args:
model (torch.nn.Module): PyTorch model
tokenizer: tokenizer whose output can be packed as a tensor and passed to the model.
No type requirement, but most have `tokenizer` method that accepts list of strings.
"""
def __init__(self, model, tokenizer, batch_size=32):
if not isinstance(model, torch.nn.Module):
@@ -26,7 +33,7 @@ class PyTorchModelWrapper(ModelWrapper):
def __call__(self, text_input_list):
model_device = next(self.model.parameters()).device
ids = self.tokenize(text_input_list)
ids = self.encode(text_input_list)
ids = torch.tensor(ids).to(model_device)
with torch.no_grad():
@@ -35,3 +42,67 @@ class PyTorchModelWrapper(ModelWrapper):
)
return outputs
def get_grad(self, text_input, loss_fn=CrossEntropyLoss()):
"""Get gradient of loss with respect to input tokens.
Args:
text_input (str): input string
loss_fn (torch.nn.Module): loss function. Default is `torch.nn.CrossEntropyLoss`
Returns:
Dict of ids, tokens, and gradient as numpy array.
"""
if not hasattr(self.model, "get_input_embeddings"):
raise AttributeError(
f"{type(self.model)} must have method `get_input_embeddings` that returns `torch.nn.Embedding` object that represents input embedding layer"
)
if not isinstance(loss_fn, torch.nn.Module):
raise ValueError("Loss function must be of type `torch.nn.Module`.")
self.model.train()
embedding_layer = self.model.get_input_embeddings()
original_state = embedding_layer.weight.requires_grad
embedding_layer.weight.requires_grad = True
emb_grads = []
def grad_hook(module, grad_in, grad_out):
emb_grads.append(grad_out[0])
emb_hook = embedding_layer.register_backward_hook(grad_hook)
self.model.zero_grad()
model_device = next(self.model.parameters()).device
ids = self.encode([text_input])
ids = torch.tensor(ids).to(model_device)
predictions = self.model(ids)
output = predictions.argmax(dim=1)
loss = loss_fn(predictions, output)
loss.backward()
# grad w.r.t to word embeddings
grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy()
embedding_layer.weight.requires_grad = original_state
emb_hook.remove()
self.model.eval()
output = {"ids": ids[0].tolist(), "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x))
for x in inputs
]

View File

@@ -27,3 +27,6 @@ class SklearnModelWrapper(ModelWrapper):
encoded_text_matrix, columns=self.tokenizer.get_feature_names()
)
return self.model.predict_proba(tokenized_text_df)
def get_grad(self, text_input):
raise NotImplementedError()

View File

@@ -27,3 +27,6 @@ class TensorFlowModelWrapper(ModelWrapper):
text_array = np.array(text_input_list)
preds = self.model(text_array)
return preds.numpy()
def get_grad(self, text_input):
raise NotImplementedError()

View File

@@ -1,13 +1,9 @@
"""
.. _search_methods:
""".. _search_methods:
Search Methods:
===================
Search methods explore the transformation space in an attempt to find a successful attack as determined by a :ref:`Goal Functions <goal_function>` and list of :ref:`Constraints <constraint>`
"""
from .search_method import SearchMethod
from .beam_search import BeamSearch

View File

@@ -49,5 +49,9 @@ class BeamSearch(SearchMethod):
beam = [potential_next_beam[i] for i in best_indices]
return best_result
@property
def is_black_box(self):
return True
def extra_repr_keys(self):
return ["beam_width"]

View File

@@ -285,6 +285,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
substitutions."""
return transformation_consists_of_word_swaps(transformation)
@property
def is_black_box(self):
return True
def extra_repr_keys(self):
return [
"pop_size",

View File

@@ -28,6 +28,7 @@ class GreedyWordSwapWIR(SearchMethod):
Args:
wir_method: method for ranking most important words
model_wrapper: model wrapper used for gradient-based ranking
"""
def __init__(self, wir_method="unk"):
@@ -44,6 +45,7 @@ class GreedyWordSwapWIR(SearchMethod):
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "weighted-saliency":
# first, compute word saliency
leave_one_texts = [
@@ -74,12 +76,30 @@ class GreedyWordSwapWIR(SearchMethod):
delta_ps.append(max_score_change)
index_scores = softmax_saliency_scores * np.array(delta_ps)
elif self.wir_method == "delete":
leave_one_texts = [
initial_text.delete_word_at_index(i) for i in range(len_text)
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "gradient":
victim_model = self.get_model()
index_scores = np.zeros(initial_text.num_words)
grad_output = victim_model.get_grad(initial_text.tokenizer_input)
gradient = grad_output["gradient"]
word2token_mapping = initial_text.align_with_model_tokens(victim_model)
for i, word in enumerate(initial_text.words):
matched_tokens = word2token_mapping[word]
if not matched_tokens:
index_scores[i] = 0.0
else:
agg_grad = np.mean(gradient[matched_tokens], axis=0)
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
search_over = False
elif self.wir_method == "random":
index_order = np.arange(len_text)
np.random.shuffle(index_order)
@@ -146,5 +166,12 @@ class GreedyWordSwapWIR(SearchMethod):
limited to word swap and deletion transformations."""
return transformation_consists_of_word_swaps_and_deletions(transformation)
@property
def is_black_box(self):
if self.wir_method == "gradient":
return False
else:
return True
def extra_repr_keys(self):
return ["wir_method"]

View File

@@ -329,6 +329,10 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
substitutions."""
return transformation_consists_of_word_swaps(transformation)
@property
def is_black_box(self):
return True
def extra_repr_keys(self):
return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]

View File

@@ -32,6 +32,12 @@ class SearchMethod(ABC):
raise AttributeError(
"Search Method must have access to filter_transformations method"
)
if not self.is_black_box and not hasattr(self, "get_model"):
raise AttributeError(
"Search Method must have access to get_model method if it is a white-box method"
)
return self._perform_search(initial_result)
@abstractmethod
@@ -48,6 +54,12 @@ class SearchMethod(ABC):
``transformation``."""
return True
@property
def is_black_box(self):
"""Returns `True` if search method does not require access to victim
model's internal states."""
raise NotImplementedError()
def extra_repr_keys(self):
return []

View File

@@ -67,7 +67,9 @@ class Attack:
self.transformation = transformation
if not self.transformation:
raise NameError("Cannot instantiate attack without transformation")
self.is_black_box = getattr(transformation, "is_black_box", True)
self.is_black_box = (
getattr(transformation, "is_black_box", True) and search_method.is_black_box
)
if not self.search_method.check_transformation_compatibility(
self.transformation
@@ -114,6 +116,8 @@ class Attack:
)
)
self.search_method.filter_transformations = self.filter_transformations
if not search_method.is_black_box:
self.search_method.get_model = lambda: self.goal_function.model
def clear_cache(self, recursive=True):
self.constraints_cache.clear()

View File

@@ -1,14 +1,9 @@
"""
.. _attacked_text:
""".. _attacked_text:
Attacked Text Class
=====================
A helper class that represents a string that can be attacked.
"""
from collections import OrderedDict
@@ -436,6 +431,40 @@ class AttackedText:
assert self.num_words == x.num_words
return float(np.sum(self.words != x.words)) / self.num_words
def align_with_model_tokens(self, model_wrapper):
"""Align AttackedText's `words` with target model's tokenization scheme
(e.g. word, character, subword). Specifically, we map each word to list
of indices of tokens that compose the word (e.g. embedding --> ["em",
"##bed", "##ding"])
Args:
model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model
Returns:
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
"""
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
word2token_mapping = {}
j = 0
last_matched = 0
for i, word in enumerate(self.words):
matched_tokens = []
while j < len(tokens) and len(word) > 0:
token = tokens[j].lower()
idx = word.find(token)
if idx == 0:
word = word[idx + len(token) :]
matched_tokens.append(j)
last_matched = j
j += 1
if not matched_tokens:
j = last_matched
else:
word2token_mapping[self.words[i]] = matched_tokens
return word2token_mapping
@property
def tokenizer_input(self):
"""The tuple of inputs to be passed to the tokenizer."""

View File

@@ -17,8 +17,6 @@ def html_style_from_dict(style_dict):
into
style: "color: red; height: 100px"
"""
style_str = ""
for key in style_dict:
@@ -100,14 +98,6 @@ def load_textattack_model_from_path(model_name, model_path):
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("bert"):
model_path, num_labels = model_path
textattack.shared.logger.info(
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
)
model = textattack.models.helpers.BERTForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("t5"):
model = textattack.models.helpers.T5ForTextToText(model_path)
else:

View File

@@ -82,7 +82,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
def validate_model_gradient_word_swap_compatibility(model):
"""Determines if ``model`` is task-compatible with
``radientBasedWordSwap``.
``GradientBasedWordSwap``.
We can only take the gradient with respect to an individual word if
the model uses a word-based tokenizer.

View File

@@ -1,13 +1,9 @@
"""
.. _transformations:
""".. _transformations:
Transformations
==========================
A transformation is a method which perturbs a text input through the insertion, deletion and substiution of words, characters, and phrases. All transformations take a ``TokenizedText`` as input and return a list of ``TokenizedText`` that contains possible transformations. Every transformation is a subclass of the abstract ``Transformation`` class.
"""
from .transformation import Transformation

View File

@@ -34,25 +34,15 @@ class WordSwapGradientBased(Transformation):
# Make sure we know how to compute the gradient for this model.
validate_model_gradient_word_swap_compatibility(self.model)
# Make sure this model has all of the required properties.
if not hasattr(self.model, "word_embeddings"):
if not hasattr(self.model, "get_input_embeddings"):
raise ValueError(
"Model needs word embedding matrix for gradient-based word swap"
)
if not hasattr(self.model, "lookup_table"):
raise ValueError("Model needs lookup table for gradient-based word swap")
if not hasattr(self.model, "zero_grad"):
raise ValueError("Model needs `zero_grad()` for gradient-based word swap")
if not hasattr(self.tokenizer, "convert_id_to_word"):
if not hasattr(self.tokenizer, "pad_token_id") and self.tokenizer.pad_token_id:
raise ValueError(
"Tokenizer needs `convert_id_to_word()` for gradient-based word swap"
"Tokenizer needs to have `pad_token_id` for gradient-based word swap"
)
if not hasattr(self.tokenizer, "pad_id"):
raise ValueError("Tokenizer needs `pad_id` for gradient-based word swap")
if not hasattr(self.tokenizer, "oov_id"):
raise ValueError("Tokenizer needs `oov_id` for gradient-based word swap")
self.loss = torch.nn.CrossEntropyLoss()
self.pad_id = self.model_wrapper.tokenizer.pad_id
self.oov_id = self.model_wrapper.tokenizer.oov_id
self.top_n = top_n
self.is_black_box = False
@@ -64,45 +54,28 @@ class WordSwapGradientBased(Transformation):
attacked_text (AttackedText): The full text input to perturb
word_index (int): index of the word to replace
"""
self.model.train()
self.model.emb_layer.embedding.weight.requires_grad = True
lookup_table = self.model.lookup_table.to(utils.device)
lookup_table_transpose = lookup_table.transpose(0, 1)
# get word IDs
text_ids = self.tokenizer.encode(attacked_text.tokenizer_input)
# set backward hook on the word embeddings for input x
emb_hook = Hook(self.model.word_embeddings, backward=True)
self.model.zero_grad()
predictions = self._call_model(text_ids)
original_label = predictions.argmax()
y_true = torch.Tensor([original_label]).long().to(utils.device)
loss = self.loss(predictions, y_true)
loss.backward()
# grad w.r.t to word embeddings
emb_grad = emb_hook.output[0].to(utils.device).squeeze()
lookup_table = self.model.get_input_embeddings().weight.data.cpu()
grad_output = self.model_wrapper.get_grad(attacked_text.tokenizer_input)
emb_grad = torch.tensor(grad_output["gradient"])
text_ids = grad_output["ids"]
# grad differences between all flips and original word (eq. 1 from paper)
vocab_size = lookup_table.size(0)
diffs = torch.zeros(len(indices_to_replace), vocab_size)
indices_to_replace = list(indices_to_replace)
for j, word_idx in enumerate(indices_to_replace):
# Make sure the word is in bounds.
if word_idx >= len(emb_grad):
continue
# Get the grad w.r.t the one-hot index of the word.
b_grads = (
emb_grad[word_idx].view(1, -1).mm(lookup_table_transpose).squeeze()
)
b_grads = lookup_table.mv(emb_grad[word_idx]).squeeze()
a_grad = b_grads[text_ids[word_idx]]
diffs[j] = b_grads - a_grad
# Don't change to the pad token.
diffs[:, self.tokenizer.pad_id] = float("-inf")
diffs[:, self.tokenizer.pad_token_id] = float("-inf")
# Find best indices within 2-d tensor by flattening.
word_idxs_sorted_by_grad = (-diffs).flatten().argsort()
@@ -121,17 +94,8 @@ class WordSwapGradientBased(Transformation):
if len(candidates) == self.top_n:
break
self.model.eval()
self.model.emb_layer.embedding.weight.requires_grad = (
self.model.emb_layer_trainable
)
return candidates
def _call_model(self, text_ids):
"""A helper function to query `self.model` with AttackedText `text`."""
model_input = torch.tensor([text_ids]).to(textattack.shared.utils.device)
return self.model(model_input)
def _get_transformations(self, attacked_text, indices_to_replace):
"""Returns a list of all possible transformations for `text`.
@@ -147,18 +111,3 @@ class WordSwapGradientBased(Transformation):
def extra_repr_keys(self):
return ["top_n"]
class Hook:
def __init__(self, module, backward=False):
if backward:
self.hook = module.register_backward_hook(self.hook_fn)
else:
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = [x.to(utils.device) for x in input]
self.output = [x.to(utils.device) for x in output]
def close(self):
self.hook.remove()

View File

@@ -21,7 +21,6 @@ class WordSwapInflections(WordSwap):
`Paper URL`_
.. _Paper URL: https://www.aclweb.org/anthology/2020.acl-main.263.pdf
"""
def __init__(self, **kwargs):