mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
make WordSwapGradientBased work with get_grad
This commit is contained in:
@@ -312,9 +312,7 @@ def parse_model_from_args(args):
|
||||
)
|
||||
# Choose the approprate model wrapper (based on whether or not this is
|
||||
# a HuggingFace model).
|
||||
if isinstance(
|
||||
model, textattack.models.helpers.BERTForClassification
|
||||
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
|
||||
if isinstance(model, textattack.models.helpers.T5ForTextToText):
|
||||
model = textattack.models.wrappers.HuggingFaceModelWrapper(
|
||||
model, model.tokenizer, batch_size=args.model_batch_size
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ def run(args, checkpoint=None):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
# Disable tensorflow logs, except in the case of an error.
|
||||
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "4"
|
||||
# Fix TensorFlow GPU memory growth
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@@ -86,3 +86,20 @@ class AutoTokenizer:
|
||||
return list_of_dicts
|
||||
else:
|
||||
return [self.encode(input_text) for input_text in input_text_list]
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return self.tokenizer.convert_ids_to_tokens(ids)
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
if hasattr(self.tokenizer, "pad_token_id"):
|
||||
return self.tokenizer.pad_token_id
|
||||
else:
|
||||
raise AttributeError("Tokenizer does not have `pad_token_id` attribute.")
|
||||
|
||||
@property
|
||||
def mask_token_id(self):
|
||||
if hasattr(self.tokenizer, "mask_token_id"):
|
||||
return self.tokenizer.mask_token_id
|
||||
else:
|
||||
raise AttributeError("Tokenizer does not have `mask_token_id` attribute.")
|
||||
|
||||
@@ -104,6 +104,9 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
|
||||
"unicode_normalizer": unicode_normalizer,
|
||||
}
|
||||
|
||||
self.unk_token = unk_token
|
||||
self.pad_token = pad_token
|
||||
|
||||
super().__init__(tokenizer, parameters)
|
||||
|
||||
|
||||
@@ -122,8 +125,8 @@ class GloveTokenizer(WordLevelTokenizer):
|
||||
pad_token_id=pad_token_id,
|
||||
lowercase=True,
|
||||
)
|
||||
self.pad_id = pad_token_id
|
||||
self.oov_id = unk_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.oov_token_id = unk_token_id
|
||||
self.convert_id_to_word = self.id_to_token
|
||||
# Set defaults.
|
||||
self.enable_padding(length=max_length, pad_id=pad_token_id)
|
||||
@@ -156,3 +159,6 @@ class GloveTokenizer(WordLevelTokenizer):
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return [x.ids for x in encodings]
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return [self.convert_id_to_word(_id) for _id in ids]
|
||||
|
||||
@@ -82,7 +82,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
|
||||
|
||||
def validate_model_gradient_word_swap_compatibility(model):
|
||||
"""Determines if ``model`` is task-compatible with
|
||||
``radientBasedWordSwap``.
|
||||
``GradientBasedWordSwap``.
|
||||
|
||||
We can only take the gradient with respect to an individual word if
|
||||
the model uses a word-based tokenizer.
|
||||
|
||||
@@ -34,25 +34,15 @@ class WordSwapGradientBased(Transformation):
|
||||
# Make sure we know how to compute the gradient for this model.
|
||||
validate_model_gradient_word_swap_compatibility(self.model)
|
||||
# Make sure this model has all of the required properties.
|
||||
if not hasattr(self.model, "word_embeddings"):
|
||||
if not hasattr(self.model, "get_input_embeddings"):
|
||||
raise ValueError(
|
||||
"Model needs word embedding matrix for gradient-based word swap"
|
||||
)
|
||||
if not hasattr(self.model, "lookup_table"):
|
||||
raise ValueError("Model needs lookup table for gradient-based word swap")
|
||||
if not hasattr(self.model, "zero_grad"):
|
||||
raise ValueError("Model needs `zero_grad()` for gradient-based word swap")
|
||||
if not hasattr(self.tokenizer, "convert_id_to_word"):
|
||||
if not hasattr(self.tokenizer, "pad_token_id") and self.tokenizer.pad_token_id:
|
||||
raise ValueError(
|
||||
"Tokenizer needs `convert_id_to_word()` for gradient-based word swap"
|
||||
"Tokenizer needs to have `pad_token_id` for gradient-based word swap"
|
||||
)
|
||||
if not hasattr(self.tokenizer, "pad_id"):
|
||||
raise ValueError("Tokenizer needs `pad_id` for gradient-based word swap")
|
||||
if not hasattr(self.tokenizer, "oov_id"):
|
||||
raise ValueError("Tokenizer needs `oov_id` for gradient-based word swap")
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
self.pad_id = self.model_wrapper.tokenizer.pad_id
|
||||
self.oov_id = self.model_wrapper.tokenizer.oov_id
|
||||
|
||||
self.top_n = top_n
|
||||
self.is_black_box = False
|
||||
|
||||
@@ -64,45 +54,28 @@ class WordSwapGradientBased(Transformation):
|
||||
attacked_text (AttackedText): The full text input to perturb
|
||||
word_index (int): index of the word to replace
|
||||
"""
|
||||
self.model.train()
|
||||
self.model.emb_layer.embedding.weight.requires_grad = True
|
||||
|
||||
lookup_table = self.model.lookup_table.to(utils.device)
|
||||
lookup_table_transpose = lookup_table.transpose(0, 1)
|
||||
|
||||
# get word IDs
|
||||
text_ids = self.tokenizer.encode(attacked_text.tokenizer_input)
|
||||
|
||||
# set backward hook on the word embeddings for input x
|
||||
emb_hook = Hook(self.model.word_embeddings, backward=True)
|
||||
|
||||
self.model.zero_grad()
|
||||
predictions = self._call_model(text_ids)
|
||||
original_label = predictions.argmax()
|
||||
y_true = torch.Tensor([original_label]).long().to(utils.device)
|
||||
loss = self.loss(predictions, y_true)
|
||||
loss.backward()
|
||||
|
||||
# grad w.r.t to word embeddings
|
||||
emb_grad = emb_hook.output[0].to(utils.device).squeeze()
|
||||
lookup_table = self.model.get_input_embeddings().weight.data.cpu()
|
||||
|
||||
grad_output = self.model_wrapper.get_grad(attacked_text.tokenizer_input)
|
||||
emb_grad = torch.tensor(grad_output["gradient"])
|
||||
text_ids = grad_output["ids"]
|
||||
# grad differences between all flips and original word (eq. 1 from paper)
|
||||
vocab_size = lookup_table.size(0)
|
||||
diffs = torch.zeros(len(indices_to_replace), vocab_size)
|
||||
indices_to_replace = list(indices_to_replace)
|
||||
|
||||
for j, word_idx in enumerate(indices_to_replace):
|
||||
# Make sure the word is in bounds.
|
||||
if word_idx >= len(emb_grad):
|
||||
continue
|
||||
# Get the grad w.r.t the one-hot index of the word.
|
||||
b_grads = (
|
||||
emb_grad[word_idx].view(1, -1).mm(lookup_table_transpose).squeeze()
|
||||
)
|
||||
b_grads = lookup_table.mv(emb_grad[word_idx]).squeeze()
|
||||
a_grad = b_grads[text_ids[word_idx]]
|
||||
diffs[j] = b_grads - a_grad
|
||||
|
||||
# Don't change to the pad token.
|
||||
diffs[:, self.tokenizer.pad_id] = float("-inf")
|
||||
diffs[:, self.tokenizer.pad_token_id] = float("-inf")
|
||||
|
||||
# Find best indices within 2-d tensor by flattening.
|
||||
word_idxs_sorted_by_grad = (-diffs).flatten().argsort()
|
||||
@@ -121,17 +94,8 @@ class WordSwapGradientBased(Transformation):
|
||||
if len(candidates) == self.top_n:
|
||||
break
|
||||
|
||||
self.model.eval()
|
||||
self.model.emb_layer.embedding.weight.requires_grad = (
|
||||
self.model.emb_layer_trainable
|
||||
)
|
||||
return candidates
|
||||
|
||||
def _call_model(self, text_ids):
|
||||
"""A helper function to query `self.model` with AttackedText `text`."""
|
||||
model_input = torch.tensor([text_ids]).to(textattack.shared.utils.device)
|
||||
return self.model(model_input)
|
||||
|
||||
def _get_transformations(self, attacked_text, indices_to_replace):
|
||||
"""Returns a list of all possible transformations for `text`.
|
||||
|
||||
@@ -147,18 +111,3 @@ class WordSwapGradientBased(Transformation):
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["top_n"]
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, module, backward=False):
|
||||
if backward:
|
||||
self.hook = module.register_backward_hook(self.hook_fn)
|
||||
else:
|
||||
self.hook = module.register_forward_hook(self.hook_fn)
|
||||
|
||||
def hook_fn(self, module, input, output):
|
||||
self.input = [x.to(utils.device) for x in input]
|
||||
self.output = [x.to(utils.device) for x in output]
|
||||
|
||||
def close(self):
|
||||
self.hook.remove()
|
||||
|
||||
Reference in New Issue
Block a user