1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/models/tokenizers/auto_tokenizer.py
2020-11-01 00:58:16 -04:00

106 lines
3.8 KiB
Python

"""
AutoTokenizer
^^^^^^^^^^^^^^
"""
import transformers
class AutoTokenizer:
"""A generic class that convert text to tokens and tokens to IDs. Supports
any type of tokenization, be it word, wordpiece, or character-based. Based
on the ``AutoTokenizer`` from the ``transformers`` library, but
standardizes the functionality for TextAttack.
Args:
name: the identifying name of the tokenizer, for example, ``bert-base-uncased``
(see AutoTokenizer,
https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_auto.py)
max_length: if set, will truncate & pad tokens to fit this length
"""
def __init__(
self,
tokenizer_path=None,
tokenizer=None,
max_length=256,
use_fast=True,
):
if not (tokenizer_path or tokenizer):
raise ValueError("Must pass tokenizer path or tokenizer")
if tokenizer_path and tokenizer:
raise ValueError("Cannot pass both tokenizer path and tokenizer")
if tokenizer_path:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=use_fast
)
else:
self.tokenizer = tokenizer
self.max_length = max_length
self.save_pretrained = self.tokenizer.save_pretrained
def encode(self, input_text):
"""Encodes ``input_text``.
``input_text`` may be a string or a tuple of strings, depending
if the model takes 1 or multiple inputs. The
``transformers.AutoTokenizer`` will automatically handle either
case.
"""
if isinstance(input_text, str):
input_text = (input_text,)
encoded_text = self.tokenizer.encode_plus(
*input_text,
max_length=self.max_length,
add_special_tokens=True,
padding="max_length",
truncation=True,
)
return dict(encoded_text)
def batch_encode(self, input_text_list):
"""The batch equivalent of ``encode``."""
if hasattr(self.tokenizer, "batch_encode_plus"):
if isinstance(input_text_list[0], tuple) and len(input_text_list[0]) == 1:
# Unroll tuples of length 1.
input_text_list = [t[0] for t in input_text_list]
encodings = self.tokenizer.batch_encode_plus(
input_text_list,
truncation=True,
max_length=self.max_length,
add_special_tokens=True,
padding="max_length",
)
# Encodings is a `transformers.utils.BatchEncode` object, which
# is basically a big dictionary that contains a key for all input
# IDs, a key for all attention masks, etc.
dict_of_lists = {k: list(v) for k, v in encodings.data.items()}
list_of_dicts = [
{key: value[index] for key, value in dict_of_lists.items()}
for index in range(max(map(len, dict_of_lists.values())))
]
# We need to turn this dict of lists into a dict of lists.
return list_of_dicts
else:
return [self.encode(input_text) for input_text in input_text_list]
def convert_ids_to_tokens(self, ids):
return self.tokenizer.convert_ids_to_tokens(ids)
@property
def pad_token_id(self):
if hasattr(self.tokenizer, "pad_token_id"):
return self.tokenizer.pad_token_id
else:
raise AttributeError("Tokenizer does not have `pad_token_id` attribute.")
@property
def mask_token_id(self):
if hasattr(self.tokenizer, "mask_token_id"):
return self.tokenizer.mask_token_id
else:
raise AttributeError("Tokenizer does not have `mask_token_id` attribute.")