1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

relax transformers packag requirement, remove BERTForSequenceClassifcation

This commit is contained in:
Jin Yong Yoo
2020-12-25 06:39:22 -05:00
parent 3a27cb0d36
commit d6cdae4091
4 changed files with 4 additions and 49 deletions

View File

@@ -7,12 +7,6 @@ textattack.models.helpers package
:show-inheritance:
.. automodule:: textattack.models.helpers.bert_for_classification
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.models.helpers.glove_embedding_layer
:members:
:undoc-members:

View File

@@ -11,9 +11,8 @@ numpy<1.19.0 #TF 2.0 requires this
pandas>=1.0.1
scipy==1.4.1
torch
transformers==3.3.0
transformers>=3.3.0
terminaltables
tokenizers==0.8.1-rc2
tqdm>=4.27,<4.50.0
word2number
num2words

View File

@@ -5,9 +5,10 @@ CoLA for Grammaticality
"""
import lru
import nltk
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModelForSequenceClassification
from textattack.constraints import Constraint
from textattack.models.tokenizers import AutoTokenizer
from textattack.models.wrappers import HuggingFaceModelWrapper
@@ -45,7 +46,7 @@ class COLA(Constraint):
self.model_name = model_name
self._reference_score_cache = lru.LRU(2 ** 10)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer(model_name)
self.model = HuggingFaceModelWrapper(model, tokenizer)
def clear_cache(self):

View File

@@ -1,39 +0,0 @@
"""
BERT Classification
^^^^^^^^^^^^^^^^^^^^^
"""
import torch
from transformers.modeling_bert import BertForSequenceClassification
from textattack.models.tokenizers import AutoTokenizer
from textattack.shared import utils
class BERTForClassification:
"""BERT fine-tuned for textual classification.
Args:
model_path(:obj:`string`): Path to the pre-trained model.
num_labels(:obj:`int`, optional): Number of class labels for
prediction, if different than 2.
"""
def __init__(self, model_path, num_labels=2):
model_file_path = utils.download_if_needed(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_file_path, num_labels=num_labels
)
self.model.to(utils.device)
self.model.eval()
self.tokenizer = AutoTokenizer(model_file_path)
def __call__(self, input_ids=None, **kwargs):
# The tokenizer will return ``input_ids`` along with ``token_type_ids``
# and an ``attention_mask``. Our pre-trained models only need the input
# IDs.
pred = self.model(input_ids=input_ids)[0]
return torch.nn.functional.softmax(pred, dim=-1)