1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge branch 'master' into lazy-loading

This commit is contained in:
Jin Yong Yoo
2020-11-10 19:21:17 -05:00
committed by GitHub
43 changed files with 527 additions and 136 deletions

View File

@@ -1,7 +1,7 @@
"""
TextAttack Command Args for Attack
------------------------------------------
TextAttack Command Args Designed for Attack
----------------------------------------------
"""

View File

@@ -312,9 +312,7 @@ def parse_model_from_args(args):
)
# Choose the approprate model wrapper (based on whether or not this is
# a HuggingFace model).
if isinstance(
model, textattack.models.helpers.BERTForClassification
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
if isinstance(model, textattack.models.helpers.T5ForTextToText):
model = textattack.models.wrappers.HuggingFaceModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)

View File

@@ -1,7 +1,7 @@
"""
TextAttack Command Arg Parsing
=====================================
TextAttack Command Arg Parsing Main Function
=============================================
"""
# !/usr/bin/env python
@@ -17,6 +17,16 @@ from textattack.commands.train_model import TrainModelCommand
def main():
"""This is the main command line parer and entry function to use TextAttack via command lines
texattack <command> [<args>]
Args:
command (string): augment, attack, train, eval-model, attack-resume, list, peek-dataset
[<args>] (string): depending on the command string
"""
parser = argparse.ArgumentParser(
"TextAttack CLI",
usage="[python -m] texattack <command> [<args>]",

View File

@@ -36,7 +36,7 @@ class GoalFunction(ABC):
model_cache_size=2 ** 20,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model_wrapper.__class__
self.__class__, model_wrapper.model.__class__
)
self.model = model_wrapper
self.maximizable = maximizable

View File

@@ -9,7 +9,6 @@ from . import utils
from .glove_embedding_layer import GloveEmbeddingLayer
# Helper modules.
from .bert_for_classification import BERTForClassification
from .lstm_for_classification import LSTMForClassification
from .t5_for_text_to_text import T5ForTextToText
from .word_cnn_for_classification import WordCNNForClassification

View File

@@ -62,8 +62,6 @@ class LSTMForClassification(nn.Module):
def load_from_disk(self, model_path):
self.load_state_dict(load_cached_state_dict(model_path))
self.word_embeddings = self.emb_layer.embedding
self.lookup_table = self.emb_layer.embedding.weight.data
self.to(utils.device)
self.eval()
@@ -80,3 +78,6 @@ class LSTMForClassification(nn.Module):
output = self.drop(output)
pred = self.out(output)
return pred
def get_input_embeddings(self):
return self.emb_layer.embedding

View File

@@ -57,3 +57,6 @@ class T5ForTextToText(torch.nn.Module):
)
# Convert ID tensor to string and return.
return [self.tokenizer.decode(ids) for ids in output_ids_list]
def get_input_embeddings(self):
return self.model.get_input_embeddings()

View File

@@ -65,6 +65,9 @@ class WordCNNForClassification(nn.Module):
pred = self.out(output)
return pred
def get_input_embeddings(self):
return self.emb_layer.embedding
class CNNTextLayer(nn.Module):
def __init__(self, n_in, widths=[3, 4, 5], filters=100):

View File

@@ -86,3 +86,20 @@ class AutoTokenizer:
return list_of_dicts
else:
return [self.encode(input_text) for input_text in input_text_list]
def convert_ids_to_tokens(self, ids):
return self.tokenizer.convert_ids_to_tokens(ids)
@property
def pad_token_id(self):
if hasattr(self.tokenizer, "pad_token_id"):
return self.tokenizer.pad_token_id
else:
raise AttributeError("Tokenizer does not have `pad_token_id` attribute.")
@property
def mask_token_id(self):
if hasattr(self.tokenizer, "mask_token_id"):
return self.tokenizer.mask_token_id
else:
raise AttributeError("Tokenizer does not have `mask_token_id` attribute.")

View File

@@ -104,6 +104,9 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
"unicode_normalizer": unicode_normalizer,
}
self.unk_token = unk_token
self.pad_token = pad_token
super().__init__(tokenizer, parameters)
@@ -122,8 +125,8 @@ class GloveTokenizer(WordLevelTokenizer):
pad_token_id=pad_token_id,
lowercase=True,
)
self.pad_id = pad_token_id
self.oov_id = unk_token_id
self.pad_token_id = pad_token_id
self.oov_token_id = unk_token_id
self.convert_id_to_word = self.id_to_token
# Set defaults.
self.enable_padding(length=max_length, pad_id=pad_token_id)
@@ -156,3 +159,6 @@ class GloveTokenizer(WordLevelTokenizer):
add_special_tokens=False,
)
return [x.ids for x in encodings]
def convert_ids_to_tokens(self, ids):
return [self.convert_id_to_word(_id) for _id in ids]

View File

@@ -21,40 +21,110 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
self.tokenizer = tokenizer
self.batch_size = batch_size
def _model_predict(self, inputs):
"""Turn a list of dicts into a dict of lists.
Then make lists (values of dict) into tensors.
"""
model_device = next(self.model.parameters()).device
input_dict = {k: [_dict[k] for _dict in inputs] for k in inputs[0]}
input_dict = {
k: torch.tensor(v).to(model_device) for k, v in input_dict.items()
}
outputs = self.model(**input_dict)
if isinstance(outputs[0], str):
# HuggingFace sequence-to-sequence models return a list of
# string predictions as output. In this case, return the full
# list of outputs.
return outputs
else:
# HuggingFace classification models return a tuple as output
# where the first item in the tuple corresponds to the list of
# scores for each input.
return outputs[0]
def __call__(self, text_input_list):
"""Passes inputs to HuggingFace models as keyword arguments.
(Regular PyTorch ``nn.Module`` models typically take inputs as
positional arguments.)
"""
ids = self.tokenize(text_input_list)
def model_predict(inputs):
"""Turn a list of dicts into a dict of lists.
Then make lists (values of dict) into tensors.
"""
model_device = next(self.model.parameters()).device
input_dict = {k: [_dict[k] for _dict in inputs] for k in inputs[0]}
input_dict = {
k: torch.tensor(v).to(model_device) for k, v in input_dict.items()
}
outputs = self.model(**input_dict)
if isinstance(outputs[0], str):
# HuggingFace sequence-to-sequence models return a list of
# string predictions as output. In this case, return the full
# list of outputs.
return outputs
else:
# HuggingFace classification models return a tuple as output
# where the first item in the tuple corresponds to the list of
# scores for each input.
return outputs[0]
ids = self.encode(text_input_list)
with torch.no_grad():
outputs = textattack.shared.utils.batch_model_predict(
model_predict, ids, batch_size=self.batch_size
self._model_predict, ids, batch_size=self.batch_size
)
return outputs
def get_grad(self, text_input):
"""Get gradient of loss with respect to input tokens.
Args:
text_input (str): input string
Returns:
Dict of ids, tokens, and gradient as numpy array.
"""
if isinstance(self.model, textattack.models.helpers.T5ForTextToText):
raise NotImplementedError(
"`get_grads` for T5FotTextToText has not been implemented yet."
)
self.model.train()
embedding_layer = self.model.get_input_embeddings()
original_state = embedding_layer.weight.requires_grad
embedding_layer.weight.requires_grad = True
emb_grads = []
def grad_hook(module, grad_in, grad_out):
emb_grads.append(grad_out[0])
emb_hook = embedding_layer.register_backward_hook(grad_hook)
self.model.zero_grad()
model_device = next(self.model.parameters()).device
ids = self.encode([text_input])
predictions = self._model_predict(ids)
model_device = next(self.model.parameters()).device
input_dict = {k: [_dict[k] for _dict in ids] for k in ids[0]}
input_dict = {
k: torch.tensor(v).to(model_device) for k, v in input_dict.items()
}
try:
labels = predictions.argmax(dim=1)
loss = self.model(**input_dict, labels=labels)[0]
except TypeError:
raise TypeError(
f"{type(self.model)} class does not take in `labels` to calculate loss. "
"One cause for this might be if you instantiatedyour model using `transformer.AutoModel` "
"(instead of `transformers.AutoModelForSequenceClassification`)."
)
loss.backward()
# grad w.r.t to word embeddings
grad = emb_grads[0][0].cpu().numpy()
embedding_layer.weight.requires_grad = original_state
emb_hook.remove()
self.model.eval()
output = {"ids": ids[0]["input_ids"], "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x)["input_ids"])
for x in inputs
]

View File

@@ -15,10 +15,49 @@ class ModelWrapper(ABC):
def __call__(self, text_list):
raise NotImplementedError()
def tokenize(self, inputs):
@abstractmethod
def get_grad(self, text_input):
"""Get gradient of loss with respect to input tokens."""
raise NotImplementedError()
def encode(self, inputs):
"""Helper method that calls ``tokenizer.batch_encode`` if possible, and
if not, falls back to calling ``tokenizer.encode`` for each input."""
if not, falls back to calling ``tokenizer.encode`` for each input.
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[int]]): List of list of ids
"""
if hasattr(self.tokenizer, "batch_encode"):
return self.tokenizer.batch_encode(inputs)
else:
return [self.tokenizer.encode(x) for x in inputs]
def _tokenize(self, inputs):
"""Helper method for `tokenize`"""
raise NotImplementedError()
def tokenize(self, inputs, strip_prefix=False):
"""Helper method that tokenizes input strings
Args:
inputs (list[str]): list of input strings
strip_prefix (bool): If `True`, we strip auxiliary characters added to tokens as prefixes (e.g. "##" for BERT, "Ġ" for RoBERTa)
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
tokens = self._tokenize(inputs)
if strip_prefix:
# `aux_chars` are known auxiliary characters that are added to tokens
strip_chars = ["##", "Ġ", "__"]
# TODO: Find a better way to identify prefixes. These depend on the model, so cannot be resolved in ModelWrapper.
def strip(s, chars):
for c in chars:
s = s.replace(c, "")
return s
tokens = [[strip(t, strip_chars) for t in x] for x in tokens]
return tokens

View File

@@ -5,6 +5,7 @@ PyTorch Model Wrapper
import torch
from torch.nn import CrossEntropyLoss
import textattack
@@ -12,7 +13,13 @@ from .model_wrapper import ModelWrapper
class PyTorchModelWrapper(ModelWrapper):
"""Loads a PyTorch model (`nn.Module`) and tokenizer."""
"""Loads a PyTorch model (`nn.Module`) and tokenizer.
Args:
model (torch.nn.Module): PyTorch model
tokenizer: tokenizer whose output can be packed as a tensor and passed to the model.
No type requirement, but most have `tokenizer` method that accepts list of strings.
"""
def __init__(self, model, tokenizer, batch_size=32):
if not isinstance(model, torch.nn.Module):
@@ -26,7 +33,7 @@ class PyTorchModelWrapper(ModelWrapper):
def __call__(self, text_input_list):
model_device = next(self.model.parameters()).device
ids = self.tokenize(text_input_list)
ids = self.encode(text_input_list)
ids = torch.tensor(ids).to(model_device)
with torch.no_grad():
@@ -35,3 +42,67 @@ class PyTorchModelWrapper(ModelWrapper):
)
return outputs
def get_grad(self, text_input, loss_fn=CrossEntropyLoss()):
"""Get gradient of loss with respect to input tokens.
Args:
text_input (str): input string
loss_fn (torch.nn.Module): loss function. Default is `torch.nn.CrossEntropyLoss`
Returns:
Dict of ids, tokens, and gradient as numpy array.
"""
if not hasattr(self.model, "get_input_embeddings"):
raise AttributeError(
f"{type(self.model)} must have method `get_input_embeddings` that returns `torch.nn.Embedding` object that represents input embedding layer"
)
if not isinstance(loss_fn, torch.nn.Module):
raise ValueError("Loss function must be of type `torch.nn.Module`.")
self.model.train()
embedding_layer = self.model.get_input_embeddings()
original_state = embedding_layer.weight.requires_grad
embedding_layer.weight.requires_grad = True
emb_grads = []
def grad_hook(module, grad_in, grad_out):
emb_grads.append(grad_out[0])
emb_hook = embedding_layer.register_backward_hook(grad_hook)
self.model.zero_grad()
model_device = next(self.model.parameters()).device
ids = self.encode([text_input])
ids = torch.tensor(ids).to(model_device)
predictions = self.model(ids)
output = predictions.argmax(dim=1)
loss = loss_fn(predictions, output)
loss.backward()
# grad w.r.t to word embeddings
grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy()
embedding_layer.weight.requires_grad = original_state
emb_hook.remove()
self.model.eval()
output = {"ids": ids[0].tolist(), "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x))
for x in inputs
]

View File

@@ -27,3 +27,6 @@ class SklearnModelWrapper(ModelWrapper):
encoded_text_matrix, columns=self.tokenizer.get_feature_names()
)
return self.model.predict_proba(tokenized_text_df)
def get_grad(self, text_input):
raise NotImplementedError()

View File

@@ -27,3 +27,6 @@ class TensorFlowModelWrapper(ModelWrapper):
text_array = np.array(text_input_list)
preds = self.model(text_array)
return preds.numpy()
def get_grad(self, text_input):
raise NotImplementedError()

View File

@@ -49,5 +49,9 @@ class BeamSearch(SearchMethod):
beam = [potential_next_beam[i] for i in best_indices]
return best_result
@property
def is_black_box(self):
return True
def extra_repr_keys(self):
return ["beam_width"]

View File

@@ -285,6 +285,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
substitutions."""
return transformation_consists_of_word_swaps(transformation)
@property
def is_black_box(self):
return True
def extra_repr_keys(self):
return [
"pop_size",

View File

@@ -28,6 +28,7 @@ class GreedyWordSwapWIR(SearchMethod):
Args:
wir_method: method for ranking most important words
model_wrapper: model wrapper used for gradient-based ranking
"""
def __init__(self, wir_method="unk"):
@@ -44,6 +45,7 @@ class GreedyWordSwapWIR(SearchMethod):
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "weighted-saliency":
# first, compute word saliency
leave_one_texts = [
@@ -74,12 +76,30 @@ class GreedyWordSwapWIR(SearchMethod):
delta_ps.append(max_score_change)
index_scores = softmax_saliency_scores * np.array(delta_ps)
elif self.wir_method == "delete":
leave_one_texts = [
initial_text.delete_word_at_index(i) for i in range(len_text)
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "gradient":
victim_model = self.get_model()
index_scores = np.zeros(initial_text.num_words)
grad_output = victim_model.get_grad(initial_text.tokenizer_input)
gradient = grad_output["gradient"]
word2token_mapping = initial_text.align_with_model_tokens(victim_model)
for i, word in enumerate(initial_text.words):
matched_tokens = word2token_mapping[word]
if not matched_tokens:
index_scores[i] = 0.0
else:
agg_grad = np.mean(gradient[matched_tokens], axis=0)
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
search_over = False
elif self.wir_method == "random":
index_order = np.arange(len_text)
np.random.shuffle(index_order)
@@ -146,5 +166,12 @@ class GreedyWordSwapWIR(SearchMethod):
limited to word swap and deletion transformations."""
return transformation_consists_of_word_swaps_and_deletions(transformation)
@property
def is_black_box(self):
if self.wir_method == "gradient":
return False
else:
return True
def extra_repr_keys(self):
return ["wir_method"]

View File

@@ -329,6 +329,10 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
substitutions."""
return transformation_consists_of_word_swaps(transformation)
@property
def is_black_box(self):
return True
def extra_repr_keys(self):
return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]

View File

@@ -32,6 +32,12 @@ class SearchMethod(ABC):
raise AttributeError(
"Search Method must have access to filter_transformations method"
)
if not self.is_black_box and not hasattr(self, "get_model"):
raise AttributeError(
"Search Method must have access to get_model method if it is a white-box method"
)
return self._perform_search(initial_result)
@abstractmethod
@@ -48,6 +54,12 @@ class SearchMethod(ABC):
``transformation``."""
return True
@property
def is_black_box(self):
"""Returns `True` if search method does not require access to victim
model's internal states."""
raise NotImplementedError()
def extra_repr_keys(self):
return []

View File

@@ -67,7 +67,9 @@ class Attack:
self.transformation = transformation
if not self.transformation:
raise NameError("Cannot instantiate attack without transformation")
self.is_black_box = getattr(transformation, "is_black_box", True)
self.is_black_box = (
getattr(transformation, "is_black_box", True) and search_method.is_black_box
)
if not self.search_method.check_transformation_compatibility(
self.transformation
@@ -114,6 +116,8 @@ class Attack:
)
)
self.search_method.filter_transformations = self.filter_transformations
if not search_method.is_black_box:
self.search_method.get_model = lambda: self.goal_function.model
def clear_cache(self, recursive=True):
self.constraints_cache.clear()

View File

@@ -431,6 +431,40 @@ class AttackedText:
assert self.num_words == x.num_words
return float(np.sum(self.words != x.words)) / self.num_words
def align_with_model_tokens(self, model_wrapper):
"""Align AttackedText's `words` with target model's tokenization scheme
(e.g. word, character, subword). Specifically, we map each word to list
of indices of tokens that compose the word (e.g. embedding --> ["em",
"##bed", "##ding"])
Args:
model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model
Returns:
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
"""
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
word2token_mapping = {}
j = 0
last_matched = 0
for i, word in enumerate(self.words):
matched_tokens = []
while j < len(tokens) and len(word) > 0:
token = tokens[j].lower()
idx = word.find(token)
if idx == 0:
word = word[idx + len(token) :]
matched_tokens.append(j)
last_matched = j
j += 1
if not matched_tokens:
j = last_matched
else:
word2token_mapping[self.words[i]] = matched_tokens
return word2token_mapping
@property
def tokenizer_input(self):
"""The tuple of inputs to be passed to the tokenizer."""

View File

@@ -98,14 +98,6 @@ def load_textattack_model_from_path(model_name, model_path):
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("bert"):
model_path, num_labels = model_path
textattack.shared.logger.info(
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
)
model = textattack.models.helpers.BERTForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("t5"):
model = textattack.models.helpers.T5ForTextToText(model_path)
else:

View File

@@ -35,8 +35,8 @@ def words_from_text(s, words_to_ignore=[]):
for c in " ".join(s.split()):
if c.isalnum():
word += c
elif c in "'-" and len(word) > 0:
# Allow apostrophes and hyphens as long as they don't begin the
elif c in "'-_*@" and len(word) > 0:
# Allow apostrophes, hyphens, underscores, asterisks and at signs as long as they don't begin the
# word.
word += c
elif word:

View File

@@ -20,12 +20,12 @@ from . import logger
# A list of goal functions and the corresponding available models.
MODELS_BY_GOAL_FUNCTIONS = {
(TargetedClassification, UntargetedClassification, InputReduction): [
r"^textattack.models.lstm_for_classification.*",
r"^textattack.models.helpers.lstm_for_classification.*",
r"^textattack.models.helpers.word_cnn_for_classification.*",
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
],
(NonOverlappingOutput, MinimizeBleu,): [
r"^textattack.models.translation.*",
r"^textattack.models.summarization.*",
r"^textattack.models.helpers.t5_for_text_to_text.*",
],
}
@@ -51,7 +51,8 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
try:
matching_model_globs = MODELS_BY_GOAL_FUNCTION[goal_function_class]
except KeyError:
raise ValueError(f"No entry found for goal function {goal_function_class}.")
matching_model_globs = []
logger.warn(f"No entry found for goal function {goal_function_class}.")
# Get options for this goal function.
# model_module = model_class.__module__
model_module_path = ".".join((model_class.__module__, model_class.__name__))
@@ -61,28 +62,28 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
logger.info(
f"Goal function {goal_function_class} compatible with model {model_class.__name__}."
)
return True
return
# If we got here, the model does not match the intended goal function.
for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items():
for glob in globs:
if re.match(glob, model_module_path):
raise ValueError(
logger.warn(
f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}."
" Found match with other goal functions: {goal_functions}."
f" Found match with other goal functions: {goal_functions}."
)
# If it matches another goal function, throw an error.
return
# If it matches another goal function, warn user.
# Otherwise, this is an unknown modelperhaps user-provided, or we forgot to
# update the corresponding dictionary. Warn user and return.
logger.warn(
f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}."
)
return True
def validate_model_gradient_word_swap_compatibility(model):
"""Determines if ``model`` is task-compatible with
``radientBasedWordSwap``.
``GradientBasedWordSwap``.
We can only take the gradient with respect to an individual word if
the model uses a word-based tokenizer.

View File

@@ -52,13 +52,13 @@ class WordSwapChangeName(WordSwap):
replacement_words = []
tag = word_part_of_speech
if (
tag.value == "B-PER"
tag.value in ("B-PER", "S-PER")
and tag.score >= self.confidence_score
and not self.last_only
):
replacement_words = self._get_firstname(word)
elif (
tag.value == "E-PER"
tag.value in ("E-PER", "S-PER")
and tag.score >= self.confidence_score
and not self.first_only
):

View File

@@ -34,25 +34,15 @@ class WordSwapGradientBased(Transformation):
# Make sure we know how to compute the gradient for this model.
validate_model_gradient_word_swap_compatibility(self.model)
# Make sure this model has all of the required properties.
if not hasattr(self.model, "word_embeddings"):
if not hasattr(self.model, "get_input_embeddings"):
raise ValueError(
"Model needs word embedding matrix for gradient-based word swap"
)
if not hasattr(self.model, "lookup_table"):
raise ValueError("Model needs lookup table for gradient-based word swap")
if not hasattr(self.model, "zero_grad"):
raise ValueError("Model needs `zero_grad()` for gradient-based word swap")
if not hasattr(self.tokenizer, "convert_id_to_word"):
if not hasattr(self.tokenizer, "pad_token_id") and self.tokenizer.pad_token_id:
raise ValueError(
"Tokenizer needs `convert_id_to_word()` for gradient-based word swap"
"Tokenizer needs to have `pad_token_id` for gradient-based word swap"
)
if not hasattr(self.tokenizer, "pad_id"):
raise ValueError("Tokenizer needs `pad_id` for gradient-based word swap")
if not hasattr(self.tokenizer, "oov_id"):
raise ValueError("Tokenizer needs `oov_id` for gradient-based word swap")
self.loss = torch.nn.CrossEntropyLoss()
self.pad_id = self.model_wrapper.tokenizer.pad_id
self.oov_id = self.model_wrapper.tokenizer.oov_id
self.top_n = top_n
self.is_black_box = False
@@ -64,45 +54,28 @@ class WordSwapGradientBased(Transformation):
attacked_text (AttackedText): The full text input to perturb
word_index (int): index of the word to replace
"""
self.model.train()
self.model.emb_layer.embedding.weight.requires_grad = True
lookup_table = self.model.lookup_table.to(utils.device)
lookup_table_transpose = lookup_table.transpose(0, 1)
# get word IDs
text_ids = self.tokenizer.encode(attacked_text.tokenizer_input)
# set backward hook on the word embeddings for input x
emb_hook = Hook(self.model.word_embeddings, backward=True)
self.model.zero_grad()
predictions = self._call_model(text_ids)
original_label = predictions.argmax()
y_true = torch.Tensor([original_label]).long().to(utils.device)
loss = self.loss(predictions, y_true)
loss.backward()
# grad w.r.t to word embeddings
emb_grad = emb_hook.output[0].to(utils.device).squeeze()
lookup_table = self.model.get_input_embeddings().weight.data.cpu()
grad_output = self.model_wrapper.get_grad(attacked_text.tokenizer_input)
emb_grad = torch.tensor(grad_output["gradient"])
text_ids = grad_output["ids"]
# grad differences between all flips and original word (eq. 1 from paper)
vocab_size = lookup_table.size(0)
diffs = torch.zeros(len(indices_to_replace), vocab_size)
indices_to_replace = list(indices_to_replace)
for j, word_idx in enumerate(indices_to_replace):
# Make sure the word is in bounds.
if word_idx >= len(emb_grad):
continue
# Get the grad w.r.t the one-hot index of the word.
b_grads = (
emb_grad[word_idx].view(1, -1).mm(lookup_table_transpose).squeeze()
)
b_grads = lookup_table.mv(emb_grad[word_idx]).squeeze()
a_grad = b_grads[text_ids[word_idx]]
diffs[j] = b_grads - a_grad
# Don't change to the pad token.
diffs[:, self.tokenizer.pad_id] = float("-inf")
diffs[:, self.tokenizer.pad_token_id] = float("-inf")
# Find best indices within 2-d tensor by flattening.
word_idxs_sorted_by_grad = (-diffs).flatten().argsort()
@@ -121,17 +94,8 @@ class WordSwapGradientBased(Transformation):
if len(candidates) == self.top_n:
break
self.model.eval()
self.model.emb_layer.embedding.weight.requires_grad = (
self.model.emb_layer_trainable
)
return candidates
def _call_model(self, text_ids):
"""A helper function to query `self.model` with AttackedText `text`."""
model_input = torch.tensor([text_ids]).to(textattack.shared.utils.device)
return self.model(model_input)
def _get_transformations(self, attacked_text, indices_to_replace):
"""Returns a list of all possible transformations for `text`.
@@ -147,18 +111,3 @@ class WordSwapGradientBased(Transformation):
def extra_repr_keys(self):
return ["top_n"]
class Hook:
def __init__(self, module, backward=False):
if backward:
self.hook = module.register_backward_hook(self.hook_fn)
else:
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = [x.to(utils.device) for x in input]
self.output = [x.to(utils.device) for x in output]
def close(self):
self.hook.remove()