From 46b08305087905587d6ff914c7815179b743c974 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Mon, 27 Jul 2020 11:48:14 -0400 Subject: [PATCH] fix BERT-attack tokenization bug --- textattack/transformations/word_swap_masked_lm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/textattack/transformations/word_swap_masked_lm.py b/textattack/transformations/word_swap_masked_lm.py index 85e501df..73c01a4b 100644 --- a/textattack/transformations/word_swap_masked_lm.py +++ b/textattack/transformations/word_swap_masked_lm.py @@ -87,7 +87,6 @@ class WordSwapMaskedLM(WordSwap): mask_token_probs = preds[0, masked_index] topk = torch.topk(mask_token_probs, self.max_candidates) - # top_logits = topk[0].tolist() top_ids = topk[1].tolist() replacement_words = [] @@ -118,7 +117,7 @@ class WordSwapMaskedLM(WordSwap): ) current_inputs = self._encode_text(masked_text.text) current_ids = current_inputs["input_ids"].tolist()[0] - word_tokens = self._lm_tokenizer.tokenize(current_text.words[index]) + word_tokens = self._lm_tokenizer.encode(current_text.words[index], add_special_tokens=False) try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer @@ -127,13 +126,9 @@ class WordSwapMaskedLM(WordSwap): return [] # List of indices of tokens that are part of the target word - target_ids_pos = [] - for i in range(len(word_tokens)): - loc = masked_index + i - if loc < self.max_length: - target_ids_pos.append(loc) + target_ids_pos = list(range(masked_index, min(masked_index + len(word_tokens), self.max_length))) - if not target_ids_pos: + if not len(target_ids_pos): return [] elif len(target_ids_pos) == 1: # Word to replace is tokenized as a single word