mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
relax transformers packag requirement, remove BERTForSequenceClassifcation
This commit is contained in:
@@ -7,12 +7,6 @@ textattack.models.helpers package
|
|||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
|
||||||
.. automodule:: textattack.models.helpers.bert_for_classification
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
||||||
|
|
||||||
|
|
||||||
.. automodule:: textattack.models.helpers.glove_embedding_layer
|
.. automodule:: textattack.models.helpers.glove_embedding_layer
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
|||||||
@@ -11,9 +11,8 @@ numpy<1.19.0 #TF 2.0 requires this
|
|||||||
pandas>=1.0.1
|
pandas>=1.0.1
|
||||||
scipy==1.4.1
|
scipy==1.4.1
|
||||||
torch
|
torch
|
||||||
transformers==3.3.0
|
transformers>=3.3.0
|
||||||
terminaltables
|
terminaltables
|
||||||
tokenizers==0.8.1-rc2
|
|
||||||
tqdm>=4.27,<4.50.0
|
tqdm>=4.27,<4.50.0
|
||||||
word2number
|
word2number
|
||||||
num2words
|
num2words
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ CoLA for Grammaticality
|
|||||||
"""
|
"""
|
||||||
import lru
|
import lru
|
||||||
import nltk
|
import nltk
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
from textattack.constraints import Constraint
|
from textattack.constraints import Constraint
|
||||||
|
from textattack.models.tokenizers import AutoTokenizer
|
||||||
from textattack.models.wrappers import HuggingFaceModelWrapper
|
from textattack.models.wrappers import HuggingFaceModelWrapper
|
||||||
|
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ class COLA(Constraint):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._reference_score_cache = lru.LRU(2 ** 10)
|
self._reference_score_cache = lru.LRU(2 ** 10)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer(model_name)
|
||||||
self.model = HuggingFaceModelWrapper(model, tokenizer)
|
self.model = HuggingFaceModelWrapper(model, tokenizer)
|
||||||
|
|
||||||
def clear_cache(self):
|
def clear_cache(self):
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
"""
|
|
||||||
BERT Classification
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers.modeling_bert import BertForSequenceClassification
|
|
||||||
|
|
||||||
from textattack.models.tokenizers import AutoTokenizer
|
|
||||||
from textattack.shared import utils
|
|
||||||
|
|
||||||
|
|
||||||
class BERTForClassification:
|
|
||||||
"""BERT fine-tuned for textual classification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path(:obj:`string`): Path to the pre-trained model.
|
|
||||||
num_labels(:obj:`int`, optional): Number of class labels for
|
|
||||||
prediction, if different than 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_path, num_labels=2):
|
|
||||||
model_file_path = utils.download_if_needed(model_path)
|
|
||||||
self.model = BertForSequenceClassification.from_pretrained(
|
|
||||||
model_file_path, num_labels=num_labels
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model.to(utils.device)
|
|
||||||
self.model.eval()
|
|
||||||
self.tokenizer = AutoTokenizer(model_file_path)
|
|
||||||
|
|
||||||
def __call__(self, input_ids=None, **kwargs):
|
|
||||||
# The tokenizer will return ``input_ids`` along with ``token_type_ids``
|
|
||||||
# and an ``attention_mask``. Our pre-trained models only need the input
|
|
||||||
# IDs.
|
|
||||||
pred = self.model(input_ids=input_ids)[0]
|
|
||||||
return torch.nn.functional.softmax(pred, dim=-1)
|
|
||||||
Reference in New Issue
Block a user