mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
"""
|
|
T5 Tokenizer
|
|
^^^^^^^^^^^^^^^^^
|
|
|
|
"""
|
|
|
|
|
|
from textattack.models.tokenizers import AutoTokenizer
|
|
|
|
|
|
class T5Tokenizer(AutoTokenizer):
|
|
"""Uses the T5 tokenizer to convert an input for processing.
|
|
|
|
For more information, please see the T5 paper, "Exploring the Limits of
|
|
Transfer Learning with a Unified Text-to-Text Transformer".
|
|
Appendix D contains information about the various tasks supported
|
|
by T5.
|
|
|
|
Supports the following modes:
|
|
|
|
* summarization: summarize English text
|
|
* english_to_german: translate English to German
|
|
* english_to_french: translate English to French
|
|
* english_to_romanian: translate English to Romanian
|
|
"""
|
|
|
|
def __init__(self, mode="english_to_german", max_length=64):
|
|
if mode == "english_to_german":
|
|
self.tokenization_prefix = "translate English to German: "
|
|
elif mode == "english_to_french":
|
|
self.tokenization_prefix = "translate English to French: "
|
|
elif mode == "english_to_romanian":
|
|
self.tokenization_prefix = "translate English to Romanian: "
|
|
elif mode == "summarization":
|
|
self.tokenization_prefix = "summarize: "
|
|
else:
|
|
raise ValueError(f"Invalid t5 tokenizer mode {mode}.")
|
|
|
|
super().__init__(tokenizer_path="t5-base", max_length=max_length)
|
|
|
|
def encode(self, text):
|
|
"""Encodes a string into IDs of tokens.
|
|
|
|
This prepares an input to be passed into T5.
|
|
"""
|
|
if isinstance(text, tuple):
|
|
if len(text) > 1:
|
|
raise ValueError(
|
|
f"T5Tokenizer tuple inputs must have length 1; got {len(text)}"
|
|
)
|
|
text = text[0]
|
|
if not isinstance(text, str):
|
|
raise TypeError(f"T5Tokenizer expects `str` input, got {type(text)}")
|
|
text_to_encode = self.tokenization_prefix + text
|
|
return super().encode(text_to_encode)
|
|
|
|
def batch_encode(self, input_text_list):
|
|
new_input_text_list = []
|
|
for text in input_text_list:
|
|
if isinstance(text, tuple):
|
|
if len(text) > 1:
|
|
raise ValueError(
|
|
f"T5Tokenizer tuple inputs must have length 1; got {len(text)}"
|
|
)
|
|
text = text[0]
|
|
new_input_text_list.append(self.tokenization_prefix + text)
|
|
|
|
return super().batch_encode(new_input_text_list)
|
|
|
|
def decode(self, ids):
|
|
"""Converts IDs (typically generated by the model) back to a string."""
|
|
return self.tokenizer.decode(ids)
|