mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
from textattack.shared import utils
|
|
import torch
|
|
|
|
from textattack.tokenizers import BERTTokenizer, BERTEntailmentTokenizer
|
|
from transformers.modeling_bert import BertForSequenceClassification
|
|
|
|
|
|
class BERTForClassification:
|
|
"""
|
|
BERT fine-tuned for textual classification.
|
|
|
|
Args:
|
|
model_path(:obj:`string`): Path to the pre-trained model.
|
|
num_labels(:obj:`int`, optional): Number of class labels for
|
|
prediction, if different than 2.
|
|
|
|
"""
|
|
def __init__(self, model_path, num_labels=2, entailment=False):
|
|
model_file_path = utils.download_if_needed(model_path)
|
|
print('TextAttack BERTForClassification loading from path', model_file_path)
|
|
self.model = BertForSequenceClassification.from_pretrained(
|
|
model_file_path, num_labels=num_labels)
|
|
self.model.to(utils.get_device())
|
|
self.model.eval()
|
|
if entailment:
|
|
self.tokenizer = BERTEntailmentTokenizer()
|
|
else:
|
|
self.tokenizer = BERTTokenizer(model_file_path)
|
|
|
|
def __call__(self, *params):
|
|
pred = self.model(*params)[0]
|
|
return torch.nn.functional.softmax(pred, dim=-1)
|