1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/models/tokenizers/t5_tokenizer.py
2020-06-16 23:52:03 -04:00

52 lines
2.0 KiB
Python

from textattack.models.tokenizers import AutoTokenizer
class T5Tokenizer(AutoTokenizer):
""" Uses the T5 tokenizer to convert an input for processing.
For more information, please see the T5 paper, "Exploring the Limits of
Transfer Learning with a Unified Text-to-Text Transformer".
Appendix D contains information about the various tasks supported
by T5.
Supports the following modes:
* summarization: summarize English text (CNN/Daily Mail dataset)
* english_to_german: translate English to German (WMT dataset)
* english_to_french: translate English to French (WMT dataset)
* english_to_romanian: translate English to Romanian (WMT dataset)
"""
def __init__(self, mode="english_to_german", max_length=64):
if mode == "english_to_german":
self.tokenization_prefix = "translate English to German: "
elif mode == "english_to_french":
self.tokenization_prefix = "translate English to French: "
elif mode == "english_to_romanian":
self.tokenization_prefix = "translate English to Romanian: "
elif mode == "summarization":
self.tokenization_prefix = "summarize: "
else:
raise ValueError(f"Invalid t5 tokenizer mode {english_to_german}.")
super().__init__(name="t5-base", max_length=max_length)
def encode(self, text):
"""
Encodes a string into IDs of tokens. This prepares an input to be
passed into T5.
"""
if isinstance(text, tuple):
text = text[0]
if not isinstance(text, str):
raise TypeError(f"T5Tokenizer expects `str` input, got {type(text)}")
text_to_encode = self.tokenization_prefix + text
return super().encode(text_to_encode)
def decode(self, ids):
"""
Converts IDs (typically generated by the model) back to a string.
"""
return self.tokenizer.decode(ids)