1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/models/wrappers/sklearn_model_wrapper.py
2020-07-31 15:35:03 -04:00

28 lines
768 B
Python

import pickle
import pandas as pd
import textattack
from .model_wrapper import ModelWrapper
class SklearnModelWrapper(ModelWrapper):
"""Loads a scikit-learn model and tokenizer (tokenizer implements
`transform` and model implements `predict_proba`).
May need to be extended and modified for different types of
tokenizers.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def __call__(self, text_input_list):
encoded_text_matrix = self.tokenizer.transform(text_input_list).toarray()
tokenized_text_df = pd.DataFrame(
encoded_text_matrix, columns=self.tokenizer.get_feature_names()
)
return self.model.predict_proba(tokenized_text_df)