1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/models/helpers/bert_for_classification.py

35 lines
1.1 KiB
Python

import torch
from transformers.modeling_bert import BertForSequenceClassification
from textattack.models.tokenizers import BERTTokenizer
from textattack.shared import utils
class BERTForClassification:
"""
BERT fine-tuned for textual classification.
Args:
model_path(:obj:`string`): Path to the pre-trained model.
num_labels(:obj:`int`, optional): Number of class labels for
prediction, if different than 2.
"""
def __init__(self, model_path, num_labels=2):
model_file_path = utils.download_if_needed(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_file_path, num_labels=num_labels
)
self.model.to(utils.device)
self.model.eval()
self.tokenizer = BERTTokenizer(model_file_path)
def __call__(self, input_ids=None, **kwargs):
# The tokenizer will return ``input_ids`` along with ``token_type_ids``
# and an ``attention_mask``. Our pre-trained models only need the input
# IDs.
pred = self.model(input_ids=input_ids)[0]
return torch.nn.functional.softmax(pred, dim=-1)