mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
relax transformers packag requirement, remove BERTForSequenceClassifcation
This commit is contained in:
@@ -7,12 +7,6 @@ textattack.models.helpers package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. automodule:: textattack.models.helpers.bert_for_classification
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. automodule:: textattack.models.helpers.glove_embedding_layer
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
@@ -11,9 +11,8 @@ numpy<1.19.0 #TF 2.0 requires this
|
||||
pandas>=1.0.1
|
||||
scipy==1.4.1
|
||||
torch
|
||||
transformers==3.3.0
|
||||
transformers>=3.3.0
|
||||
terminaltables
|
||||
tokenizers==0.8.1-rc2
|
||||
tqdm>=4.27,<4.50.0
|
||||
word2number
|
||||
num2words
|
||||
|
||||
@@ -5,9 +5,10 @@ CoLA for Grammaticality
|
||||
"""
|
||||
import lru
|
||||
import nltk
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.models.tokenizers import AutoTokenizer
|
||||
from textattack.models.wrappers import HuggingFaceModelWrapper
|
||||
|
||||
|
||||
@@ -45,7 +46,7 @@ class COLA(Constraint):
|
||||
self.model_name = model_name
|
||||
self._reference_score_cache = lru.LRU(2 ** 10)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer(model_name)
|
||||
self.model = HuggingFaceModelWrapper(model, tokenizer)
|
||||
|
||||
def clear_cache(self):
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
"""
|
||||
BERT Classification
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from transformers.modeling_bert import BertForSequenceClassification
|
||||
|
||||
from textattack.models.tokenizers import AutoTokenizer
|
||||
from textattack.shared import utils
|
||||
|
||||
|
||||
class BERTForClassification:
|
||||
"""BERT fine-tuned for textual classification.
|
||||
|
||||
Args:
|
||||
model_path(:obj:`string`): Path to the pre-trained model.
|
||||
num_labels(:obj:`int`, optional): Number of class labels for
|
||||
prediction, if different than 2.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, num_labels=2):
|
||||
model_file_path = utils.download_if_needed(model_path)
|
||||
self.model = BertForSequenceClassification.from_pretrained(
|
||||
model_file_path, num_labels=num_labels
|
||||
)
|
||||
|
||||
self.model.to(utils.device)
|
||||
self.model.eval()
|
||||
self.tokenizer = AutoTokenizer(model_file_path)
|
||||
|
||||
def __call__(self, input_ids=None, **kwargs):
|
||||
# The tokenizer will return ``input_ids`` along with ``token_type_ids``
|
||||
# and an ``attention_mask``. Our pre-trained models only need the input
|
||||
# IDs.
|
||||
pred = self.model(input_ids=input_ids)[0]
|
||||
return torch.nn.functional.softmax(pred, dim=-1)
|
||||
Reference in New Issue
Block a user