mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix BERT-attack tokenization bug
This commit is contained in:
@@ -87,7 +87,6 @@ class WordSwapMaskedLM(WordSwap):
|
||||
|
||||
mask_token_probs = preds[0, masked_index]
|
||||
topk = torch.topk(mask_token_probs, self.max_candidates)
|
||||
# top_logits = topk[0].tolist()
|
||||
top_ids = topk[1].tolist()
|
||||
|
||||
replacement_words = []
|
||||
@@ -118,7 +117,7 @@ class WordSwapMaskedLM(WordSwap):
|
||||
)
|
||||
current_inputs = self._encode_text(masked_text.text)
|
||||
current_ids = current_inputs["input_ids"].tolist()[0]
|
||||
word_tokens = self._lm_tokenizer.tokenize(current_text.words[index])
|
||||
word_tokens = self._lm_tokenizer.encode(current_text.words[index], add_special_tokens=False)
|
||||
|
||||
try:
|
||||
# Need try-except b/c mask-token located past max_length might be truncated by tokenizer
|
||||
@@ -127,13 +126,9 @@ class WordSwapMaskedLM(WordSwap):
|
||||
return []
|
||||
|
||||
# List of indices of tokens that are part of the target word
|
||||
target_ids_pos = []
|
||||
for i in range(len(word_tokens)):
|
||||
loc = masked_index + i
|
||||
if loc < self.max_length:
|
||||
target_ids_pos.append(loc)
|
||||
target_ids_pos = list(range(masked_index, min(masked_index + len(word_tokens), self.max_length)))
|
||||
|
||||
if not target_ids_pos:
|
||||
if not len(target_ids_pos):
|
||||
return []
|
||||
elif len(target_ids_pos) == 1:
|
||||
# Word to replace is tokenized as a single word
|
||||
|
||||
Reference in New Issue
Block a user