mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
v0.2.5: tokenizer fixes
This commit is contained in:
@@ -22,7 +22,7 @@ copyright = "2020, UVA QData Lab"
|
||||
author = "UVA QData Lab"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = "0.2.4"
|
||||
release = "0.2.5"
|
||||
|
||||
# Set master doc to `index.rst`.
|
||||
master_doc = "index"
|
||||
|
||||
@@ -9,7 +9,6 @@ nlp
|
||||
nltk
|
||||
numpy
|
||||
pandas>=1.0.1
|
||||
pyarrow<1.0
|
||||
scikit-learn
|
||||
scipy==1.4.1
|
||||
sentence_transformers>0.2.6
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from transformers.modeling_bert import BertForSequenceClassification
|
||||
|
||||
from textattack.models.tokenizers import BERTTokenizer
|
||||
from textattack.models.tokenizers import AutoTokenizer
|
||||
from textattack.shared import utils
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class BERTForClassification:
|
||||
|
||||
self.model.to(utils.device)
|
||||
self.model.eval()
|
||||
self.tokenizer = BERTTokenizer(model_file_path)
|
||||
self.tokenizer = AutoTokenizer(model_file_path)
|
||||
|
||||
def __call__(self, input_ids=None, **kwargs):
|
||||
# The tokenizer will return ``input_ids`` along with ``token_type_ids``
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .auto_tokenizer import AutoTokenizer
|
||||
from .bert_tokenizer import BERTTokenizer
|
||||
from .glove_tokenizer import GloveTokenizer
|
||||
from .t5_tokenizer import T5Tokenizer
|
||||
|
||||
@@ -14,11 +14,23 @@ class AutoTokenizer:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name="bert-base-uncased", max_length=256, use_fast=True,
|
||||
self,
|
||||
tokenizer_path="bert-base-uncased",
|
||||
tokenizer=None,
|
||||
max_length=256,
|
||||
use_fast=True,
|
||||
):
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
name, use_fast=use_fast
|
||||
)
|
||||
if not (tokenizer_path or tokenizer):
|
||||
raise ValueError("Must pass tokenizer path or tokenizer")
|
||||
if tokenizer_path and tokenizer:
|
||||
raise ValueError("Cannot pass both tokenizer path and tokenizer")
|
||||
|
||||
if tokenizer_path:
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, use_fast=use_fast
|
||||
)
|
||||
else:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.save_pretrained = self.tokenizer.save_pretrained
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from textattack.models.tokenizers import AutoTokenizer
|
||||
|
||||
|
||||
class BERTTokenizer(AutoTokenizer):
|
||||
"""A generic class that convert text to tokens and tokens to IDs.
|
||||
|
||||
Intended for fine-tuned BERT models.
|
||||
"""
|
||||
|
||||
def __init__(self, name="bert-base-uncased", max_length=256):
|
||||
super().__init__(name, max_length=max_length)
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import textattack
|
||||
|
||||
@@ -8,6 +9,13 @@ from .pytorch_model_wrapper import PyTorchModelWrapper
|
||||
class HuggingFaceModelWrapper(PyTorchModelWrapper):
|
||||
"""Loads a HuggingFace ``transformers`` model and tokenizer."""
|
||||
|
||||
def __init__(self, model, tokenizer, batch_size=32):
|
||||
self.model = model
|
||||
if isinstance(tokenizer, transformers.PreTrainedTokenizer):
|
||||
tokenizer = textattack.models.tokenizers.AutoTokenizer(tokenizer=tokenizer)
|
||||
self.tokenizer = tokenizer
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, text_input_list):
|
||||
"""Passes inputs to HuggingFace models as keyword arguments.
|
||||
|
||||
@@ -26,6 +34,8 @@ class HuggingFaceModelWrapper(PyTorchModelWrapper):
|
||||
k: torch.tensor(v).to(textattack.shared.utils.device)
|
||||
for k, v in input_dict.items()
|
||||
}
|
||||
for k, v in input_dict.items():
|
||||
break
|
||||
outputs = self.model(**input_dict)
|
||||
|
||||
if isinstance(outputs[0], str):
|
||||
|
||||
Reference in New Issue
Block a user