1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/models/tokenizers/auto_tokenizer.py
2020-07-01 00:48:56 -04:00

72 lines
2.8 KiB
Python

import torch
import transformers
from textattack.shared import AttackedText
class AutoTokenizer:
"""
A generic class that convert text to tokens and tokens to IDs. Supports
any type of tokenization, be it word, wordpiece, or character-based.
Based on the ``AutoTokenizer`` from the ``transformers`` library, but
standardizes the functionality for TextAttack.
Args:
name: the identifying name of the tokenizer (see AutoTokenizer,
https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_auto.py)
max_length: if set, will truncate & pad tokens to fit this length
"""
def __init__(
self, name="bert-base-uncased", max_length=256, use_fast=True,
):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
name, use_fast=use_fast
)
self.max_length = max_length
self.save_pretrained = self.tokenizer.save_pretrained
def encode(self, input_text):
""" Encodes ``input_text``.
``input_text`` may be a string or a tuple of strings, depending if the
model takes 1 or multiple inputs. The ``transformers.AutoTokenizer``
will automatically handle either case.
"""
if isinstance(input_text, str):
input_text = (input_text,)
encoded_text = self.tokenizer.encode_plus(
*input_text,
max_length=self.max_length,
add_special_tokens=True,
padding="max_length",
truncation=True,
)
return dict(encoded_text)
def batch_encode(self, input_text_list):
""" The batch equivalent of ``encode``."""
if hasattr(self.tokenizer, "batch_encode_plus"):
if isinstance(input_text_list[0], tuple) and len(input_text_list[0]) == 1:
# Unroll tuples of length 1.
input_text_list = [t[0] for t in input_text_list]
encodings = self.tokenizer.batch_encode_plus(
input_text_list,
truncation=True,
max_length=self.max_length,
add_special_tokens=True,
padding="max_length",
)
# Encodings is a `transformers.utils.BatchEncode` object, which
# is basically a big dictionary that contains a key for all input
# IDs, a key for all attention masks, etc.
dict_of_lists = {k: list(v) for k, v in encodings.data.items()}
list_of_dicts = [
{key: value[index] for key, value in dict_of_lists.items()}
for index in range(max(map(len, dict_of_lists.values())))
]
# We need to turn this dict of lists into a dict of lists.
return list_of_dicts
else:
return [self.encode(input_text) for input_text in input_text_list]