1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix test outputs and bugs

This commit is contained in:
Jin Yong Yoo
2020-11-01 16:16:55 -05:00
parent 866301ba08
commit 8dad54a145
6 changed files with 51 additions and 25 deletions

View File

@@ -116,3 +116,15 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
output = {"ids": ids[0]["input_ids"], "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x)["input_ids"])
for x in inputs
]

View File

@@ -23,35 +23,36 @@ class ModelWrapper(ABC):
def encode(self, inputs):
"""Helper method that calls ``tokenizer.batch_encode`` if possible, and
if not, falls back to calling ``tokenizer.encode`` for each input.
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[int]]): List of list of ids
""""
"""
if hasattr(self.tokenizer, "batch_encode"):
return self.tokenizer.batch_encode(inputs)
else:
return [self.tokenizer.encode(x) for x in inputs]
def tokenize(self, inputs, strip=True):
"""Helper method that calls ``tokenizer.tokenize``.
def _tokenize(self, inputs):
"""Helper method for `tokenize`"""
raise NotImplementedError()
def tokenize(self, inputs, strip_prefix=False):
"""Helper method that tokenizes input strings
Args:
inputs (list[str]): list of input strings
strip (bool): If `True`, we strip auxiliary characters added to tokens (e.g. "##" for BERT, "Ġ" for RoBERTa)
strip_prefix (bool): If `True`, we strip auxiliary characters added to tokens as prefixes (e.g. "##" for BERT, "Ġ" for RoBERTa)
Returns:
tokens (list[list[str]]): List of list of tokens as strings
""""
tokens = [self.tokenizer.tokenize(x) for x in inputs]
if strip:
#`aux_chars` are known auxiliary characters that are added to tokens
"""
tokens = self._tokenize(inputs)
if strip_prefix:
# `aux_chars` are known auxiliary characters that are added to tokens
strip_chars = ["##", "Ġ", "__"]
# Try dummy string "aaaaaaaaaaaaaaaaaaaaaaaaaa" and identify possible prefix
# TODO: Find a better way to identify prefixes
strip_charas.append(self.tokenize(["aaaaaaaaaaaaaaaaaaaaaaaaaa"])[0][2].replace("a", ""))
# TODO: Find a better way to identify prefixes. These depend on the model, so cannot be resolved in ModelWrapper.
def strip(s, chars):
for c in chars:
s = s.replace(c, "")
@@ -59,4 +60,4 @@ class ModelWrapper(ABC):
tokens = [[strip(t, strip_chars) for t in x] for x in tokens]
return tokens
return tokens

View File

@@ -94,3 +94,15 @@ class PyTorchModelWrapper(ModelWrapper):
output = {"ids": ids[0].tolist(), "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [
self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x))
for x in inputs
]

View File

@@ -432,17 +432,19 @@ class AttackedText:
return float(np.sum(self.words != x.words)) / self.num_words
def align_with_model_tokens(self, model_wrapper):
"""Align AttackedText's `words` with target model's tokenization scheme (e.g. word, character, subword).
Specifically, we map each word to list of indices of tokens that compose the word (e.g. embedding --> ["em", "##bed", "##ding"])
"""Align AttackedText's `words` with target model's tokenization scheme
(e.g. word, character, subword). Specifically, we map each word to list
of indices of tokens that compose the word (e.g. embedding --> ["em",
"##bed", "##ding"])
Args:
model_wrapper (textattack.models.wrappers.ModelWrapper): ModelWrapper of the target model
Returns:
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
word2token_mapping (dict[str. list[int]]): Dictionary that maps word to list of indices.
"""
tokens = model_wrapper.tokenize([self.tokenizer_input], strip=True)[0]
word2token_mapping = {k:[] for k in self.words}
tokens = model_wrapper.tokenize([self.tokenizer_input], strip_prefix=True)[0]
word2token_mapping = {}
j = 0
last_matched = 0
for i, word in enumerate(self.words):
@@ -451,16 +453,15 @@ class AttackedText:
token = tokens[j].lower()
idx = word.find(token)
if idx == 0:
word = word[idx + len(token):]
word = word[idx + len(token) :]
matched_tokens.append(j)
last_matched = j
j += 1
if not matched_tokens:
# Reset j to most recent match
j = last_matched
else:
word2token_mapping[word] = matched_tokens
word2token_mapping[self.words[i]] = matched_tokens
return word2token_mapping