1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Add CoLA based constraint

Added a constraint that uses a model pre-trained on the CoLA dataset to check that the attacked text has a certain number of linguistically acceptable sentences as compared to the reference text
This commit is contained in:
k-ivey
2020-10-16 17:15:03 -04:00
parent 67e03ac3c0
commit c9b1bcabdd
8 changed files with 154 additions and 2 deletions

View File

@@ -111,6 +111,11 @@ Part of Speech
.. automodule:: textattack.constraints.grammaticality.part_of_speech
:members:
CoLA
######
.. automodule:: textattack.constraints.grammaticality.cola
:members:
.. _overlap:
Overlap

View File

@@ -0,0 +1,59 @@
/.*/Attack(
(search_method): GreedyWordSwapWIR(
(wir_method): unk
)
(goal_function): UntargetedClassification
(transformation): WordSwapWordNet
(constraints):
(0): COLA(
(max_diff): 0.1
(model_name): textattack/bert-base-uncased-CoLA
(compare_against_original): True
)
(1): BERTScore(
(min_bert_score): 0.7
(model): bert-base-uncased
(score_type): f1
(compare_against_original): True
)
(is_black_box): True
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
2 (72%) --> 1 (53%)
Contrary to other reviews, I have zero complaints about the service or the prices. I have been getting tire service here for the past 5 years now, and compared to my experience with places like Pep Boys, these guys are experienced and know what they're doing. \nAlso, this is one place that I do not feel like I am being taken advantage of, just because of my gender. Other auto mechanics have been notorious for capitalizing on my ignorance of cars, and have sucked my bank account dry. But here, my service and road coverage has all been well explained - and let up to me to decide. \nAnd they just renovated the waiting room. It looks a lot better than it did in previous years.
Contrary to other reviews, I have zero complaints about the service or the prices. I have been getting tire service here for the past 5 years now, and compared to my experience with places like Pep Boys, these blackguard are experienced and know what they're doing. \nAlso, this is one place that I do not feel like I am being taken advantage of, just because of my gender. Other auto mechanics have been notorious for capitalizing on my ignorance of cars, and have sucked my bank account dry. But here, my service and road coverage has all been well explained - and let up to me to decide. \nAnd they just renovated the waiting room. It looks a lot better than it did in previous years.
--------------------------------------------- Result 2 ---------------------------------------------
1 (61%) --> 2 (51%)
Last summer I had an appointment to get new tires and had to wait a super long time. I also went in this week for them to fix a minor problem with a tire they put on. They \""fixed\"" it for free, and the very next morning I had the same issue. I called to complain, and the \""manager\"" didn't even apologize!!! So frustrated. Never going back. They seem overpriced, too.
Last summer I had an appointment to get new tires and had to wait a super long time. I also went in this week for them to fix a minor problem with a tire they put on. They \""fixed\"" it for free, and the very next morning I had the same issue. I called to complain, and the \""manager\"" didn't even apologize!!! So frustrated. Never going back. They look overpriced, too.
--------------------------------------------- Result 3 ---------------------------------------------
2 (76%) --> 1 (63%)
Friendly staff, same starbucks fair you get anywhere else. Sometimes the lines can get long.
well-disposed staff, same starbucks fair you get anywhere else. Sometimes the lines behind get long.
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 3 |
| Number of failed attacks: | 0 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 5.18% |
| Average num. words per input: | 70.33 |
| Avg num queries: | 77.67 |
+-------------------------------+--------+

View File

@@ -1,4 +1,4 @@
Attack(
/.*/Attack(
(search_method): GreedySearch
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
@@ -16,7 +16,7 @@ Attack(
(2): StopwordModification
(is_black_box): True
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (91%) --> Negative (62%)

View File

@@ -146,6 +146,18 @@ attack_test_params = [
),
"tests/sample_outputs/run_attack_stanza_pos_tagger.txt",
),
#
# test: run_attack on CNN Yelp using the WordNet transformation and greedy search WIR
# with a CoLA constraint and BERT score
#
(
"run_attack_cnn_cola",
(
"textattack attack --model cnn-yelp --num-examples 3 --search-method greedy-word-wir "
"--transformation word-swap-wordnet --constraints cola^max_diff=0.1 bert-score^min_bert_score=0.7 --shuffle=False"
),
"tests/sample_outputs/run_attack_cnn_cola.txt",
),
]

View File

@@ -366,6 +366,7 @@ CONSTRAINT_CLASS_NAMES = {
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
"cola": "textattack.constraints.grammaticality.COLA",
#
# Overlap constraints
#

View File

@@ -2,3 +2,4 @@ from . import language_models
from .language_tool import LanguageTool
from .part_of_speech import PartOfSpeech
from .cola import COLA

View File

@@ -0,0 +1,73 @@
import lru
import nltk
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from textattack.constraints import Constraint
from textattack.models.wrappers import HuggingFaceModelWrapper
class COLA(Constraint):
"""Constrains an attack to text that has a similar number of linguistically
accecptable sentences as the original text. Linguistic acceptability is
determined by a model pre-trained on the `CoLA dataset <https://nyu-
mll.github.io/CoLA/>`_. By default a BERT model is used, see the `pre-
trained models README <https://github.com/QData/TextAttack/tree/master/
textattack/models>`_ for a full list of available models or provide your
own model from the huggingface model hub.
Args:
max_diff (float): The absolute (if greater than or equal to 1) or percent (if less than 1)
maximum difference allowed between the number of valid sentences in the reference
text and the number of valid sentences in the attacked text.
model_name (str): The name of the pre-trained model to use for classification. The model must be in huggingface model hub.
compare_against_original (bool): If `True`, compare against the original text.
Otherwise, compare against the most recent text.
"""
def __init__(
self,
max_diff,
model_name="textattack/bert-base-uncased-CoLA",
compare_against_original=True,
):
super().__init__(compare_against_original)
if not isinstance(max_diff, float):
raise TypeError("max_diff must be a float")
if max_diff < 0.0:
raise ValueError("max_diff must be a value greater or equal to than 0.0")
self.max_diff = max_diff
self.model_name = model_name
self._reference_score_cache = lru.LRU(2 ** 10)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = HuggingFaceModelWrapper(model, tokenizer)
def clear_cache(self):
self._reference_score_cache.clear()
def _check_constraint(self, transformed_text, reference_text):
if reference_text not in self._reference_score_cache:
# Split the text into sentences before predicting validity
reference_sentences = nltk.sent_tokenize(reference_text.text)
# A label of 1 indicates the sentence is valid
num_valid = self.model(reference_sentences).argmax(axis=1).sum()
self._reference_score_cache[reference_text] = num_valid
sentences = nltk.sent_tokenize(transformed_text.text)
predictions = self.model(sentences)
num_valid = predictions.argmax(axis=1).sum()
reference_score = self._reference_score_cache[reference_text]
if (
self.max_diff < 1.0
and num_valid < reference_score - (reference_score * self.max_diff)
) or (self.max_diff >= 1.0 and num_valid < reference_score - self.max_diff):
return False
return True
def extra_repr_keys(self):
return [
"max_diff",
"model_name",
] + super().extra_repr_keys()

View File

@@ -114,6 +114,7 @@ def _post_install():
nltk.download("omw")
nltk.download("universal_tagset")
nltk.download("wordnet")
nltk.download("punkt")
import stanza