mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix test outputs and bugs
This commit is contained in:
@@ -116,3 +116,15 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
|
||||
output = {"ids": ids[0]["input_ids"], "gradient": grad}
|
||||
|
||||
return output
|
||||
|
||||
def _tokenize(self, inputs):
|
||||
"""Helper method that for `tokenize`
|
||||
Args:
|
||||
inputs (list[str]): list of input strings
|
||||
Returns:
|
||||
tokens (list[list[str]]): List of list of tokens as strings
|
||||
"""
|
||||
return [
|
||||
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x)["input_ids"])
|
||||
for x in inputs
|
||||
]
|
||||
|
||||
@@ -23,35 +23,36 @@ class ModelWrapper(ABC):
|
||||
def encode(self, inputs):
|
||||
"""Helper method that calls ``tokenizer.batch_encode`` if possible, and
|
||||
if not, falls back to calling ``tokenizer.encode`` for each input.
|
||||
|
||||
|
||||
Args:
|
||||
inputs (list[str]): list of input strings
|
||||
|
||||
Returns:
|
||||
tokens (list[list[int]]): List of list of ids
|
||||
""""
|
||||
"""
|
||||
if hasattr(self.tokenizer, "batch_encode"):
|
||||
return self.tokenizer.batch_encode(inputs)
|
||||
else:
|
||||
return [self.tokenizer.encode(x) for x in inputs]
|
||||
|
||||
def tokenize(self, inputs, strip=True):
|
||||
"""Helper method that calls ``tokenizer.tokenize``.
|
||||
|
||||
def _tokenize(self, inputs):
|
||||
"""Helper method for `tokenize`"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def tokenize(self, inputs, strip_prefix=False):
|
||||
"""Helper method that tokenizes input strings
|
||||
Args:
|
||||
inputs (list[str]): list of input strings
|
||||
strip (bool): If `True`, we strip auxiliary characters added to tokens (e.g. "##" for BERT, "Ġ" for RoBERTa)
|
||||
strip_prefix (bool): If `True`, we strip auxiliary characters added to tokens as prefixes (e.g. "##" for BERT, "Ġ" for RoBERTa)
|
||||
Returns:
|
||||
tokens (list[list[str]]): List of list of tokens as strings
|
||||
""""
|
||||
tokens = [self.tokenizer.tokenize(x) for x in inputs]
|
||||
if strip:
|
||||
#`aux_chars` are known auxiliary characters that are added to tokens
|
||||
"""
|
||||
tokens = self._tokenize(inputs)
|
||||
if strip_prefix:
|
||||
# `aux_chars` are known auxiliary characters that are added to tokens
|
||||
strip_chars = ["##", "Ġ", "__"]
|
||||
# Try dummy string "aaaaaaaaaaaaaaaaaaaaaaaaaa" and identify possible prefix
|
||||
# TODO: Find a better way to identify prefixes
|
||||
strip_charas.append(self.tokenize(["aaaaaaaaaaaaaaaaaaaaaaaaaa"])[0][2].replace("a", ""))
|
||||
|
||||
# TODO: Find a better way to identify prefixes. These depend on the model, so cannot be resolved in ModelWrapper.
|
||||
|
||||
def strip(s, chars):
|
||||
for c in chars:
|
||||
s = s.replace(c, "")
|
||||
@@ -59,4 +60,4 @@ class ModelWrapper(ABC):
|
||||
|
||||
tokens = [[strip(t, strip_chars) for t in x] for x in tokens]
|
||||
|
||||
return tokens
|
||||
return tokens
|
||||
|
||||
@@ -94,3 +94,15 @@ class PyTorchModelWrapper(ModelWrapper):
|
||||
output = {"ids": ids[0].tolist(), "gradient": grad}
|
||||
|
||||
return output
|
||||
|
||||
def _tokenize(self, inputs):
|
||||
"""Helper method that for `tokenize`
|
||||
Args:
|
||||
inputs (list[str]): list of input strings
|
||||
Returns:
|
||||
tokens (list[list[str]]): List of list of tokens as strings
|
||||
"""
|
||||
return [
|
||||
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x))
|
||||
for x in inputs
|
||||
]
|
||||
|
||||
@@ -432,17 +432,19 @@ class AttackedText:
|
||||
return float(np.sum(self.words != x.words)) / self.num_words
|
||||
|
||||
def align_with_model_tokens(self, model_wrapper):
|
||||
"""Align AttackedText's `words` with target model's tokenization scheme (e.g. word, character, subword).
|
||||
Specifically, we map each word to list of indices of tokens that compose the word (e.g. embedding --> ["em", "##bed", "##ding"])
|
||||
"""Align AttackedText's `words` with target model's tokenization scheme
|
||||
(e.g. word, character, subword). Specifically, we map each word to list
|
||||
of indices of tokens that compose the word (e.g. embedding --> ["em",
|
||||
"##bed", "##ding"])
|
||||
|
||||
Args:
|
||||
model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model
|
||||
|
||||
Returns:
|
||||
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
|
||||
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
|
||||
"""
|
||||
tokens = model_wrapper.tokenize([self.tokenizer_input], strip=True)[0]
|
||||
word2token_mapping = {k:[] for k in self.words}
|
||||
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
|
||||
word2token_mapping = {}
|
||||
j = 0
|
||||
last_matched = 0
|
||||
for i, word in enumerate(self.words):
|
||||
@@ -451,16 +453,15 @@ class AttackedText:
|
||||
token = tokens[j].lower()
|
||||
idx = word.find(token)
|
||||
if idx == 0:
|
||||
word = word[idx + len(token):]
|
||||
word = word[idx + len(token) :]
|
||||
matched_tokens.append(j)
|
||||
last_matched = j
|
||||
j += 1
|
||||
|
||||
if not matched_tokens:
|
||||
# Reset j to most recent match
|
||||
j = last_matched
|
||||
else:
|
||||
word2token_mapping[word] = matched_tokens
|
||||
word2token_mapping[self.words[i]] = matched_tokens
|
||||
|
||||
return word2token_mapping
|
||||
|
||||
|
||||
Reference in New Issue
Block a user