mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Add CoLA based constraint
Added a constraint that uses a model pre-trained on the CoLA dataset to check that the attacked text has a certain number of linguistically acceptable sentences as compared to the reference text
This commit is contained in:
@@ -111,6 +111,11 @@ Part of Speech
|
||||
.. automodule:: textattack.constraints.grammaticality.part_of_speech
|
||||
:members:
|
||||
|
||||
CoLA
|
||||
######
|
||||
.. automodule:: textattack.constraints.grammaticality.cola
|
||||
:members:
|
||||
|
||||
.. _overlap:
|
||||
|
||||
Overlap
|
||||
|
||||
59
tests/sample_outputs/run_attack_cnn_cola.txt
Normal file
59
tests/sample_outputs/run_attack_cnn_cola.txt
Normal file
@@ -0,0 +1,59 @@
|
||||
/.*/Attack(
|
||||
(search_method): GreedyWordSwapWIR(
|
||||
(wir_method): unk
|
||||
)
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapWordNet
|
||||
(constraints):
|
||||
(0): COLA(
|
||||
(max_diff): 0.1
|
||||
(model_name): textattack/bert-base-uncased-CoLA
|
||||
(compare_against_original): True
|
||||
)
|
||||
(1): BERTScore(
|
||||
(min_bert_score): 0.7
|
||||
(model): bert-base-uncased
|
||||
(score_type): f1
|
||||
(compare_against_original): True
|
||||
)
|
||||
(is_black_box): True
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92m2 (72%)[0m --> [91m1 (53%)[0m
|
||||
|
||||
Contrary to other reviews, I have zero complaints about the service or the prices. I have been getting tire service here for the past 5 years now, and compared to my experience with places like Pep Boys, these [92mguys[0m are experienced and know what they're doing. \nAlso, this is one place that I do not feel like I am being taken advantage of, just because of my gender. Other auto mechanics have been notorious for capitalizing on my ignorance of cars, and have sucked my bank account dry. But here, my service and road coverage has all been well explained - and let up to me to decide. \nAnd they just renovated the waiting room. It looks a lot better than it did in previous years.
|
||||
|
||||
Contrary to other reviews, I have zero complaints about the service or the prices. I have been getting tire service here for the past 5 years now, and compared to my experience with places like Pep Boys, these [91mblackguard[0m are experienced and know what they're doing. \nAlso, this is one place that I do not feel like I am being taken advantage of, just because of my gender. Other auto mechanics have been notorious for capitalizing on my ignorance of cars, and have sucked my bank account dry. But here, my service and road coverage has all been well explained - and let up to me to decide. \nAnd they just renovated the waiting room. It looks a lot better than it did in previous years.
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[91m1 (61%)[0m --> [92m2 (51%)[0m
|
||||
|
||||
Last summer I had an appointment to get new tires and had to wait a super long time. I also went in this week for them to fix a minor problem with a tire they put on. They \""fixed\"" it for free, and the very next morning I had the same issue. I called to complain, and the \""manager\"" didn't even apologize!!! So frustrated. Never going back. They [91mseem[0m overpriced, too.
|
||||
|
||||
Last summer I had an appointment to get new tires and had to wait a super long time. I also went in this week for them to fix a minor problem with a tire they put on. They \""fixed\"" it for free, and the very next morning I had the same issue. I called to complain, and the \""manager\"" didn't even apologize!!! So frustrated. Never going back. They [92mlook[0m overpriced, too.
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[92m2 (76%)[0m --> [91m1 (63%)[0m
|
||||
|
||||
[92mFriendly[0m staff, same starbucks fair you get anywhere else. Sometimes the lines [92mcan[0m get long.
|
||||
|
||||
[91mwell-disposed[0m staff, same starbucks fair you get anywhere else. Sometimes the lines [91mbehind[0m get long.
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 3 |
|
||||
| Number of failed attacks: | 0 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 0.0% |
|
||||
| Attack success rate: | 100.0% |
|
||||
| Average perturbed word %: | 5.18% |
|
||||
| Average num. words per input: | 70.33 |
|
||||
| Avg num queries: | 77.67 |
|
||||
+-------------------------------+--------+
|
||||
@@ -1,4 +1,4 @@
|
||||
Attack(
|
||||
/.*/Attack(
|
||||
(search_method): GreedySearch
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
@@ -16,7 +16,7 @@ Attack(
|
||||
(2): StopwordModification
|
||||
(is_black_box): True
|
||||
)
|
||||
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92mPositive (91%)[0m --> [91mNegative (62%)[0m
|
||||
|
||||
|
||||
@@ -146,6 +146,18 @@ attack_test_params = [
|
||||
),
|
||||
"tests/sample_outputs/run_attack_stanza_pos_tagger.txt",
|
||||
),
|
||||
#
|
||||
# test: run_attack on CNN Yelp using the WordNet transformation and greedy search WIR
|
||||
# with a CoLA constraint and BERT score
|
||||
#
|
||||
(
|
||||
"run_attack_cnn_cola",
|
||||
(
|
||||
"textattack attack --model cnn-yelp --num-examples 3 --search-method greedy-word-wir "
|
||||
"--transformation word-swap-wordnet --constraints cola^max_diff=0.1 bert-score^min_bert_score=0.7 --shuffle=False"
|
||||
),
|
||||
"tests/sample_outputs/run_attack_cnn_cola.txt",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -366,6 +366,7 @@ CONSTRAINT_CLASS_NAMES = {
|
||||
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
|
||||
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
|
||||
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
|
||||
"cola": "textattack.constraints.grammaticality.COLA",
|
||||
#
|
||||
# Overlap constraints
|
||||
#
|
||||
|
||||
@@ -2,3 +2,4 @@ from . import language_models
|
||||
|
||||
from .language_tool import LanguageTool
|
||||
from .part_of_speech import PartOfSpeech
|
||||
from .cola import COLA
|
||||
|
||||
73
textattack/constraints/grammaticality/cola.py
Normal file
73
textattack/constraints/grammaticality/cola.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import lru
|
||||
import nltk
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.models.wrappers import HuggingFaceModelWrapper
|
||||
|
||||
|
||||
class COLA(Constraint):
|
||||
"""Constrains an attack to text that has a similar number of linguistically
|
||||
accecptable sentences as the original text. Linguistic acceptability is
|
||||
determined by a model pre-trained on the `CoLA dataset <https://nyu-
|
||||
mll.github.io/CoLA/>`_. By default a BERT model is used, see the `pre-
|
||||
trained models README <https://github.com/QData/TextAttack/tree/master/
|
||||
textattack/models>`_ for a full list of available models or provide your
|
||||
own model from the huggingface model hub.
|
||||
|
||||
Args:
|
||||
max_diff (float): The absolute (if greater than or equal to 1) or percent (if less than 1)
|
||||
maximum difference allowed between the number of valid sentences in the reference
|
||||
text and the number of valid sentences in the attacked text.
|
||||
model_name (str): The name of the pre-trained model to use for classification. The model must be in huggingface model hub.
|
||||
compare_against_original (bool): If `True`, compare against the original text.
|
||||
Otherwise, compare against the most recent text.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_diff,
|
||||
model_name="textattack/bert-base-uncased-CoLA",
|
||||
compare_against_original=True,
|
||||
):
|
||||
super().__init__(compare_against_original)
|
||||
if not isinstance(max_diff, float):
|
||||
raise TypeError("max_diff must be a float")
|
||||
if max_diff < 0.0:
|
||||
raise ValueError("max_diff must be a value greater or equal to than 0.0")
|
||||
|
||||
self.max_diff = max_diff
|
||||
self.model_name = model_name
|
||||
self._reference_score_cache = lru.LRU(2 ** 10)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = HuggingFaceModelWrapper(model, tokenizer)
|
||||
|
||||
def clear_cache(self):
|
||||
self._reference_score_cache.clear()
|
||||
|
||||
def _check_constraint(self, transformed_text, reference_text):
|
||||
if reference_text not in self._reference_score_cache:
|
||||
# Split the text into sentences before predicting validity
|
||||
reference_sentences = nltk.sent_tokenize(reference_text.text)
|
||||
# A label of 1 indicates the sentence is valid
|
||||
num_valid = self.model(reference_sentences).argmax(axis=1).sum()
|
||||
self._reference_score_cache[reference_text] = num_valid
|
||||
|
||||
sentences = nltk.sent_tokenize(transformed_text.text)
|
||||
predictions = self.model(sentences)
|
||||
num_valid = predictions.argmax(axis=1).sum()
|
||||
reference_score = self._reference_score_cache[reference_text]
|
||||
|
||||
if (
|
||||
self.max_diff < 1.0
|
||||
and num_valid < reference_score - (reference_score * self.max_diff)
|
||||
) or (self.max_diff >= 1.0 and num_valid < reference_score - self.max_diff):
|
||||
return False
|
||||
return True
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return [
|
||||
"max_diff",
|
||||
"model_name",
|
||||
] + super().extra_repr_keys()
|
||||
@@ -114,6 +114,7 @@ def _post_install():
|
||||
nltk.download("omw")
|
||||
nltk.download("universal_tagset")
|
||||
nltk.download("wordnet")
|
||||
nltk.download("punkt")
|
||||
|
||||
import stanza
|
||||
|
||||
|
||||
Reference in New Issue
Block a user