1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

bunch of bug fixes, update tutorials and tests and README

This commit is contained in:
uvafan
2020-05-17 22:40:01 -04:00
parent fcb167b50d
commit 99cea57b0e
14 changed files with 94 additions and 70 deletions

View File

@@ -120,7 +120,7 @@ A `Transformation` takes as input a `TokenizedText` and returns a list of possib
### Search Methods
A search method is currently implemented in an extension of the `Attack` class, through implementing the `attack_one` method. The `get_transformations` function takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attacks constraints. A search consists of successive calls to `get_transformations` until the search succeeds or is exhausted.
A search method takes as input an initial `goal_function` result and returns a final `goal_function` result. The search is given access to the `get_transformations` function, which takes as input a `TokenizedText` object and outputs a list of possible transformations filtered by meeting all of the attacks constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined through `get_goal_results`) or is exhausted.
## Contributing to TextAttack

View File

@@ -13,7 +13,7 @@
"The **goal function** determines if the attack is successful or not. One common goal function is **untargeted classification**, where the attack tries to perturb an input to change its classification. \n",
"\n",
"### Search method\n",
"The **search method** explores the space of potential transformations and tries to locate a successful perturbation. Greedy search, beam search, and brute-force search are all examples of search methods. Since the search method is the backbone of the attack, the term \"search\" is often substituted with \"attack method\" or just \"attack\". In TextAttack, all three of those terms (search, attack, attack method) mean the same thing.\n",
"The **search method** explores the space of potential transformations and tries to locate a successful perturbation. Greedy search, beam search, and brute-force search are all examples of search methods.\n",
"\n",
"### Transformation\n",
"A **transformation** takes a text input and transforms it, replacing words or phrases with similar ones, while trying not to change the meaning. Paraphrase and synonym substitution are two broad classes of transformations.\n",
@@ -28,7 +28,7 @@
"source": [
"### A custom transformation\n",
"\n",
"This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\". (These don't intend to preserve the grammaticality of inputs.) Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n",
"This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\" (these don't intend to preserve the grammaticality of inputs). Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n",
"\n",
"\n",
"### Banana word swap 🍌\n",
@@ -38,7 +38,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -111,14 +111,19 @@
"metadata": {},
"outputs": [],
"source": [
"from textattack.search_methods import GreedyWordSwap\n",
"from textattack.search_methods import GreedySearch\n",
"from textattack.constraints.semantics import RepeatModification, StopwordModification\n",
"from textattack.shared import Attack\n",
"\n",
"# We're going to use our Banana word swap class as the attack transformation.\n",
"transformation = BananaWordSwap() \n",
"# And, we don't want to use any constraints.\n",
"constraints = []\n",
"# Now, let's make the attack using these parameters:\n",
"attack = GreedyWordSwap(goal_function, transformation, constraints)"
"# We'll constrain modificaiton of already modified indices and stopwords\n",
"constraints = [RepeatModification(),\n",
" StopwordModification()]\n",
"# We'll use the Greedy search method\n",
"search_method = GreedySearch()\n",
"# Now, let's make the attack from the 4 components:\n",
"attack = Attack(goal_function, constraints, transformation, search_method)"
]
},
{
@@ -137,12 +142,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"GreedyWordSwap(\n",
"Attack(\n",
" (search_method): GreedySearch\n",
" (goal_function): UntargetedClassification\n",
" (transformation): BananaWordSwap(\n",
" (replace_stopwords): False\n",
" )\n",
" (constraints): None\n",
" (transformation): BananaWordSwap\n",
" (constraints): \n",
" (0): RepeatModification\n",
" (1): StopwordModification\n",
" (is_black_box): True\n",
")\n"
]
@@ -170,7 +176,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"12it [00:01, 10.31it/s] \n"
"12it [00:00, 13.71it/s] \n"
]
}
],
@@ -198,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -306,9 +312,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"display_name": "Python 3",
"language": "python",
"name": "build_central"
"name": "python3"
},
"language_info": {
"codemirror_mode": {

View File

@@ -52,19 +52,19 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /u/jm8wx/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package punkt to /u/edl9cy/nltk_data...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n",
"[nltk_data] Downloading package maxent_ne_chunker to\n",
"[nltk_data] /u/jm8wx/nltk_data...\n",
"[nltk_data] Package maxent_ne_chunker is already up-to-date!\n",
"[nltk_data] Downloading package words to /u/jm8wx/nltk_data...\n",
"[nltk_data] /u/edl9cy/nltk_data...\n",
"[nltk_data] Unzipping chunkers/maxent_ne_chunker.zip.\n",
"[nltk_data] Downloading package words to /u/edl9cy/nltk_data...\n",
"[nltk_data] Package words is already up-to-date!\n"
]
},
@@ -74,7 +74,7 @@
"True"
]
},
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
@@ -97,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -151,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -188,7 +188,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -214,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -236,7 +236,7 @@
" ('.', '.')]"
]
},
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -264,7 +264,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -273,7 +273,7 @@
"class NamedEntityConstraint(Constraint):\n",
" \"\"\" A constraint that ensures `x_adv` only substitutes named entities from `x` with other named entities.\n",
" \"\"\"\n",
" def __call__(self, x, x_adv, original_text=None):\n",
" def _check_constraint(self, x, x_adv, original_text=None):\n",
" x_entities = get_entities(x.text)\n",
" x_adv_entities = get_entities(x_adv.text)\n",
" # If there aren't named entities, let's return False (the attack\n",
@@ -316,7 +316,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -324,7 +324,7 @@
"output_type": "stream",
"text": [
"\u001b[34;1mtextattack\u001b[0m: Downloading https://textattack.s3.amazonaws.com/models/classification/lstm/yelp_polarity.\n",
"100%|██████████| 297M/297M [00:06<00:00, 48.3MB/s] \n",
"100%|██████████| 297M/297M [00:10<00:00, 28.5MB/s] \n",
"\u001b[34;1mtextattack\u001b[0m: Unzipping file path_to_zip_file to unzipped_folder_path.\n",
"\u001b[34;1mtextattack\u001b[0m: Successfully saved models/classification/lstm/yelp_polarity to cache.\n",
"\u001b[34;1mtextattack\u001b[0m: Goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'> matches model LSTMForYelpSentimentClassification.\n"
@@ -344,22 +344,24 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GreedyWordSwap(\n",
"Attack(\n",
" (search_method): GreedySearch\n",
" (goal_function): UntargetedClassification\n",
" (transformation): WordSwapEmbedding(\n",
" (max_candidates): 15\n",
" (embedding_type): paragramcf\n",
" (replace_stopwords): False\n",
" )\n",
" (constraints): \n",
" (0): NamedEntityConstraint\n",
" (1): RepeatModification\n",
" (2): StopwordModification\n",
" (is_black_box): True\n",
")\n"
]
@@ -367,20 +369,31 @@
],
"source": [
"from textattack.transformations import WordSwapEmbedding\n",
"from textattack.search_methods import GreedyWordSwap\n",
"from textattack.search_methods import GreedySearch\n",
"from textattack.constraints.semantics import RepeatModification, StopwordModification\n",
"from textattack.shared import Attack\n",
"\n",
"# We're going to the `WordSwapEmbedding` transformation. Using the default settings, this\n",
"# will try substituting words with their neighbors in the counter-fitted embedding space. \n",
"transformation = WordSwapEmbedding(max_candidates=15) \n",
"# Now, let's make the attack using these parameters. And add one constraint: our \n",
"# custom NamedEntityConstraint.\n",
"attack = GreedyWordSwap(goal_function, transformation, constraints=[NamedEntityConstraint()])\n",
"\n",
"# We'll use the greedy search method again\n",
"search_method = GreedySearch()\n",
"\n",
"# Our constraints will be the same as Tutorial 1, plus the named entity constraint\n",
"constraints = [RepeatModification(),\n",
" StopwordModification(),\n",
" NamedEntityConstraint()]\n",
"\n",
"# Now, let's make the attack using these parameters. \n",
"attack = Attack(goal_function, constraints, transformation, search_method)\n",
"\n",
"print(attack)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -389,7 +402,7 @@
"True"
]
},
"execution_count": 41,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -534,9 +547,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"display_name": "Python 3",
"language": "python",
"name": "build_central"
"name": "python3"
},
"language_info": {
"codemirror_mode": {

View File

@@ -51,7 +51,7 @@ register_test('python -m textattack --model lstm-mr --recipe deepwordbug --num-e
#
register_test(('python -m textattack --attack-n --goal-function targeted-classification:target_class=2 '
'--enable-csv --model bert-mnli --num-examples 4 --transformation word-swap-wordnet '
'--constraints lang-tool --attack beam-search:beam_width=2'),
'--constraints lang-tool repeat stopword --search beam-search:beam_width=2'),
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
output_file='local_tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_4.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
@@ -67,10 +67,11 @@ register_test(('python -m textattack --attack-n --goal-function targeted-classif
#
register_test(('python -m textattack --attack-n --goal-function non-overlapping-output '
'--model t5-en2de --num-examples 6 --transformation word-swap-random-char-substitution '
'--constraints edit-distance:12 words-perturbed:max_percent=0.75 --attack greedy-word'),
'--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword '
'--search greedy'),
name='run_attack_nonoverlapping_t5en2de_randomcharsub_editdistance_wordsperturbed_greedyword',
output_file='local_tests/sample_outputs/run_attack_nonoverlapping_t5ende_editdistance_bleu.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples')
)
)

View File

@@ -3,7 +3,7 @@ import nltk
from textattack.constraints import Constraint
from textattack.shared import TokenizedText
from textattack.transformations import WordSwap
from textattack.shared.validators import is_word_swap
class PartOfSpeech(Constraint):
""" Constraints word swaps to only swap words with the same part of speech.
@@ -17,7 +17,7 @@ class PartOfSpeech(Constraint):
self._pos_tag_cache = lru.LRU(2**14)
def check_compatibility(self, transformation):
return isinstance(transformation, WordSwap)
return transformation.consists_of(is_word_swap)
def _can_replace_pos(self, pos_a, pos_b):
return (pos_a == pos_b) or (self.allow_verb_noun_swap and set([pos_a,pos_b]) <= set(['NOUN','VERB']))

View File

@@ -2,8 +2,8 @@
"""
from textattack.shared.utils import default_class_repr
from textattack.transformations import WordSwap
from textattack.constraints import ModificationConstraint
from textattack.shared.validators import is_word_swap
class StopwordModification(ModificationConstraint):
"""
@@ -28,7 +28,4 @@ class StopwordModification(ModificationConstraint):
Args:
transformation: The transformation to check compatibility with.
"""
return isinstance(transformation, WordSwap)
def extra_repr_keys(self):
return ['textfooler_stopwords']
return transformation.consists_of(is_word_swap)

View File

@@ -6,7 +6,7 @@ import torch
from textattack.shared import utils
from textattack.constraints import Constraint
from textattack.shared import TokenizedText
from textattack.transformations import WordSwap
from textattack.shared.validators import is_word_swap
class WordEmbeddingDistance(Constraint):
"""
@@ -99,7 +99,7 @@ class WordEmbeddingDistance(Constraint):
return mse_dist
def check_compatibility(self, transformation):
return isinstance(transformation, WordSwap)
return transformation.consists_of(is_word_swap)
def _check_constraint(self, x, x_adv, original_text=None):
""" Returns true if (x, x_adv) are closer than `self.min_cos_sim`

View File

@@ -27,6 +27,8 @@ def word_difference_score(s1, s2):
s1_words = get_words_cached(s1)
s2_words = get_words_cached(s2)
min_length = min(len(s1_words), len(s2_words))
if min_length == 0:
return 0
s1_words = s1_words[:min_length]
s2_words = s2_words[:min_length]
return (s1_words != s2_words).sum()
return (s1_words != s2_words).sum()

View File

@@ -9,7 +9,7 @@ import torch
from copy import deepcopy
from textattack.search_methods import SearchMethod
from textattack.transformations import WordSwap
from textattack.shared.validators import is_word_swap
class GeneticAlgorithm(SearchMethod):
"""
@@ -27,7 +27,7 @@ class GeneticAlgorithm(SearchMethod):
self.give_up_if_no_improvement = give_up_if_no_improvement
def check_transformation_compatibility(self, transformation):
return transformation.consists_of(WordSwap)
return transformation.consists_of(is_word_swap)
def _replace_at_index(self, pop_member, idx):
"""

View File

@@ -1,7 +1,7 @@
import numpy as np
from textattack.search_methods import SearchMethod
from textattack.transformations import WordSwap
from textattack.shared.validators import is_word_swap
class GreedyWordSwapWIR(SearchMethod):
"""
@@ -31,7 +31,7 @@ class GreedyWordSwapWIR(SearchMethod):
raise KeyError(f'Word Importance Ranking method {wir_method} not recognized.')
def check_transformation_compatibility(self, transformation):
return transformation.consists_of(WordSwap)
return transformation.consists_of(is_word_swap)
def __call__(self, initial_result):
tokenized_text = initial_result.tokenized_text

View File

@@ -226,8 +226,9 @@ class Attack:
)
# self.constraints
constraints_lines = []
if len(self.constraints):
for i, constraint in enumerate(self.constraints):
constraints = self.constraints + self.modification_constraints
if len(constraints):
for i, constraint in enumerate(constraints):
constraints_lines.append(utils.add_indent(f'({i}): {constraint}', 2))
constraints_str = utils.add_indent('\n' + '\n'.join(constraints_lines), 2)
else:

View File

@@ -72,4 +72,8 @@ def validate_model_gradient_word_swap_compatibility(model):
if isinstance(model, textattack.models.helpers.LSTMForClassification):
return True
else:
raise ValueError(f'Cannot perform GradientBasedWordSwap on model {model}.')
raise ValueError(f'Cannot perform GradientBasedWordSwap on model {model}.')
def is_word_swap(transformation):
from textattack.transformations import WordSwap, WordSwapGradientBased
return isinstance(transformation, WordSwap) or isinstance(transformation, WordSwapGradientBased)

View File

@@ -17,8 +17,8 @@ class CompositeTransformation(Transformation):
)
return list(new_tokenized_texts)
def consists_of(self, subclass):
def consists_of(self, validator):
for transformation in self.transformations:
if not transformation.consists_of(subclass):
if not transformation.consists_of(validator):
return False
return True

View File

@@ -26,7 +26,7 @@ class Transformation:
def extra_repr_keys(self):
return []
def consists_of(self, subclass):
return isinstance(self, subclass)
def consists_of(self, validator):
return validator(self)
__repr__ = __str__ = default_class_repr