mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
81 lines
3.3 KiB
Python
81 lines
3.3 KiB
Python
"""
|
|
CoLA for Grammaticality
|
|
--------------------------
|
|
|
|
"""
|
|
import lru
|
|
import nltk
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
from textattack.constraints import Constraint
|
|
from textattack.models.wrappers import HuggingFaceModelWrapper
|
|
|
|
|
|
class COLA(Constraint):
|
|
"""Constrains an attack to text that has a similar number of linguistically
|
|
accecptable sentences as the original text. Linguistic acceptability is
|
|
determined by a model pre-trained on the `CoLA dataset <https://nyu-
|
|
mll.github.io/CoLA/>`_. By default a BERT model is used, see the `pre-
|
|
trained models README <https://github.com/QData/TextAttack/tree/master/
|
|
textattack/models>`_ for a full list of available models or provide your
|
|
own model from the huggingface model hub.
|
|
|
|
Args:
|
|
max_diff (float or int): The absolute (if int or greater than or equal to 1) or percent (if float and less than 1)
|
|
maximum difference allowed between the number of valid sentences in the reference
|
|
text and the number of valid sentences in the attacked text.
|
|
model_name (str): The name of the pre-trained model to use for classification. The model must be in huggingface model hub.
|
|
compare_against_original (bool): If `True`, compare against the original text.
|
|
Otherwise, compare against the most recent text.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_diff,
|
|
model_name="textattack/bert-base-uncased-CoLA",
|
|
compare_against_original=True,
|
|
):
|
|
super().__init__(compare_against_original)
|
|
if not isinstance(max_diff, float) and not isinstance(max_diff, int):
|
|
raise TypeError("max_diff must be a float or int")
|
|
if max_diff < 0.0:
|
|
raise ValueError("max_diff must be a value greater or equal to than 0.0")
|
|
|
|
self.max_diff = max_diff
|
|
self.model_name = model_name
|
|
self._reference_score_cache = lru.LRU(2 ** 10)
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.model = HuggingFaceModelWrapper(model, tokenizer)
|
|
|
|
def clear_cache(self):
|
|
self._reference_score_cache.clear()
|
|
|
|
def _check_constraint(self, transformed_text, reference_text):
|
|
if reference_text not in self._reference_score_cache:
|
|
# Split the text into sentences before predicting validity
|
|
reference_sentences = nltk.sent_tokenize(reference_text.text)
|
|
# A label of 1 indicates the sentence is valid
|
|
num_valid = self.model(reference_sentences).argmax(axis=1).sum()
|
|
self._reference_score_cache[reference_text] = num_valid
|
|
|
|
sentences = nltk.sent_tokenize(transformed_text.text)
|
|
predictions = self.model(sentences)
|
|
num_valid = predictions.argmax(axis=1).sum()
|
|
reference_score = self._reference_score_cache[reference_text]
|
|
|
|
if isinstance(self.max_diff, int) or self.max_diff >= 1:
|
|
threshold = reference_score - self.max_diff
|
|
else:
|
|
threshold = reference_score - (reference_score * self.max_diff)
|
|
|
|
if num_valid < threshold:
|
|
return False
|
|
return True
|
|
|
|
def extra_repr_keys(self):
|
|
return [
|
|
"max_diff",
|
|
"model_name",
|
|
] + super().extra_repr_keys()
|