mirror of
				https://github.com/QData/TextAttack.git
				synced 2021-10-13 00:05:06 +03:00 
			
		
		
		
	remove spacy; add 3 MR models
This commit is contained in:
		@@ -2,13 +2,10 @@
 | 
				
			|||||||
Tokenizers
 | 
					Tokenizers
 | 
				
			||||||
===========
 | 
					===========
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. automodule:: textattack.models.tokenizers.tokenizer
 | 
					 | 
				
			||||||
   :members:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
.. automodule:: textattack.models.tokenizers.auto_tokenizer
 | 
					.. automodule:: textattack.models.tokenizers.auto_tokenizer
 | 
				
			||||||
   :members:
 | 
					   :members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. automodule:: textattack.models.tokenizers.spacy_tokenizer
 | 
					.. automodule:: textattack.models.tokenizers.glove_tokenizer
 | 
				
			||||||
   :members:
 | 
					   :members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. automodule:: textattack.models.tokenizers.t5_tokenizer
 | 
					.. automodule:: textattack.models.tokenizers.t5_tokenizer
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								examples/augmentation/example.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								examples/augmentation/example.csv
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
				
			|||||||
 | 
					"text",label
 | 
				
			||||||
 | 
					"it's a mystery how the movie could be released in this condition .", 0
 | 
				
			||||||
		
		
			
  | 
@@ -10,7 +10,6 @@ pandas
 | 
				
			|||||||
scikit-learn
 | 
					scikit-learn
 | 
				
			||||||
scipy==1.4.1
 | 
					scipy==1.4.1
 | 
				
			||||||
sentence_transformers
 | 
					sentence_transformers
 | 
				
			||||||
spacy
 | 
					 | 
				
			||||||
torch
 | 
					torch
 | 
				
			||||||
transformers>=2.5.1
 | 
					transformers>=2.5.1
 | 
				
			||||||
tensorflow>=2
 | 
					tensorflow>=2
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -53,6 +53,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
 | 
				
			|||||||
        "textattack/bert-base-uncased-WNLI",
 | 
					        "textattack/bert-base-uncased-WNLI",
 | 
				
			||||||
        ("glue", "wnli", "validation"),
 | 
					        ("glue", "wnli", "validation"),
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
 | 
					    "bert-base-uncased-mr": (
 | 
				
			||||||
 | 
					        "textattack/bert-base-uncased-rotten_tomatoes",
 | 
				
			||||||
 | 
					        ("rotten_tomatoes", None, "test"),
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
    # distilbert-base-cased
 | 
					    # distilbert-base-cased
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
@@ -139,6 +143,17 @@ HUGGINGFACE_DATASET_BY_MODEL = {
 | 
				
			|||||||
        "textattack/roberta-base-WNLI",
 | 
					        "textattack/roberta-base-WNLI",
 | 
				
			||||||
        ("glue", "wnli", "validation"),
 | 
					        ("glue", "wnli", "validation"),
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
 | 
					    "roberta-base-mr": (
 | 
				
			||||||
 | 
					        "textattack/roberta-base-rotten_tomatoes",
 | 
				
			||||||
 | 
					        ("rotten_tomatoes", None, "test"),
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # albert-base-v2 (ALBERT is cased by default)
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    "albert-base-v2-mr": (
 | 
				
			||||||
 | 
					        "textattack/albert-base-v2-rotten_tomatoes",
 | 
				
			||||||
 | 
					        ("rotten_tomatoes", None, "test"),
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -170,10 +185,6 @@ TEXTATTACK_DATASET_BY_MODEL = {
 | 
				
			|||||||
    #
 | 
					    #
 | 
				
			||||||
    # Text classification models
 | 
					    # Text classification models
 | 
				
			||||||
    #
 | 
					    #
 | 
				
			||||||
    "bert-base-uncased-mr": (
 | 
					 | 
				
			||||||
        ("models/classification/bert/mr-uncased", 2),
 | 
					 | 
				
			||||||
        ("rotten_tomatoes", None, "train"),
 | 
					 | 
				
			||||||
    ),
 | 
					 | 
				
			||||||
    "bert-base-cased-imdb": (
 | 
					    "bert-base-cased-imdb": (
 | 
				
			||||||
        ("models/classification/bert/imdb-cased", 2),
 | 
					        ("models/classification/bert/imdb-cased", 2),
 | 
				
			||||||
        ("imdb", None, "test"),
 | 
					        ("imdb", None, "test"),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -24,8 +24,11 @@ def make_directories(output_dir):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def batch_encode(tokenizer, text_list):
 | 
					def batch_encode(tokenizer, text_list):
 | 
				
			||||||
    # TODO configure batch encoding to work with fast tokenizer
 | 
					    if hasattr(tokenizer, 'batch_encode'):
 | 
				
			||||||
    return [tokenizer.encode(text_input) for text_input in text_list]
 | 
					        print('batch_encode')
 | 
				
			||||||
 | 
					        return tokenizer.batch_encode(text_list)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return [tokenizer.encode(text_input) for text_input in text_list]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train_model(args):
 | 
					def train_model(args):
 | 
				
			||||||
@@ -264,10 +267,10 @@ def train_model(args):
 | 
				
			|||||||
        return loss
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for epoch in tqdm.trange(
 | 
					    for epoch in tqdm.trange(
 | 
				
			||||||
        int(args.num_train_epochs), desc="Epoch", position=0, leave=True
 | 
					        int(args.num_train_epochs), desc="Epoch", position=0, leave=False
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        prog_bar = tqdm.tqdm(
 | 
					        prog_bar = tqdm.tqdm(
 | 
				
			||||||
            train_dataloader, desc="Iteration", position=0, leave=False
 | 
					            train_dataloader, desc="Iteration", position=1, leave=False
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        for step, batch in enumerate(prog_bar):
 | 
					        for step, batch in enumerate(prog_bar):
 | 
				
			||||||
            input_ids, labels = batch
 | 
					            input_ids, labels = batch
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -70,7 +70,7 @@ def dataset_from_args(args):
 | 
				
			|||||||
                    args.dataset_dev_split = "validation"
 | 
					                    args.dataset_dev_split = "validation"
 | 
				
			||||||
                except KeyError:
 | 
					                except KeyError:
 | 
				
			||||||
                    raise KeyError(
 | 
					                    raise KeyError(
 | 
				
			||||||
                        f"Could not find `dev` or `test` split in dataset {args.dataset}."
 | 
					                        f"Could not find `dev`, `eval`, or `validation` split in dataset {args.dataset}."
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
    eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
 | 
					    eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -122,20 +122,22 @@ def write_readme(args, best_eval_score, best_eval_score_epoch):
 | 
				
			|||||||
    epoch_info = f"{best_eval_score_epoch} epoch" + (
 | 
					    epoch_info = f"{best_eval_score_epoch} epoch" + (
 | 
				
			||||||
        "s" if best_eval_score_epoch > 1 else ""
 | 
					        "s" if best_eval_score_epoch > 1 else ""
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    readme_text = f""" 
 | 
					    readme_text = \
 | 
				
			||||||
    ## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
 | 
					f""" 
 | 
				
			||||||
    
 | 
					## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
 | 
				
			||||||
    This `{args.model}` model was fine-tuned for sequence classificationusing TextAttack 
 | 
					
 | 
				
			||||||
    and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned 
 | 
					This `{args.model}` model was fine-tuned for sequence classificationusing TextAttack 
 | 
				
			||||||
    for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning 
 | 
					and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned 
 | 
				
			||||||
    rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}. 
 | 
					for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning 
 | 
				
			||||||
    Since this was a {task_name} task, the model was trained with a {loss_func} loss function. 
 | 
					rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}. 
 | 
				
			||||||
    The best score the model achieved on this task was {best_eval_score}, as measured by the 
 | 
					Since this was a {task_name} task, the model was trained with a {loss_func} loss function. 
 | 
				
			||||||
    eval set {metric_name}, found after {epoch_info}.
 | 
					The best score the model achieved on this task was {best_eval_score}, as measured by the 
 | 
				
			||||||
    
 | 
					eval set {metric_name}, found after {epoch_info}.
 | 
				
			||||||
    For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
 | 
					
 | 
				
			||||||
    
 | 
					For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
 | 
				
			||||||
    """
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open(readme_save_path, "w", encoding="utf-8") as f:
 | 
					    with open(readme_save_path, "w", encoding="utf-8") as f:
 | 
				
			||||||
        f.write(readme_text.strip() + "\n")
 | 
					        f.write(readme_text.strip() + "\n")
 | 
				
			||||||
    logger.info(f"Wrote README to {readme_save_path}.")
 | 
					    logger.info(f"Wrote README to {readme_save_path}.")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -41,8 +41,9 @@ class LSTMForClassification(nn.Module):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        d_out = hidden_size
 | 
					        d_out = hidden_size
 | 
				
			||||||
        self.out = nn.Linear(d_out, num_labels)
 | 
					        self.out = nn.Linear(d_out, num_labels)
 | 
				
			||||||
        self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
 | 
					        self.tokenizer = textattack.models.tokenizers.GloveTokenizer(
 | 
				
			||||||
            self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
 | 
					            word_id_map=self.word2id, unk_token_id=self.emb_layer.oovid, 
 | 
				
			||||||
 | 
					            pad_token_id=self.emb_layer.padid, max_length=max_seq_length
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if model_path is not None:
 | 
					        if model_path is not None:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -32,8 +32,9 @@ class WordCNNForClassification(nn.Module):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        d_out = 3 * hidden_size
 | 
					        d_out = 3 * hidden_size
 | 
				
			||||||
        self.out = nn.Linear(d_out, num_labels)
 | 
					        self.out = nn.Linear(d_out, num_labels)
 | 
				
			||||||
        self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
 | 
					        self.tokenizer = textattack.models.tokenizers.GloveTokenizer(
 | 
				
			||||||
            self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
 | 
					            word_id_map=self.word2id, unk_token_id=self.emb_layer.oovid, 
 | 
				
			||||||
 | 
					            pad_token_id=self.emb_layer.padid, max_length=max_seq_length
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if model_path is not None:
 | 
					        if model_path is not None:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,4 @@
 | 
				
			|||||||
from .tokenizer import Tokenizer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .auto_tokenizer import AutoTokenizer
 | 
					from .auto_tokenizer import AutoTokenizer
 | 
				
			||||||
from .bert_tokenizer import BERTTokenizer
 | 
					from .bert_tokenizer import BERTTokenizer
 | 
				
			||||||
from .spacy_tokenizer import SpacyTokenizer
 | 
					from .glove_tokenizer import GloveTokenizer
 | 
				
			||||||
from .t5_tokenizer import T5Tokenizer
 | 
					from .t5_tokenizer import T5Tokenizer
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,15 +1,15 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from textattack.models.tokenizers import Tokenizer
 | 
					 | 
				
			||||||
from textattack.shared import AttackedText
 | 
					from textattack.shared import AttackedText
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AutoTokenizer(Tokenizer):
 | 
					class AutoTokenizer:
 | 
				
			||||||
    """ 
 | 
					    """ 
 | 
				
			||||||
    A generic class that convert text to tokens and tokens to IDs. Supports
 | 
					    A generic class that convert text to tokens and tokens to IDs. Supports
 | 
				
			||||||
    any type of tokenization, be it word, wordpiece, or character-based.
 | 
					    any type of tokenization, be it word, wordpiece, or character-based.
 | 
				
			||||||
    Based on the ``AutoTokenizer`` from the ``transformers`` library.
 | 
					    Based on the ``AutoTokenizer`` from the ``transformers`` library, but 
 | 
				
			||||||
 | 
					    standardizes the functionality for TextAttack.
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    Args: 
 | 
					    Args: 
 | 
				
			||||||
        name: the identifying name of the tokenizer (see AutoTokenizer,
 | 
					        name: the identifying name of the tokenizer (see AutoTokenizer,
 | 
				
			||||||
@@ -51,13 +51,20 @@ class AutoTokenizer(Tokenizer):
 | 
				
			|||||||
    def batch_encode(self, input_text_list):
 | 
					    def batch_encode(self, input_text_list):
 | 
				
			||||||
        """ The batch equivalent of ``encode``."""
 | 
					        """ The batch equivalent of ``encode``."""
 | 
				
			||||||
        if hasattr(self.tokenizer, "batch_encode_plus"):
 | 
					        if hasattr(self.tokenizer, "batch_encode_plus"):
 | 
				
			||||||
            print("utilizing batch encode")
 | 
					            encodings = self.tokenizer.batch_encode_plus(
 | 
				
			||||||
            return self.tokenizer.batch_encode_plus(
 | 
					 | 
				
			||||||
                input_text_list,
 | 
					                input_text_list,
 | 
				
			||||||
                truncation=True,
 | 
					                truncation=True,
 | 
				
			||||||
                max_length=self.max_length,
 | 
					                max_length=self.max_length,
 | 
				
			||||||
                add_special_tokens=True,
 | 
					                add_special_tokens=False,
 | 
				
			||||||
                pad_to_max_length=True,
 | 
					                pad_to_max_length=True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            # Encodings is a `transformers.utils.BatchEncode` object, which
 | 
				
			||||||
 | 
					            # is basically a big dictionary that contains a key for all input
 | 
				
			||||||
 | 
					            # IDs, a key for all attention masks, etc.
 | 
				
			||||||
 | 
					            dict_of_lists = {k: list(v) for k, v in encodings.data.items()}
 | 
				
			||||||
 | 
					            list_of_dicts = [{key:value[index] for key,value in dict_of_lists.items()}
 | 
				
			||||||
 | 
					                for index in range(max(map(len, dict_of_lists.values())))]
 | 
				
			||||||
 | 
					            # We need to turn this dict of lists into a dict of lists.
 | 
				
			||||||
 | 
					            return list_of_dicts
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return [self.encode(input_text) for input_text in input_text_list]
 | 
					            return [self.encode(input_text) for input_text in input_text_list]
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										126
									
								
								textattack/models/tokenizers/glove_tokenizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								textattack/models/tokenizers/glove_tokenizer.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,126 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import textattack
 | 
				
			||||||
 | 
					import tempfile
 | 
				
			||||||
 | 
					import tokenizers as hf_tokenizers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
 | 
				
			||||||
 | 
					    """ WordLevelTokenizer. 
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    Represents a simple word level tokenization using the internals of BERT's
 | 
				
			||||||
 | 
					    tokenizer.
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    Based off the `tokenizers` BertWordPieceTokenizer (https://github.com/huggingface/tokenizers/blob/704cf3fdd2f607ead58a561b892b510b49c301db/bindings/python/tokenizers/implementations/bert_wordpiece.py).
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        word_id_map = {},
 | 
				
			||||||
 | 
					        pad_token_id = None,
 | 
				
			||||||
 | 
					        unk_token_id = None,
 | 
				
			||||||
 | 
					        unk_token = "[UNK]",
 | 
				
			||||||
 | 
					        sep_token = "[SEP]",
 | 
				
			||||||
 | 
					        cls_token = "[CLS]",
 | 
				
			||||||
 | 
					        pad_token = "[PAD]",
 | 
				
			||||||
 | 
					        lowercase: bool = False,
 | 
				
			||||||
 | 
					        unicode_normalizer = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if pad_token_id:
 | 
				
			||||||
 | 
					            word_id_map[pad_token] = pad_token_id
 | 
				
			||||||
 | 
					        if unk_token_id:
 | 
				
			||||||
 | 
					            word_id_map[unk_token] = unk_token_id
 | 
				
			||||||
 | 
					        max_id = max(word_id_map.values())
 | 
				
			||||||
 | 
					        for idx, token in enumerate((unk_token, sep_token, cls_token, pad_token)):
 | 
				
			||||||
 | 
					            if token not in word_id_map:
 | 
				
			||||||
 | 
					                word_id_map[token] = max_id + idx
 | 
				
			||||||
 | 
					        # HuggingFace tokenizer expects a path to a `*.json` file to read the
 | 
				
			||||||
 | 
					        # vocab from. I think this is kind of a silly constraint, but for now
 | 
				
			||||||
 | 
					        # we write the vocab to a temporary file before initialization.
 | 
				
			||||||
 | 
					        word_list_file = tempfile.NamedTemporaryFile()
 | 
				
			||||||
 | 
					        word_list_file.write(json.dumps(word_id_map).encode())
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        word_level = hf_tokenizers.models.WordLevel(word_list_file.name, unk_token=str(unk_token))
 | 
				
			||||||
 | 
					        tokenizer = hf_tokenizers.Tokenizer(word_level)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Let the tokenizer know about special tokens if they are part of the vocab
 | 
				
			||||||
 | 
					        if tokenizer.token_to_id(str(unk_token)) is not None:
 | 
				
			||||||
 | 
					            tokenizer.add_special_tokens([str(unk_token)])
 | 
				
			||||||
 | 
					        if tokenizer.token_to_id(str(sep_token)) is not None:
 | 
				
			||||||
 | 
					            tokenizer.add_special_tokens([str(sep_token)])
 | 
				
			||||||
 | 
					        if tokenizer.token_to_id(str(cls_token)) is not None:
 | 
				
			||||||
 | 
					            tokenizer.add_special_tokens([str(cls_token)])
 | 
				
			||||||
 | 
					        if tokenizer.token_to_id(str(pad_token)) is not None:
 | 
				
			||||||
 | 
					            tokenizer.add_special_tokens([str(pad_token)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Check for Unicode normalization first (before everything else)
 | 
				
			||||||
 | 
					        normalizers = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if unicode_normalizer:
 | 
				
			||||||
 | 
					            normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if lowercase:
 | 
				
			||||||
 | 
					            normalizers += [hf_tokenizers.normalizers.Lowercase()]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Create the normalizer structure
 | 
				
			||||||
 | 
					        if len(normalizers) > 0:
 | 
				
			||||||
 | 
					            if len(normalizers) > 1:
 | 
				
			||||||
 | 
					                tokenizer.normalizer = hf_tokenizers.normalizers.Sequence(normalizers)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                tokenizer.normalizer = normalizers[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tokenizer.pre_tokenizer = hf_tokenizers.pre_tokenizers.WhitespaceSplit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sep_token_id = tokenizer.token_to_id(str(sep_token))
 | 
				
			||||||
 | 
					        if sep_token_id is None:
 | 
				
			||||||
 | 
					            raise TypeError("sep_token not found in the vocabulary")
 | 
				
			||||||
 | 
					        cls_token_id = tokenizer.token_to_id(str(cls_token))
 | 
				
			||||||
 | 
					        if cls_token_id is None:
 | 
				
			||||||
 | 
					            raise TypeError("cls_token not found in the vocabulary")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tokenizer.post_processor = hf_tokenizers.processors.BertProcessing(
 | 
				
			||||||
 | 
					            (str(sep_token), sep_token_id), (str(cls_token), cls_token_id)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        parameters = {
 | 
				
			||||||
 | 
					            "model": "WordLevel",
 | 
				
			||||||
 | 
					            "unk_token": unk_token,
 | 
				
			||||||
 | 
					            "sep_token": sep_token,
 | 
				
			||||||
 | 
					            "cls_token": cls_token,
 | 
				
			||||||
 | 
					            "pad_token": pad_token,
 | 
				
			||||||
 | 
					            "lowercase": lowercase,
 | 
				
			||||||
 | 
					            "unicode_normalizer": unicode_normalizer,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__(tokenizer, parameters)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GloveTokenizer(WordLevelTokenizer):
 | 
				
			||||||
 | 
					    """ A word-level tokenizer with GloVe 200-dimensional vectors. 
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        Lowercased, since GloVe vectors are lowercased.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, word_id_map={}, pad_token_id=None, unk_token_id=None, max_length=256):
 | 
				
			||||||
 | 
					        super().__init__(word_id_map=word_id_map, unk_token_id=unk_token_id,
 | 
				
			||||||
 | 
					        pad_token_id=pad_token_id, lowercase=True)
 | 
				
			||||||
 | 
					        print('pad_token_id:', pad_token_id)
 | 
				
			||||||
 | 
					        # Set defaults.
 | 
				
			||||||
 | 
					        self.enable_padding(max_length=max_length, pad_id=pad_token_id)
 | 
				
			||||||
 | 
					        self.enable_truncation(max_length=max_length)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def convert_id_to_word(word):
 | 
				
			||||||
 | 
					        """ Returns the `id` associated with `word`. If not found, returns
 | 
				
			||||||
 | 
					            None. 
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return gt2.token_to_id(word)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def encode(self, text):
 | 
				
			||||||
 | 
					        return super().encode(text, add_special_tokens=False).ids
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def batch_encode(self, input_text_list):
 | 
				
			||||||
 | 
					        """ The batch equivalent of ``encode``."""
 | 
				
			||||||
 | 
					        encodings = self.encode_batch(
 | 
				
			||||||
 | 
					            list(input_text_list),
 | 
				
			||||||
 | 
					            add_special_tokens=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return [x.ids for x in encodings]
 | 
				
			||||||
@@ -1,55 +0,0 @@
 | 
				
			|||||||
import spacy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from textattack.models.tokenizers import Tokenizer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SpacyTokenizer(Tokenizer):
 | 
					 | 
				
			||||||
    """ A basic implementation of the spaCy English tokenizer. 
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
        Params:
 | 
					 | 
				
			||||||
            word2id (dict<string, int>): A dictionary that matches words to IDs
 | 
					 | 
				
			||||||
            oov_id (int): An out-of-variable ID
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, word2id, oov_id, pad_id, max_length=128):
 | 
					 | 
				
			||||||
        self.tokenizer = spacy.load("en").tokenizer
 | 
					 | 
				
			||||||
        self.word2id = word2id
 | 
					 | 
				
			||||||
        self.id2word = {v: k for k, v in word2id.items()}
 | 
					 | 
				
			||||||
        self.oov_id = oov_id
 | 
					 | 
				
			||||||
        self.pad_id = pad_id
 | 
					 | 
				
			||||||
        self.max_length = max_length
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def convert_text_to_tokens(self, text):
 | 
					 | 
				
			||||||
        if isinstance(text, tuple):
 | 
					 | 
				
			||||||
            if len(text) > 1:
 | 
					 | 
				
			||||||
                raise TypeError(
 | 
					 | 
				
			||||||
                    "Cannot train LSTM/CNN models with multi-sequence inputs."
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            text = text[0]
 | 
					 | 
				
			||||||
        if not isinstance(text, str):
 | 
					 | 
				
			||||||
            raise TypeError(
 | 
					 | 
				
			||||||
                f"SpacyTokenizer can only tokenize `str`, got type {type(text)}"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        spacy_tokens = [t.text for t in self.tokenizer(text)]
 | 
					 | 
				
			||||||
        return spacy_tokens[: self.max_length]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def convert_tokens_to_ids(self, tokens):
 | 
					 | 
				
			||||||
        ids = []
 | 
					 | 
				
			||||||
        for raw_token in tokens:
 | 
					 | 
				
			||||||
            token = raw_token.lower()
 | 
					 | 
				
			||||||
            if token in self.word2id:
 | 
					 | 
				
			||||||
                ids.append(self.word2id[token])
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                ids.append(self.oov_id)
 | 
					 | 
				
			||||||
        pad_ids_to_add = [self.pad_id] * (self.max_length - len(ids))
 | 
					 | 
				
			||||||
        ids += pad_ids_to_add
 | 
					 | 
				
			||||||
        return ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def convert_id_to_word(self, _id):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Takes an integer input and returns the corresponding word from the 
 | 
					 | 
				
			||||||
        vocabulary.
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        Raises: KeyError on OOV.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self.id2word[_id]
 | 
					 | 
				
			||||||
@@ -1,18 +0,0 @@
 | 
				
			|||||||
class Tokenizer:
 | 
					 | 
				
			||||||
    """ 
 | 
					 | 
				
			||||||
    A generic class that convert text to tokens and tokens to IDs. Supports
 | 
					 | 
				
			||||||
    any type of tokenization, be it word, wordpiece, or character-based.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def convert_text_to_tokens(self, text):
 | 
					 | 
				
			||||||
        raise NotImplementedError()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def convert_tokens_to_ids(self, ids):
 | 
					 | 
				
			||||||
        raise NotImplementedError()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def encode(self, text):
 | 
					 | 
				
			||||||
        """ 
 | 
					 | 
				
			||||||
        Converts text directly to IDs. 
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        tokens = self.convert_text_to_tokens(text)
 | 
					 | 
				
			||||||
        return self.convert_tokens_to_ids(tokens)
 | 
					 | 
				
			||||||
@@ -107,10 +107,6 @@ def _post_install():
 | 
				
			|||||||
    logger.info(
 | 
					    logger.info(
 | 
				
			||||||
        "First time running textattack: downloading remaining required packages."
 | 
					        "First time running textattack: downloading remaining required packages."
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    logger.info("Downloading spaCy required packages.")
 | 
					 | 
				
			||||||
    import spacy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    spacy.cli.download("en")
 | 
					 | 
				
			||||||
    logger.info("Downloading NLTK required packages.")
 | 
					    logger.info("Downloading NLTK required packages.")
 | 
				
			||||||
    import nltk
 | 
					    import nltk
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user