1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

v0.2.5: tokenizer fixes

This commit is contained in:
Jack Morris
2020-08-17 16:15:46 -04:00
parent d13971813e
commit 7a2fde1520
7 changed files with 29 additions and 20 deletions

View File

@@ -22,7 +22,7 @@ copyright = "2020, UVA QData Lab"
author = "UVA QData Lab"
# The full version, including alpha/beta/rc tags
release = "0.2.4"
release = "0.2.5"
# Set master doc to `index.rst`.
master_doc = "index"

View File

@@ -9,7 +9,6 @@ nlp
nltk
numpy
pandas>=1.0.1
pyarrow<1.0
scikit-learn
scipy==1.4.1
sentence_transformers>0.2.6

View File

@@ -1,7 +1,7 @@
import torch
from transformers.modeling_bert import BertForSequenceClassification
from textattack.models.tokenizers import BERTTokenizer
from textattack.models.tokenizers import AutoTokenizer
from textattack.shared import utils
@@ -22,7 +22,7 @@ class BERTForClassification:
self.model.to(utils.device)
self.model.eval()
self.tokenizer = BERTTokenizer(model_file_path)
self.tokenizer = AutoTokenizer(model_file_path)
def __call__(self, input_ids=None, **kwargs):
# The tokenizer will return ``input_ids`` along with ``token_type_ids``

View File

@@ -1,4 +1,3 @@
from .auto_tokenizer import AutoTokenizer
from .bert_tokenizer import BERTTokenizer
from .glove_tokenizer import GloveTokenizer
from .t5_tokenizer import T5Tokenizer

View File

@@ -14,11 +14,23 @@ class AutoTokenizer:
"""
def __init__(
self, name="bert-base-uncased", max_length=256, use_fast=True,
self,
tokenizer_path="bert-base-uncased",
tokenizer=None,
max_length=256,
use_fast=True,
):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
name, use_fast=use_fast
)
if not (tokenizer_path or tokenizer):
raise ValueError("Must pass tokenizer path or tokenizer")
if tokenizer_path and tokenizer:
raise ValueError("Cannot pass both tokenizer path and tokenizer")
if tokenizer_path:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=use_fast
)
else:
self.tokenizer = tokenizer
self.max_length = max_length
self.save_pretrained = self.tokenizer.save_pretrained

View File

@@ -1,11 +0,0 @@
from textattack.models.tokenizers import AutoTokenizer
class BERTTokenizer(AutoTokenizer):
"""A generic class that convert text to tokens and tokens to IDs.
Intended for fine-tuned BERT models.
"""
def __init__(self, name="bert-base-uncased", max_length=256):
super().__init__(name, max_length=max_length)

View File

@@ -1,4 +1,5 @@
import torch
import transformers
import textattack
@@ -8,6 +9,13 @@ from .pytorch_model_wrapper import PyTorchModelWrapper
class HuggingFaceModelWrapper(PyTorchModelWrapper):
"""Loads a HuggingFace ``transformers`` model and tokenizer."""
def __init__(self, model, tokenizer, batch_size=32):
self.model = model
if isinstance(tokenizer, transformers.PreTrainedTokenizer):
tokenizer = textattack.models.tokenizers.AutoTokenizer(tokenizer=tokenizer)
self.tokenizer = tokenizer
self.batch_size = batch_size
def __call__(self, text_input_list):
"""Passes inputs to HuggingFace models as keyword arguments.
@@ -26,6 +34,8 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
k: torch.tensor(v).to(textattack.shared.utils.device)
for k, v in input_dict.items()
}
for k, v in input_dict.items():
break
outputs = self.model(**input_dict)
if isinstance(outputs[0], str):