mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
"""
|
|
AutoTokenizer
|
|
^^^^^^^^^^^^^^
|
|
|
|
"""
|
|
|
|
|
|
import transformers
|
|
|
|
|
|
class AutoTokenizer:
|
|
"""A generic class that convert text to tokens and tokens to IDs. Supports
|
|
any type of tokenization, be it word, wordpiece, or character-based. Based
|
|
on the ``AutoTokenizer`` from the ``transformers`` library, but
|
|
standardizes the functionality for TextAttack.
|
|
|
|
Args:
|
|
name: the identifying name of the tokenizer, for example, ``bert-base-uncased``
|
|
(see AutoTokenizer,
|
|
https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_auto.py)
|
|
max_length: if set, will truncate & pad tokens to fit this length
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer_path=None,
|
|
tokenizer=None,
|
|
max_length=256,
|
|
use_fast=True,
|
|
):
|
|
if not (tokenizer_path or tokenizer):
|
|
raise ValueError("Must pass tokenizer path or tokenizer")
|
|
if tokenizer_path and tokenizer:
|
|
raise ValueError("Cannot pass both tokenizer path and tokenizer")
|
|
|
|
if tokenizer_path:
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
tokenizer_path, use_fast=use_fast
|
|
)
|
|
else:
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
self.save_pretrained = self.tokenizer.save_pretrained
|
|
|
|
def encode(self, input_text):
|
|
"""Encodes ``input_text``.
|
|
|
|
``input_text`` may be a string or a tuple of strings, depending
|
|
if the model takes 1 or multiple inputs. The
|
|
``transformers.AutoTokenizer`` will automatically handle either
|
|
case.
|
|
"""
|
|
if isinstance(input_text, str):
|
|
input_text = (input_text,)
|
|
encoded_text = self.tokenizer.encode_plus(
|
|
*input_text,
|
|
max_length=self.max_length,
|
|
add_special_tokens=True,
|
|
padding="max_length",
|
|
truncation=True,
|
|
)
|
|
return dict(encoded_text)
|
|
|
|
def batch_encode(self, input_text_list):
|
|
"""The batch equivalent of ``encode``."""
|
|
if hasattr(self.tokenizer, "batch_encode_plus"):
|
|
if isinstance(input_text_list[0], tuple) and len(input_text_list[0]) == 1:
|
|
# Unroll tuples of length 1.
|
|
input_text_list = [t[0] for t in input_text_list]
|
|
encodings = self.tokenizer.batch_encode_plus(
|
|
input_text_list,
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
add_special_tokens=True,
|
|
padding="max_length",
|
|
)
|
|
# Encodings is a `transformers.utils.BatchEncode` object, which
|
|
# is basically a big dictionary that contains a key for all input
|
|
# IDs, a key for all attention masks, etc.
|
|
dict_of_lists = {k: list(v) for k, v in encodings.data.items()}
|
|
list_of_dicts = [
|
|
{key: value[index] for key, value in dict_of_lists.items()}
|
|
for index in range(max(map(len, dict_of_lists.values())))
|
|
]
|
|
# We need to turn this dict of lists into a dict of lists.
|
|
return list_of_dicts
|
|
else:
|
|
return [self.encode(input_text) for input_text in input_text_list]
|
|
|
|
def convert_ids_to_tokens(self, ids):
|
|
return self.tokenizer.convert_ids_to_tokens(ids)
|
|
|
|
@property
|
|
def pad_token_id(self):
|
|
if hasattr(self.tokenizer, "pad_token_id"):
|
|
return self.tokenizer.pad_token_id
|
|
else:
|
|
raise AttributeError("Tokenizer does not have `pad_token_id` attribute.")
|
|
|
|
@property
|
|
def mask_token_id(self):
|
|
if hasattr(self.tokenizer, "mask_token_id"):
|
|
return self.tokenizer.mask_token_id
|
|
else:
|
|
raise AttributeError("Tokenizer does not have `mask_token_id` attribute.")
|