1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix BERT-attack tokenization bug

This commit is contained in:
Jack Morris
2020-07-27 11:48:14 -04:00
parent 589113f82f
commit 46b0830508

View File

@@ -87,7 +87,6 @@ class WordSwapMaskedLM(WordSwap):
mask_token_probs = preds[0, masked_index]
topk = torch.topk(mask_token_probs, self.max_candidates)
# top_logits = topk[0].tolist()
top_ids = topk[1].tolist()
replacement_words = []
@@ -118,7 +117,7 @@ class WordSwapMaskedLM(WordSwap):
)
current_inputs = self._encode_text(masked_text.text)
current_ids = current_inputs["input_ids"].tolist()[0]
word_tokens = self._lm_tokenizer.tokenize(current_text.words[index])
word_tokens = self._lm_tokenizer.encode(current_text.words[index], add_special_tokens=False)
try:
# Need try-except b/c mask-token located past max_length might be truncated by tokenizer
@@ -127,13 +126,9 @@ class WordSwapMaskedLM(WordSwap):
return []
# List of indices of tokens that are part of the target word
target_ids_pos = []
for i in range(len(word_tokens)):
loc = masked_index + i
if loc < self.max_length:
target_ids_pos.append(loc)
target_ids_pos = list(range(masked_index, min(masked_index + len(word_tokens), self.max_length)))
if not target_ids_pos:
if not len(target_ids_pos):
return []
elif len(target_ids_pos) == 1:
# Word to replace is tokenized as a single word