From d6cdae40915819a948a3f992990a18e2207ff631 Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Fri, 25 Dec 2020 06:39:22 -0500 Subject: [PATCH] relax transformers packag requirement, remove BERTForSequenceClassifcation --- docs/apidoc/textattack.models.helpers.rst | 6 --- requirements.txt | 3 +- textattack/constraints/grammaticality/cola.py | 5 ++- .../models/helpers/bert_for_classification.py | 39 ------------------- 4 files changed, 4 insertions(+), 49 deletions(-) delete mode 100644 textattack/models/helpers/bert_for_classification.py diff --git a/docs/apidoc/textattack.models.helpers.rst b/docs/apidoc/textattack.models.helpers.rst index 7f003bfb..9255d774 100644 --- a/docs/apidoc/textattack.models.helpers.rst +++ b/docs/apidoc/textattack.models.helpers.rst @@ -7,12 +7,6 @@ textattack.models.helpers package :show-inheritance: -.. automodule:: textattack.models.helpers.bert_for_classification - :members: - :undoc-members: - :show-inheritance: - - .. automodule:: textattack.models.helpers.glove_embedding_layer :members: :undoc-members: diff --git a/requirements.txt b/requirements.txt index dd05eb72..64996d07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,9 +11,8 @@ numpy<1.19.0 #TF 2.0 requires this pandas>=1.0.1 scipy==1.4.1 torch -transformers==3.3.0 +transformers>=3.3.0 terminaltables -tokenizers==0.8.1-rc2 tqdm>=4.27,<4.50.0 word2number num2words diff --git a/textattack/constraints/grammaticality/cola.py b/textattack/constraints/grammaticality/cola.py index dd01f3b0..58179616 100644 --- a/textattack/constraints/grammaticality/cola.py +++ b/textattack/constraints/grammaticality/cola.py @@ -5,9 +5,10 @@ CoLA for Grammaticality """ import lru import nltk -from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers import AutoModelForSequenceClassification from textattack.constraints import Constraint +from textattack.models.tokenizers import AutoTokenizer from textattack.models.wrappers import HuggingFaceModelWrapper @@ -45,7 +46,7 @@ class COLA(Constraint): self.model_name = model_name self._reference_score_cache = lru.LRU(2 ** 10) model = AutoModelForSequenceClassification.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer(model_name) self.model = HuggingFaceModelWrapper(model, tokenizer) def clear_cache(self): diff --git a/textattack/models/helpers/bert_for_classification.py b/textattack/models/helpers/bert_for_classification.py deleted file mode 100644 index 98afd84e..00000000 --- a/textattack/models/helpers/bert_for_classification.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -BERT Classification -^^^^^^^^^^^^^^^^^^^^^ - -""" - - -import torch -from transformers.modeling_bert import BertForSequenceClassification - -from textattack.models.tokenizers import AutoTokenizer -from textattack.shared import utils - - -class BERTForClassification: - """BERT fine-tuned for textual classification. - - Args: - model_path(:obj:`string`): Path to the pre-trained model. - num_labels(:obj:`int`, optional): Number of class labels for - prediction, if different than 2. - """ - - def __init__(self, model_path, num_labels=2): - model_file_path = utils.download_if_needed(model_path) - self.model = BertForSequenceClassification.from_pretrained( - model_file_path, num_labels=num_labels - ) - - self.model.to(utils.device) - self.model.eval() - self.tokenizer = AutoTokenizer(model_file_path) - - def __call__(self, input_ids=None, **kwargs): - # The tokenizer will return ``input_ids`` along with ``token_type_ids`` - # and an ``attention_mask``. Our pre-trained models only need the input - # IDs. - pred = self.model(input_ids=input_ids)[0] - return torch.nn.functional.softmax(pred, dim=-1)