mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
remove spacy; add 3 MR models
This commit is contained in:
@@ -2,13 +2,10 @@
|
||||
Tokenizers
|
||||
===========
|
||||
|
||||
.. automodule:: textattack.models.tokenizers.tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.tokenizers.auto_tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.tokenizers.spacy_tokenizer
|
||||
.. automodule:: textattack.models.tokenizers.glove_tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.tokenizers.t5_tokenizer
|
||||
|
||||
2
examples/augmentation/example.csv
Normal file
2
examples/augmentation/example.csv
Normal file
@@ -0,0 +1,2 @@
|
||||
"text",label
|
||||
"it's a mystery how the movie could be released in this condition .", 0
|
||||
|
@@ -10,7 +10,6 @@ pandas
|
||||
scikit-learn
|
||||
scipy==1.4.1
|
||||
sentence_transformers
|
||||
spacy
|
||||
torch
|
||||
transformers>=2.5.1
|
||||
tensorflow>=2
|
||||
|
||||
@@ -53,6 +53,10 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
"textattack/bert-base-uncased-WNLI",
|
||||
("glue", "wnli", "validation"),
|
||||
),
|
||||
"bert-base-uncased-mr": (
|
||||
"textattack/bert-base-uncased-rotten_tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
#
|
||||
# distilbert-base-cased
|
||||
#
|
||||
@@ -139,6 +143,17 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
"textattack/roberta-base-WNLI",
|
||||
("glue", "wnli", "validation"),
|
||||
),
|
||||
"roberta-base-mr": (
|
||||
"textattack/roberta-base-rotten_tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
#
|
||||
# albert-base-v2 (ALBERT is cased by default)
|
||||
#
|
||||
"albert-base-v2-mr": (
|
||||
"textattack/albert-base-v2-rotten_tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -170,10 +185,6 @@ TEXTATTACK_DATASET_BY_MODEL = {
|
||||
#
|
||||
# Text classification models
|
||||
#
|
||||
"bert-base-uncased-mr": (
|
||||
("models/classification/bert/mr-uncased", 2),
|
||||
("rotten_tomatoes", None, "train"),
|
||||
),
|
||||
"bert-base-cased-imdb": (
|
||||
("models/classification/bert/imdb-cased", 2),
|
||||
("imdb", None, "test"),
|
||||
|
||||
@@ -24,8 +24,11 @@ def make_directories(output_dir):
|
||||
|
||||
|
||||
def batch_encode(tokenizer, text_list):
|
||||
# TODO configure batch encoding to work with fast tokenizer
|
||||
return [tokenizer.encode(text_input) for text_input in text_list]
|
||||
if hasattr(tokenizer, 'batch_encode'):
|
||||
print('batch_encode')
|
||||
return tokenizer.batch_encode(text_list)
|
||||
else:
|
||||
return [tokenizer.encode(text_input) for text_input in text_list]
|
||||
|
||||
|
||||
def train_model(args):
|
||||
@@ -264,10 +267,10 @@ def train_model(args):
|
||||
return loss
|
||||
|
||||
for epoch in tqdm.trange(
|
||||
int(args.num_train_epochs), desc="Epoch", position=0, leave=True
|
||||
int(args.num_train_epochs), desc="Epoch", position=0, leave=False
|
||||
):
|
||||
prog_bar = tqdm.tqdm(
|
||||
train_dataloader, desc="Iteration", position=0, leave=False
|
||||
train_dataloader, desc="Iteration", position=1, leave=False
|
||||
)
|
||||
for step, batch in enumerate(prog_bar):
|
||||
input_ids, labels = batch
|
||||
|
||||
@@ -70,7 +70,7 @@ def dataset_from_args(args):
|
||||
args.dataset_dev_split = "validation"
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"Could not find `dev` or `test` split in dataset {args.dataset}."
|
||||
f"Could not find `dev`, `eval`, or `validation` split in dataset {args.dataset}."
|
||||
)
|
||||
eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
|
||||
|
||||
@@ -122,20 +122,22 @@ def write_readme(args, best_eval_score, best_eval_score_epoch):
|
||||
epoch_info = f"{best_eval_score_epoch} epoch" + (
|
||||
"s" if best_eval_score_epoch > 1 else ""
|
||||
)
|
||||
readme_text = f"""
|
||||
## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
|
||||
readme_text = \
|
||||
f"""
|
||||
## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
|
||||
|
||||
This `{args.model}` model was fine-tuned for sequence classificationusing TextAttack
|
||||
and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned
|
||||
for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning
|
||||
rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}.
|
||||
Since this was a {task_name} task, the model was trained with a {loss_func} loss function.
|
||||
The best score the model achieved on this task was {best_eval_score}, as measured by the
|
||||
eval set {metric_name}, found after {epoch_info}.
|
||||
This `{args.model}` model was fine-tuned for sequence classificationusing TextAttack
|
||||
and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned
|
||||
for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning
|
||||
rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}.
|
||||
Since this was a {task_name} task, the model was trained with a {loss_func} loss function.
|
||||
The best score the model achieved on this task was {best_eval_score}, as measured by the
|
||||
eval set {metric_name}, found after {epoch_info}.
|
||||
|
||||
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
|
||||
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
with open(readme_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(readme_text.strip() + "\n")
|
||||
logger.info(f"Wrote README to {readme_save_path}.")
|
||||
|
||||
@@ -41,8 +41,9 @@ class LSTMForClassification(nn.Module):
|
||||
)
|
||||
d_out = hidden_size
|
||||
self.out = nn.Linear(d_out, num_labels)
|
||||
self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
|
||||
self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
|
||||
self.tokenizer = textattack.models.tokenizers.GloveTokenizer(
|
||||
word_id_map=self.word2id, unk_token_id=self.emb_layer.oovid,
|
||||
pad_token_id=self.emb_layer.padid, max_length=max_seq_length
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
|
||||
@@ -32,8 +32,9 @@ class WordCNNForClassification(nn.Module):
|
||||
)
|
||||
d_out = 3 * hidden_size
|
||||
self.out = nn.Linear(d_out, num_labels)
|
||||
self.tokenizer = textattack.models.tokenizers.SpacyTokenizer(
|
||||
self.word2id, self.emb_layer.oovid, self.emb_layer.padid, max_seq_length
|
||||
self.tokenizer = textattack.models.tokenizers.GloveTokenizer(
|
||||
word_id_map=self.word2id, unk_token_id=self.emb_layer.oovid,
|
||||
pad_token_id=self.emb_layer.padid, max_length=max_seq_length
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
from .auto_tokenizer import AutoTokenizer
|
||||
from .bert_tokenizer import BERTTokenizer
|
||||
from .spacy_tokenizer import SpacyTokenizer
|
||||
from .glove_tokenizer import GloveTokenizer
|
||||
from .t5_tokenizer import T5Tokenizer
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from textattack.models.tokenizers import Tokenizer
|
||||
from textattack.shared import AttackedText
|
||||
|
||||
|
||||
class AutoTokenizer(Tokenizer):
|
||||
class AutoTokenizer:
|
||||
"""
|
||||
A generic class that convert text to tokens and tokens to IDs. Supports
|
||||
any type of tokenization, be it word, wordpiece, or character-based.
|
||||
Based on the ``AutoTokenizer`` from the ``transformers`` library.
|
||||
Based on the ``AutoTokenizer`` from the ``transformers`` library, but
|
||||
standardizes the functionality for TextAttack.
|
||||
|
||||
Args:
|
||||
name: the identifying name of the tokenizer (see AutoTokenizer,
|
||||
@@ -51,13 +51,20 @@ class AutoTokenizer(Tokenizer):
|
||||
def batch_encode(self, input_text_list):
|
||||
""" The batch equivalent of ``encode``."""
|
||||
if hasattr(self.tokenizer, "batch_encode_plus"):
|
||||
print("utilizing batch encode")
|
||||
return self.tokenizer.batch_encode_plus(
|
||||
encodings = self.tokenizer.batch_encode_plus(
|
||||
input_text_list,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
add_special_tokens=True,
|
||||
add_special_tokens=False,
|
||||
pad_to_max_length=True,
|
||||
)
|
||||
# Encodings is a `transformers.utils.BatchEncode` object, which
|
||||
# is basically a big dictionary that contains a key for all input
|
||||
# IDs, a key for all attention masks, etc.
|
||||
dict_of_lists = {k: list(v) for k, v in encodings.data.items()}
|
||||
list_of_dicts = [{key:value[index] for key,value in dict_of_lists.items()}
|
||||
for index in range(max(map(len, dict_of_lists.values())))]
|
||||
# We need to turn this dict of lists into a dict of lists.
|
||||
return list_of_dicts
|
||||
else:
|
||||
return [self.encode(input_text) for input_text in input_text_list]
|
||||
|
||||
126
textattack/models/tokenizers/glove_tokenizer.py
Normal file
126
textattack/models/tokenizers/glove_tokenizer.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import textattack
|
||||
import tempfile
|
||||
import tokenizers as hf_tokenizers
|
||||
|
||||
class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
|
||||
""" WordLevelTokenizer.
|
||||
|
||||
Represents a simple word level tokenization using the internals of BERT's
|
||||
tokenizer.
|
||||
|
||||
Based off the `tokenizers` BertWordPieceTokenizer (https://github.com/huggingface/tokenizers/blob/704cf3fdd2f607ead58a561b892b510b49c301db/bindings/python/tokenizers/implementations/bert_wordpiece.py).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
word_id_map = {},
|
||||
pad_token_id = None,
|
||||
unk_token_id = None,
|
||||
unk_token = "[UNK]",
|
||||
sep_token = "[SEP]",
|
||||
cls_token = "[CLS]",
|
||||
pad_token = "[PAD]",
|
||||
lowercase: bool = False,
|
||||
unicode_normalizer = None,
|
||||
):
|
||||
if pad_token_id:
|
||||
word_id_map[pad_token] = pad_token_id
|
||||
if unk_token_id:
|
||||
word_id_map[unk_token] = unk_token_id
|
||||
max_id = max(word_id_map.values())
|
||||
for idx, token in enumerate((unk_token, sep_token, cls_token, pad_token)):
|
||||
if token not in word_id_map:
|
||||
word_id_map[token] = max_id + idx
|
||||
# HuggingFace tokenizer expects a path to a `*.json` file to read the
|
||||
# vocab from. I think this is kind of a silly constraint, but for now
|
||||
# we write the vocab to a temporary file before initialization.
|
||||
word_list_file = tempfile.NamedTemporaryFile()
|
||||
word_list_file.write(json.dumps(word_id_map).encode())
|
||||
|
||||
word_level = hf_tokenizers.models.WordLevel(word_list_file.name, unk_token=str(unk_token))
|
||||
tokenizer = hf_tokenizers.Tokenizer(word_level)
|
||||
|
||||
# Let the tokenizer know about special tokens if they are part of the vocab
|
||||
if tokenizer.token_to_id(str(unk_token)) is not None:
|
||||
tokenizer.add_special_tokens([str(unk_token)])
|
||||
if tokenizer.token_to_id(str(sep_token)) is not None:
|
||||
tokenizer.add_special_tokens([str(sep_token)])
|
||||
if tokenizer.token_to_id(str(cls_token)) is not None:
|
||||
tokenizer.add_special_tokens([str(cls_token)])
|
||||
if tokenizer.token_to_id(str(pad_token)) is not None:
|
||||
tokenizer.add_special_tokens([str(pad_token)])
|
||||
|
||||
# Check for Unicode normalization first (before everything else)
|
||||
normalizers = []
|
||||
|
||||
if unicode_normalizer:
|
||||
normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
|
||||
|
||||
if lowercase:
|
||||
normalizers += [hf_tokenizers.normalizers.Lowercase()]
|
||||
|
||||
# Create the normalizer structure
|
||||
if len(normalizers) > 0:
|
||||
if len(normalizers) > 1:
|
||||
tokenizer.normalizer = hf_tokenizers.normalizers.Sequence(normalizers)
|
||||
else:
|
||||
tokenizer.normalizer = normalizers[0]
|
||||
|
||||
tokenizer.pre_tokenizer = hf_tokenizers.pre_tokenizers.WhitespaceSplit()
|
||||
|
||||
sep_token_id = tokenizer.token_to_id(str(sep_token))
|
||||
if sep_token_id is None:
|
||||
raise TypeError("sep_token not found in the vocabulary")
|
||||
cls_token_id = tokenizer.token_to_id(str(cls_token))
|
||||
if cls_token_id is None:
|
||||
raise TypeError("cls_token not found in the vocabulary")
|
||||
|
||||
tokenizer.post_processor = hf_tokenizers.processors.BertProcessing(
|
||||
(str(sep_token), sep_token_id), (str(cls_token), cls_token_id)
|
||||
)
|
||||
|
||||
parameters = {
|
||||
"model": "WordLevel",
|
||||
"unk_token": unk_token,
|
||||
"sep_token": sep_token,
|
||||
"cls_token": cls_token,
|
||||
"pad_token": pad_token,
|
||||
"lowercase": lowercase,
|
||||
"unicode_normalizer": unicode_normalizer,
|
||||
}
|
||||
|
||||
super().__init__(tokenizer, parameters)
|
||||
|
||||
|
||||
class GloveTokenizer(WordLevelTokenizer):
|
||||
""" A word-level tokenizer with GloVe 200-dimensional vectors.
|
||||
|
||||
Lowercased, since GloVe vectors are lowercased.
|
||||
"""
|
||||
def __init__(self, word_id_map={}, pad_token_id=None, unk_token_id=None, max_length=256):
|
||||
super().__init__(word_id_map=word_id_map, unk_token_id=unk_token_id,
|
||||
pad_token_id=pad_token_id, lowercase=True)
|
||||
print('pad_token_id:', pad_token_id)
|
||||
# Set defaults.
|
||||
self.enable_padding(max_length=max_length, pad_id=pad_token_id)
|
||||
self.enable_truncation(max_length=max_length)
|
||||
|
||||
def convert_id_to_word(word):
|
||||
""" Returns the `id` associated with `word`. If not found, returns
|
||||
None.
|
||||
"""
|
||||
return gt2.token_to_id(word)
|
||||
|
||||
def encode(self, text):
|
||||
return super().encode(text, add_special_tokens=False).ids
|
||||
|
||||
def batch_encode(self, input_text_list):
|
||||
""" The batch equivalent of ``encode``."""
|
||||
encodings = self.encode_batch(
|
||||
list(input_text_list),
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return [x.ids for x in encodings]
|
||||
@@ -1,55 +0,0 @@
|
||||
import spacy
|
||||
|
||||
from textattack.models.tokenizers import Tokenizer
|
||||
|
||||
|
||||
class SpacyTokenizer(Tokenizer):
|
||||
""" A basic implementation of the spaCy English tokenizer.
|
||||
|
||||
Params:
|
||||
word2id (dict<string, int>): A dictionary that matches words to IDs
|
||||
oov_id (int): An out-of-variable ID
|
||||
"""
|
||||
|
||||
def __init__(self, word2id, oov_id, pad_id, max_length=128):
|
||||
self.tokenizer = spacy.load("en").tokenizer
|
||||
self.word2id = word2id
|
||||
self.id2word = {v: k for k, v in word2id.items()}
|
||||
self.oov_id = oov_id
|
||||
self.pad_id = pad_id
|
||||
self.max_length = max_length
|
||||
|
||||
def convert_text_to_tokens(self, text):
|
||||
if isinstance(text, tuple):
|
||||
if len(text) > 1:
|
||||
raise TypeError(
|
||||
"Cannot train LSTM/CNN models with multi-sequence inputs."
|
||||
)
|
||||
text = text[0]
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(
|
||||
f"SpacyTokenizer can only tokenize `str`, got type {type(text)}"
|
||||
)
|
||||
spacy_tokens = [t.text for t in self.tokenizer(text)]
|
||||
return spacy_tokens[: self.max_length]
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
ids = []
|
||||
for raw_token in tokens:
|
||||
token = raw_token.lower()
|
||||
if token in self.word2id:
|
||||
ids.append(self.word2id[token])
|
||||
else:
|
||||
ids.append(self.oov_id)
|
||||
pad_ids_to_add = [self.pad_id] * (self.max_length - len(ids))
|
||||
ids += pad_ids_to_add
|
||||
return ids
|
||||
|
||||
def convert_id_to_word(self, _id):
|
||||
"""
|
||||
Takes an integer input and returns the corresponding word from the
|
||||
vocabulary.
|
||||
|
||||
Raises: KeyError on OOV.
|
||||
"""
|
||||
return self.id2word[_id]
|
||||
@@ -1,18 +0,0 @@
|
||||
class Tokenizer:
|
||||
"""
|
||||
A generic class that convert text to tokens and tokens to IDs. Supports
|
||||
any type of tokenization, be it word, wordpiece, or character-based.
|
||||
"""
|
||||
|
||||
def convert_text_to_tokens(self, text):
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_tokens_to_ids(self, ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode(self, text):
|
||||
"""
|
||||
Converts text directly to IDs.
|
||||
"""
|
||||
tokens = self.convert_text_to_tokens(text)
|
||||
return self.convert_tokens_to_ids(tokens)
|
||||
@@ -107,10 +107,6 @@ def _post_install():
|
||||
logger.info(
|
||||
"First time running textattack: downloading remaining required packages."
|
||||
)
|
||||
logger.info("Downloading spaCy required packages.")
|
||||
import spacy
|
||||
|
||||
spacy.cli.download("en")
|
||||
logger.info("Downloading NLTK required packages.")
|
||||
import nltk
|
||||
|
||||
|
||||
Reference in New Issue
Block a user