1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

make WordSwapGradientBased work with get_grad

This commit is contained in:
Jin Yong Yoo
2020-10-05 18:45:27 -04:00
parent 3ffd776b63
commit 6273b19c19
7 changed files with 43 additions and 73 deletions

View File

@@ -312,9 +312,7 @@ def parse_model_from_args(args):
)
# Choose the approprate model wrapper (based on whether or not this is
# a HuggingFace model).
if isinstance(
model, textattack.models.helpers.BERTForClassification
) or isinstance(model, textattack.models.helpers.T5ForTextToText):
if isinstance(model, textattack.models.helpers.T5ForTextToText):
model = textattack.models.wrappers.HuggingFaceModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)

View File

@@ -31,7 +31,7 @@ def run(args, checkpoint=None):
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Disable tensorflow logs, except in the case of an error.
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "4"
# Fix TensorFlow GPU memory growth
import tensorflow as tf

View File

@@ -86,3 +86,20 @@ class AutoTokenizer:
return list_of_dicts
else:
return [self.encode(input_text) for input_text in input_text_list]
def convert_ids_to_tokens(self, ids):
return self.tokenizer.convert_ids_to_tokens(ids)
@property
def pad_token_id(self):
if hasattr(self.tokenizer, "pad_token_id"):
return self.tokenizer.pad_token_id
else:
raise AttributeError("Tokenizer does not have `pad_token_id` attribute.")
@property
def mask_token_id(self):
if hasattr(self.tokenizer, "mask_token_id"):
return self.tokenizer.mask_token_id
else:
raise AttributeError("Tokenizer does not have `mask_token_id` attribute.")

View File

@@ -104,6 +104,9 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
"unicode_normalizer": unicode_normalizer,
}
self.unk_token = unk_token
self.pad_token = pad_token
super().__init__(tokenizer, parameters)
@@ -122,8 +125,8 @@ class GloveTokenizer(WordLevelTokenizer):
pad_token_id=pad_token_id,
lowercase=True,
)
self.pad_id = pad_token_id
self.oov_id = unk_token_id
self.pad_token_id = pad_token_id
self.oov_token_id = unk_token_id
self.convert_id_to_word = self.id_to_token
# Set defaults.
self.enable_padding(length=max_length, pad_id=pad_token_id)
@@ -156,3 +159,6 @@ class GloveTokenizer(WordLevelTokenizer):
add_special_tokens=False,
)
return [x.ids for x in encodings]
def convert_ids_to_tokens(self, ids):
return [self.convert_id_to_word(_id) for _id in ids]

View File

@@ -82,7 +82,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
def validate_model_gradient_word_swap_compatibility(model):
"""Determines if ``model`` is task-compatible with
``radientBasedWordSwap``.
``GradientBasedWordSwap``.
We can only take the gradient with respect to an individual word if
the model uses a word-based tokenizer.

View File

@@ -34,25 +34,15 @@ class WordSwapGradientBased(Transformation):
# Make sure we know how to compute the gradient for this model.
validate_model_gradient_word_swap_compatibility(self.model)
# Make sure this model has all of the required properties.
if not hasattr(self.model, "word_embeddings"):
if not hasattr(self.model, "get_input_embeddings"):
raise ValueError(
"Model needs word embedding matrix for gradient-based word swap"
)
if not hasattr(self.model, "lookup_table"):
raise ValueError("Model needs lookup table for gradient-based word swap")
if not hasattr(self.model, "zero_grad"):
raise ValueError("Model needs `zero_grad()` for gradient-based word swap")
if not hasattr(self.tokenizer, "convert_id_to_word"):
if not hasattr(self.tokenizer, "pad_token_id") and self.tokenizer.pad_token_id:
raise ValueError(
"Tokenizer needs `convert_id_to_word()` for gradient-based word swap"
"Tokenizer needs to have `pad_token_id` for gradient-based word swap"
)
if not hasattr(self.tokenizer, "pad_id"):
raise ValueError("Tokenizer needs `pad_id` for gradient-based word swap")
if not hasattr(self.tokenizer, "oov_id"):
raise ValueError("Tokenizer needs `oov_id` for gradient-based word swap")
self.loss = torch.nn.CrossEntropyLoss()
self.pad_id = self.model_wrapper.tokenizer.pad_id
self.oov_id = self.model_wrapper.tokenizer.oov_id
self.top_n = top_n
self.is_black_box = False
@@ -64,45 +54,28 @@ class WordSwapGradientBased(Transformation):
attacked_text (AttackedText): The full text input to perturb
word_index (int): index of the word to replace
"""
self.model.train()
self.model.emb_layer.embedding.weight.requires_grad = True
lookup_table = self.model.lookup_table.to(utils.device)
lookup_table_transpose = lookup_table.transpose(0, 1)
# get word IDs
text_ids = self.tokenizer.encode(attacked_text.tokenizer_input)
# set backward hook on the word embeddings for input x
emb_hook = Hook(self.model.word_embeddings, backward=True)
self.model.zero_grad()
predictions = self._call_model(text_ids)
original_label = predictions.argmax()
y_true = torch.Tensor([original_label]).long().to(utils.device)
loss = self.loss(predictions, y_true)
loss.backward()
# grad w.r.t to word embeddings
emb_grad = emb_hook.output[0].to(utils.device).squeeze()
lookup_table = self.model.get_input_embeddings().weight.data.cpu()
grad_output = self.model_wrapper.get_grad(attacked_text.tokenizer_input)
emb_grad = torch.tensor(grad_output["gradient"])
text_ids = grad_output["ids"]
# grad differences between all flips and original word (eq. 1 from paper)
vocab_size = lookup_table.size(0)
diffs = torch.zeros(len(indices_to_replace), vocab_size)
indices_to_replace = list(indices_to_replace)
for j, word_idx in enumerate(indices_to_replace):
# Make sure the word is in bounds.
if word_idx >= len(emb_grad):
continue
# Get the grad w.r.t the one-hot index of the word.
b_grads = (
emb_grad[word_idx].view(1, -1).mm(lookup_table_transpose).squeeze()
)
b_grads = lookup_table.mv(emb_grad[word_idx]).squeeze()
a_grad = b_grads[text_ids[word_idx]]
diffs[j] = b_grads - a_grad
# Don't change to the pad token.
diffs[:, self.tokenizer.pad_id] = float("-inf")
diffs[:, self.tokenizer.pad_token_id] = float("-inf")
# Find best indices within 2-d tensor by flattening.
word_idxs_sorted_by_grad = (-diffs).flatten().argsort()
@@ -121,17 +94,8 @@ class WordSwapGradientBased(Transformation):
if len(candidates) == self.top_n:
break
self.model.eval()
self.model.emb_layer.embedding.weight.requires_grad = (
self.model.emb_layer_trainable
)
return candidates
def _call_model(self, text_ids):
"""A helper function to query `self.model` with AttackedText `text`."""
model_input = torch.tensor([text_ids]).to(textattack.shared.utils.device)
return self.model(model_input)
def _get_transformations(self, attacked_text, indices_to_replace):
"""Returns a list of all possible transformations for `text`.
@@ -147,18 +111,3 @@ class WordSwapGradientBased(Transformation):
def extra_repr_keys(self):
return ["top_n"]
class Hook:
def __init__(self, module, backward=False):
if backward:
self.hook = module.register_backward_hook(self.hook_fn)
else:
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = [x.to(utils.device) for x in input]
self.output = [x.to(utils.device) for x in output]
def close(self):
self.hook.remove()