mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
bunch of bug fixes, update tutorials and tests and README
This commit is contained in:
@@ -120,7 +120,7 @@ A `Transformation` takes as input a `TokenizedText` and returns a list of possib
|
||||
|
||||
### Search Methods
|
||||
|
||||
A search method is currently implemented in an extension of the `Attack` class, through implementing the `attack_one` method. The `get_transformations` function takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attack’s constraints. A search consists of successive calls to `get_transformations` until the search succeeds or is exhausted.
|
||||
A search method takes as input an initial `goal_function` result and returns a final `goal_function` result. The search is given access to the `get_transformations` function, which takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attack’s constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined through `get_goal_results`) or is exhausted.
|
||||
|
||||
## Contributing to TextAttack
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
"The **goal function** determines if the attack is successful or not. One common goal function is **untargeted classification**, where the attack tries to perturb an input to change its classification. \n",
|
||||
"\n",
|
||||
"### Search method\n",
|
||||
"The **search method** explores the space of potential transformations and tries to locate a successful perturbation. Greedy search, beam search, and brute-force search are all examples of search methods. Since the search method is the backbone of the attack, the term \"search\" is often substituted with \"attack method\" or just \"attack\". In TextAttack, all three of those terms (search, attack, attack method) mean the same thing.\n",
|
||||
"The **search method** explores the space of potential transformations and tries to locate a successful perturbation. Greedy search, beam search, and brute-force search are all examples of search methods.\n",
|
||||
"\n",
|
||||
"### Transformation\n",
|
||||
"A **transformation** takes a text input and transforms it, replacing words or phrases with similar ones, while trying not to change the meaning. Paraphrase and synonym substitution are two broad classes of transformations.\n",
|
||||
@@ -28,7 +28,7 @@
|
||||
"source": [
|
||||
"### A custom transformation\n",
|
||||
"\n",
|
||||
"This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\". (These don't intend to preserve the grammaticality of inputs.) Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n",
|
||||
"This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\" (these don't intend to preserve the grammaticality of inputs). Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Banana word swap 🍌\n",
|
||||
@@ -38,7 +38,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -75,7 +75,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -111,14 +111,19 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from textattack.search_methods import GreedyWordSwap\n",
|
||||
"from textattack.search_methods import GreedySearch\n",
|
||||
"from textattack.constraints.semantics import RepeatModification, StopwordModification\n",
|
||||
"from textattack.shared import Attack\n",
|
||||
"\n",
|
||||
"# We're going to use our Banana word swap class as the attack transformation.\n",
|
||||
"transformation = BananaWordSwap() \n",
|
||||
"# And, we don't want to use any constraints.\n",
|
||||
"constraints = []\n",
|
||||
"# Now, let's make the attack using these parameters:\n",
|
||||
"attack = GreedyWordSwap(goal_function, transformation, constraints)"
|
||||
"# We'll constrain modificaiton of already modified indices and stopwords\n",
|
||||
"constraints = [RepeatModification(),\n",
|
||||
" StopwordModification()]\n",
|
||||
"# We'll use the Greedy search method\n",
|
||||
"search_method = GreedySearch()\n",
|
||||
"# Now, let's make the attack from the 4 components:\n",
|
||||
"attack = Attack(goal_function, constraints, transformation, search_method)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -137,12 +142,13 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GreedyWordSwap(\n",
|
||||
"Attack(\n",
|
||||
" (search_method): GreedySearch\n",
|
||||
" (goal_function): UntargetedClassification\n",
|
||||
" (transformation): BananaWordSwap(\n",
|
||||
" (replace_stopwords): False\n",
|
||||
" )\n",
|
||||
" (constraints): None\n",
|
||||
" (transformation): BananaWordSwap\n",
|
||||
" (constraints): \n",
|
||||
" (0): RepeatModification\n",
|
||||
" (1): StopwordModification\n",
|
||||
" (is_black_box): True\n",
|
||||
")\n"
|
||||
]
|
||||
@@ -170,7 +176,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"12it [00:01, 10.31it/s] \n"
|
||||
"12it [00:00, 13.71it/s] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -198,7 +204,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -306,9 +312,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "build_central"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@@ -52,19 +52,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[nltk_data] Downloading package punkt to /u/jm8wx/nltk_data...\n",
|
||||
"[nltk_data] Package punkt is already up-to-date!\n",
|
||||
"[nltk_data] Downloading package punkt to /u/edl9cy/nltk_data...\n",
|
||||
"[nltk_data] Unzipping tokenizers/punkt.zip.\n",
|
||||
"[nltk_data] Downloading package maxent_ne_chunker to\n",
|
||||
"[nltk_data] /u/jm8wx/nltk_data...\n",
|
||||
"[nltk_data] Package maxent_ne_chunker is already up-to-date!\n",
|
||||
"[nltk_data] Downloading package words to /u/jm8wx/nltk_data...\n",
|
||||
"[nltk_data] /u/edl9cy/nltk_data...\n",
|
||||
"[nltk_data] Unzipping chunkers/maxent_ne_chunker.zip.\n",
|
||||
"[nltk_data] Downloading package words to /u/edl9cy/nltk_data...\n",
|
||||
"[nltk_data] Package words is already up-to-date!\n"
|
||||
]
|
||||
},
|
||||
@@ -74,7 +74,7 @@
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -97,7 +97,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -151,7 +151,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -188,7 +188,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -214,7 +214,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -236,7 +236,7 @@
|
||||
" ('.', '.')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -264,7 +264,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -273,7 +273,7 @@
|
||||
"class NamedEntityConstraint(Constraint):\n",
|
||||
" \"\"\" A constraint that ensures `x_adv` only substitutes named entities from `x` with other named entities.\n",
|
||||
" \"\"\"\n",
|
||||
" def __call__(self, x, x_adv, original_text=None):\n",
|
||||
" def _check_constraint(self, x, x_adv, original_text=None):\n",
|
||||
" x_entities = get_entities(x.text)\n",
|
||||
" x_adv_entities = get_entities(x_adv.text)\n",
|
||||
" # If there aren't named entities, let's return False (the attack\n",
|
||||
@@ -316,7 +316,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -324,7 +324,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[34;1mtextattack\u001b[0m: Downloading https://textattack.s3.amazonaws.com/models/classification/lstm/yelp_polarity.\n",
|
||||
"100%|██████████| 297M/297M [00:06<00:00, 48.3MB/s] \n",
|
||||
"100%|██████████| 297M/297M [00:10<00:00, 28.5MB/s] \n",
|
||||
"\u001b[34;1mtextattack\u001b[0m: Unzipping file path_to_zip_file to unzipped_folder_path.\n",
|
||||
"\u001b[34;1mtextattack\u001b[0m: Successfully saved models/classification/lstm/yelp_polarity to cache.\n",
|
||||
"\u001b[34;1mtextattack\u001b[0m: Goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'> matches model LSTMForYelpSentimentClassification.\n"
|
||||
@@ -344,22 +344,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GreedyWordSwap(\n",
|
||||
"Attack(\n",
|
||||
" (search_method): GreedySearch\n",
|
||||
" (goal_function): UntargetedClassification\n",
|
||||
" (transformation): WordSwapEmbedding(\n",
|
||||
" (max_candidates): 15\n",
|
||||
" (embedding_type): paragramcf\n",
|
||||
" (replace_stopwords): False\n",
|
||||
" )\n",
|
||||
" (constraints): \n",
|
||||
" (0): NamedEntityConstraint\n",
|
||||
" (1): RepeatModification\n",
|
||||
" (2): StopwordModification\n",
|
||||
" (is_black_box): True\n",
|
||||
")\n"
|
||||
]
|
||||
@@ -367,20 +369,31 @@
|
||||
],
|
||||
"source": [
|
||||
"from textattack.transformations import WordSwapEmbedding\n",
|
||||
"from textattack.search_methods import GreedyWordSwap\n",
|
||||
"from textattack.search_methods import GreedySearch\n",
|
||||
"from textattack.constraints.semantics import RepeatModification, StopwordModification\n",
|
||||
"from textattack.shared import Attack\n",
|
||||
"\n",
|
||||
"# We're going to the `WordSwapEmbedding` transformation. Using the default settings, this\n",
|
||||
"# will try substituting words with their neighbors in the counter-fitted embedding space. \n",
|
||||
"transformation = WordSwapEmbedding(max_candidates=15) \n",
|
||||
"# Now, let's make the attack using these parameters. And add one constraint: our \n",
|
||||
"# custom NamedEntityConstraint.\n",
|
||||
"attack = GreedyWordSwap(goal_function, transformation, constraints=[NamedEntityConstraint()])\n",
|
||||
"\n",
|
||||
"# We'll use the greedy search method again\n",
|
||||
"search_method = GreedySearch()\n",
|
||||
"\n",
|
||||
"# Our constraints will be the same as Tutorial 1, plus the named entity constraint\n",
|
||||
"constraints = [RepeatModification(),\n",
|
||||
" StopwordModification(),\n",
|
||||
" NamedEntityConstraint()]\n",
|
||||
"\n",
|
||||
"# Now, let's make the attack using these parameters. \n",
|
||||
"attack = Attack(goal_function, constraints, transformation, search_method)\n",
|
||||
"\n",
|
||||
"print(attack)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -389,7 +402,7 @@
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 41,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -534,9 +547,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "build_central"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
||||
@@ -51,7 +51,7 @@ register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num-e
|
||||
#
|
||||
register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 '
|
||||
'--enable-csv --model bert-mnli --num-examples 4 --transformation word-swap-wordnet '
|
||||
'--constraints lang-tool --attack beam-search:beam_width=2'),
|
||||
'--constraints lang-tool repeat stopword --search beam-search:beam_width=2'),
|
||||
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
|
||||
output_file='local_tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_4.txt',
|
||||
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
|
||||
@@ -67,10 +67,11 @@ register_test(('python -m textattack --attack-n --goal-function targeted-classif
|
||||
#
|
||||
register_test(('python -m textattack --attack-n --goal-function non-overlapping-output '
|
||||
'--model t5-en2de --num-examples 6 --transformation word-swap-random-char-substitution '
|
||||
'--constraints edit-distance:12 words-perturbed:max_percent=0.75 --attack greedy-word'),
|
||||
'--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword '
|
||||
'--search greedy'),
|
||||
name='run_attack_nonoverlapping_t5en2de_randomcharsub_editdistance_wordsperturbed_greedyword',
|
||||
output_file='local_tests/sample_outputs/run_attack_nonoverlapping_t5ende_editdistance_bleu.txt',
|
||||
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
|
||||
'enable_csv and attack_n set, using the WordNet transformation and beam '
|
||||
'search with beam width 2, using language tool constraint, on 10 samples')
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import nltk
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared import TokenizedText
|
||||
from textattack.transformations import WordSwap
|
||||
from textattack.shared.validators import is_word_swap
|
||||
|
||||
class PartOfSpeech(Constraint):
|
||||
""" Constraints word swaps to only swap words with the same part of speech.
|
||||
@@ -17,7 +17,7 @@ class PartOfSpeech(Constraint):
|
||||
self._pos_tag_cache = lru.LRU(2**14)
|
||||
|
||||
def check_compatibility(self, transformation):
|
||||
return isinstance(transformation, WordSwap)
|
||||
return transformation.consists_of(is_word_swap)
|
||||
|
||||
def _can_replace_pos(self, pos_a, pos_b):
|
||||
return (pos_a == pos_b) or (self.allow_verb_noun_swap and set([pos_a,pos_b]) <= set(['NOUN','VERB']))
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
"""
|
||||
|
||||
from textattack.shared.utils import default_class_repr
|
||||
from textattack.transformations import WordSwap
|
||||
from textattack.constraints import ModificationConstraint
|
||||
from textattack.shared.validators import is_word_swap
|
||||
|
||||
class StopwordModification(ModificationConstraint):
|
||||
"""
|
||||
@@ -28,7 +28,4 @@ class StopwordModification(ModificationConstraint):
|
||||
Args:
|
||||
transformation: The transformation to check compatibility with.
|
||||
"""
|
||||
return isinstance(transformation, WordSwap)
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ['textfooler_stopwords']
|
||||
return transformation.consists_of(is_word_swap)
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
from textattack.shared import utils
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared import TokenizedText
|
||||
from textattack.transformations import WordSwap
|
||||
from textattack.shared.validators import is_word_swap
|
||||
|
||||
class WordEmbeddingDistance(Constraint):
|
||||
"""
|
||||
@@ -99,7 +99,7 @@ class WordEmbeddingDistance(Constraint):
|
||||
return mse_dist
|
||||
|
||||
def check_compatibility(self, transformation):
|
||||
return isinstance(transformation, WordSwap)
|
||||
return transformation.consists_of(is_word_swap)
|
||||
|
||||
def _check_constraint(self, x, x_adv, original_text=None):
|
||||
""" Returns true if (x, x_adv) are closer than `self.min_cos_sim`
|
||||
|
||||
@@ -27,6 +27,8 @@ def word_difference_score(s1, s2):
|
||||
s1_words = get_words_cached(s1)
|
||||
s2_words = get_words_cached(s2)
|
||||
min_length = min(len(s1_words), len(s2_words))
|
||||
if min_length == 0:
|
||||
return 0
|
||||
s1_words = s1_words[:min_length]
|
||||
s2_words = s2_words[:min_length]
|
||||
return (s1_words != s2_words).sum()
|
||||
return (s1_words != s2_words).sum()
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.transformations import WordSwap
|
||||
from textattack.shared.validators import is_word_swap
|
||||
|
||||
class GeneticAlgorithm(SearchMethod):
|
||||
"""
|
||||
@@ -27,7 +27,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
|
||||
def check_transformation_compatibility(self, transformation):
|
||||
return transformation.consists_of(WordSwap)
|
||||
return transformation.consists_of(is_word_swap)
|
||||
|
||||
def _replace_at_index(self, pop_member, idx):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.transformations import WordSwap
|
||||
from textattack.shared.validators import is_word_swap
|
||||
|
||||
class GreedyWordSwapWIR(SearchMethod):
|
||||
"""
|
||||
@@ -31,7 +31,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
raise KeyError(f'Word Importance Ranking method {wir_method} not recognized.')
|
||||
|
||||
def check_transformation_compatibility(self, transformation):
|
||||
return transformation.consists_of(WordSwap)
|
||||
return transformation.consists_of(is_word_swap)
|
||||
|
||||
def __call__(self, initial_result):
|
||||
tokenized_text = initial_result.tokenized_text
|
||||
|
||||
@@ -226,8 +226,9 @@ class Attack:
|
||||
)
|
||||
# self.constraints
|
||||
constraints_lines = []
|
||||
if len(self.constraints):
|
||||
for i, constraint in enumerate(self.constraints):
|
||||
constraints = self.constraints + self.modification_constraints
|
||||
if len(constraints):
|
||||
for i, constraint in enumerate(constraints):
|
||||
constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2))
|
||||
constraints_str = utils.add_indent('\n' + '\n'.join(constraints_lines), 2)
|
||||
else:
|
||||
|
||||
@@ -72,4 +72,8 @@ def validate_model_gradient_word_swap_compatibility(model):
|
||||
if isinstance(model, textattack.models.helpers.LSTMForClassification):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f'Cannot perform GradientBasedWordSwap on model {model}.')
|
||||
raise ValueError(f'Cannot perform GradientBasedWordSwap on model {model}.')
|
||||
|
||||
def is_word_swap(transformation):
|
||||
from textattack.transformations import WordSwap, WordSwapGradientBased
|
||||
return isinstance(transformation, WordSwap) or isinstance(transformation, WordSwapGradientBased)
|
||||
|
||||
@@ -17,8 +17,8 @@ class CompositeTransformation(Transformation):
|
||||
)
|
||||
return list(new_tokenized_texts)
|
||||
|
||||
def consists_of(self, subclass):
|
||||
def consists_of(self, validator):
|
||||
for transformation in self.transformations:
|
||||
if not transformation.consists_of(subclass):
|
||||
if not transformation.consists_of(validator):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -26,7 +26,7 @@ class Transformation:
|
||||
def extra_repr_keys(self):
|
||||
return []
|
||||
|
||||
def consists_of(self, subclass):
|
||||
return isinstance(self, subclass)
|
||||
def consists_of(self, validator):
|
||||
return validator(self)
|
||||
|
||||
__repr__ = __str__ = default_class_repr
|
||||
|
||||
Reference in New Issue
Block a user