""" T5 Tokenizer ^^^^^^^^^^^^^^^^^ """ from textattack.models.tokenizers import AutoTokenizer class T5Tokenizer(AutoTokenizer): """Uses the T5 tokenizer to convert an input for processing. For more information, please see the T5 paper, "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". Appendix D contains information about the various tasks supported by T5. Supports the following modes: * summarization: summarize English text * english_to_german: translate English to German * english_to_french: translate English to French * english_to_romanian: translate English to Romanian """ def __init__(self, mode="english_to_german", max_length=64): if mode == "english_to_german": self.tokenization_prefix = "translate English to German: " elif mode == "english_to_french": self.tokenization_prefix = "translate English to French: " elif mode == "english_to_romanian": self.tokenization_prefix = "translate English to Romanian: " elif mode == "summarization": self.tokenization_prefix = "summarize: " else: raise ValueError(f"Invalid t5 tokenizer mode {mode}.") super().__init__(tokenizer_path="t5-base", max_length=max_length) def encode(self, text): """Encodes a string into IDs of tokens. This prepares an input to be passed into T5. """ if isinstance(text, tuple): if len(text) > 1: raise ValueError( f"T5Tokenizer tuple inputs must have length 1; got {len(text)}" ) text = text[0] if not isinstance(text, str): raise TypeError(f"T5Tokenizer expects `str` input, got {type(text)}") text_to_encode = self.tokenization_prefix + text return super().encode(text_to_encode) def batch_encode(self, input_text_list): new_input_text_list = [] for text in input_text_list: if isinstance(text, tuple): if len(text) > 1: raise ValueError( f"T5Tokenizer tuple inputs must have length 1; got {len(text)}" ) text = text[0] new_input_text_list.append(self.tokenization_prefix + text) return super().batch_encode(new_input_text_list) def decode(self, ids): """Converts IDs (typically generated by the model) back to a string.""" return self.tokenizer.decode(ids)