1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

relax transformers packag requirement, remove BERTForSequenceClassifcation

This commit is contained in:
Jin Yong Yoo
2020-12-25 06:39:22 -05:00
parent 3a27cb0d36
commit d6cdae4091
4 changed files with 4 additions and 49 deletions

View File

@@ -7,12 +7,6 @@ textattack.models.helpers package
:show-inheritance: :show-inheritance:
.. automodule:: textattack.models.helpers.bert_for_classification
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.models.helpers.glove_embedding_layer .. automodule:: textattack.models.helpers.glove_embedding_layer
:members: :members:
:undoc-members: :undoc-members:

View File

@@ -11,9 +11,8 @@ numpy<1.19.0 #TF 2.0 requires this
pandas>=1.0.1 pandas>=1.0.1
scipy==1.4.1 scipy==1.4.1
torch torch
transformers==3.3.0 transformers>=3.3.0
terminaltables terminaltables
tokenizers==0.8.1-rc2
tqdm>=4.27,<4.50.0 tqdm>=4.27,<4.50.0
word2number word2number
num2words num2words

View File

@@ -5,9 +5,10 @@ CoLA for Grammaticality
""" """
import lru import lru
import nltk import nltk
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification
from textattack.constraints import Constraint from textattack.constraints import Constraint
from textattack.models.tokenizers import AutoTokenizer
from textattack.models.wrappers import HuggingFaceModelWrapper from textattack.models.wrappers import HuggingFaceModelWrapper
@@ -45,7 +46,7 @@ class COLA(Constraint):
self.model_name = model_name self.model_name = model_name
self._reference_score_cache = lru.LRU(2 ** 10) self._reference_score_cache = lru.LRU(2 ** 10)
model = AutoModelForSequenceClassification.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer(model_name)
self.model = HuggingFaceModelWrapper(model, tokenizer) self.model = HuggingFaceModelWrapper(model, tokenizer)
def clear_cache(self): def clear_cache(self):

View File

@@ -1,39 +0,0 @@
"""
BERT Classification
^^^^^^^^^^^^^^^^^^^^^
"""
import torch
from transformers.modeling_bert import BertForSequenceClassification
from textattack.models.tokenizers import AutoTokenizer
from textattack.shared import utils
class BERTForClassification:
"""BERT fine-tuned for textual classification.
Args:
model_path(:obj:`string`): Path to the pre-trained model.
num_labels(:obj:`int`, optional): Number of class labels for
prediction, if different than 2.
"""
def __init__(self, model_path, num_labels=2):
model_file_path = utils.download_if_needed(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_file_path, num_labels=num_labels
)
self.model.to(utils.device)
self.model.eval()
self.tokenizer = AutoTokenizer(model_file_path)
def __call__(self, input_ids=None, **kwargs):
# The tokenizer will return ``input_ids`` along with ``token_type_ids``
# and an ``attention_mask``. Our pre-trained models only need the input
# IDs.
pred = self.model(input_ids=input_ids)[0]
return torch.nn.functional.softmax(pred, dim=-1)