mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge & downgrade test (capture_output) for py36
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -19,9 +19,6 @@ docs/_build/
|
||||
# Files from IDES
|
||||
.*.py
|
||||
|
||||
# CSVs to upload to MTurk
|
||||
*.csv
|
||||
|
||||
# TF Hub modules
|
||||
tensorflow-hub
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
language: python
|
||||
python:
|
||||
- '3.8'
|
||||
- '3.7'
|
||||
- '3.6'
|
||||
ython:
|
||||
- '3.8.3'
|
||||
- '3.7.7'
|
||||
- '3.6.10'
|
||||
cache: pip
|
||||
before_install:
|
||||
- python --version
|
||||
|
||||
78
README.md
78
README.md
@@ -34,6 +34,9 @@ You should be running Python 3.6+ to use this package. A CUDA-compatible GPU is
|
||||
pip install textattack
|
||||
```
|
||||
|
||||
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
|
||||
or via the python module (`python -m textattack ...`).
|
||||
|
||||
### Configuration
|
||||
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
|
||||
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
|
||||
@@ -41,27 +44,30 @@ environment variable `TA_CACHE_DIR`.
|
||||
|
||||
## Usage
|
||||
|
||||
TextAttack's main features can all be accessed via the `textattack` command. Two very
|
||||
common commands are `textattack attack <args>`, and `textattack augment <args>`. You can see more
|
||||
information about all commands using `textattack --help`, or a specific command using, for example,
|
||||
`textattack attack --help`.
|
||||
|
||||
### Running Attacks
|
||||
|
||||
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
|
||||
|
||||
We also have a command-line interface for running attacks. See help info and list of arguments with `python -m textattack --help`.
|
||||
|
||||
#### Sample Attack Commands
|
||||
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:
|
||||
|
||||
*TextFooler on an LSTM trained on the MR sentiment classification dataset*:
|
||||
```bash
|
||||
python -m textattack --recipe textfooler --model bert-base-uncased-mr --num-examples 100
|
||||
textattack attack --recipe textfooler --model bert-base-uncased-mr --num-examples 100
|
||||
```
|
||||
|
||||
*DeepWordBug on DistilBERT trained on the Quora Question Pairs paraphrase identification dataset*:
|
||||
```bash
|
||||
python -m textattack --model distilbert-base-uncased-qqp --recipe deepwordbug --num-examples 100
|
||||
textattack attack --model distilbert-base-uncased-qqp --recipe deepwordbug --num-examples 100
|
||||
```
|
||||
|
||||
*Beam search with beam width 4 and word embedding transformation and untargeted goal function on an LSTM*:
|
||||
```bash
|
||||
python -m textattack --model lstm-mr --num-examples 20 \
|
||||
textattack attack --model lstm-mr --num-examples 20 \
|
||||
--search-method beam-search:beam_width=4 --transformation word-swap-embedding \
|
||||
--constraints repeat stopword max-words-perturbed:max_num_words=2 embedding:min_cos_sim=0.8 part-of-speech \
|
||||
--goal-function untargeted-classification
|
||||
@@ -69,7 +75,7 @@ python -m textattack --model lstm-mr --num-examples 20 \
|
||||
|
||||
*Non-overlapping output attack using a greedy word swap and WordNet word substitutions on T5 English-to-German translation:*
|
||||
```bash
|
||||
python -m textattack --attack-n --goal-function non-overlapping-output \
|
||||
textattack attack --attack-n --goal-function non-overlapping-output \
|
||||
--model t5-en2de --num-examples 10 --transformation word-swap-wordnet \
|
||||
--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword \
|
||||
--search greedy
|
||||
@@ -79,7 +85,9 @@ python -m textattack --attack-n --goal-function non-overlapping-output \
|
||||
|
||||
### Attacks and Papers Implemented ("Attack Recipes")
|
||||
|
||||
We include attack recipes which build an attack such that only one command line argument has to be passed. To run an attack recipes, run `python -m textattack --recipe [recipe_name]`
|
||||
We include attack recipes which implement attacks from the literature. You can list attack recipes using `textattack list attack-recipes`.
|
||||
|
||||
To run an attack recipe: `textattack attack --recipe [recipe_name]`
|
||||
|
||||
The first are for classification tasks, like sentiment classification and entailment:
|
||||
- **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)).
|
||||
@@ -98,12 +106,12 @@ Here are some exampes of testing attacks from the literature from the command-li
|
||||
|
||||
*TextFooler against BERT fine-tuned on SST-2:*
|
||||
```bash
|
||||
python -m textattack --model bert-base-uncased-sst2 --recipe textfooler --num-examples 10
|
||||
textattack attack --model bert-base-uncased-sst2 --recipe textfooler --num-examples 10
|
||||
```
|
||||
|
||||
*seq2sick (black-box) against T5 fine-tuned for English-German translation:*
|
||||
```bash
|
||||
python -m textattack --recipe seq2sick --model t5-en2de --num-examples 100
|
||||
textattack attack --recipe seq2sick --model t5-en2de --num-examples 100
|
||||
```
|
||||
|
||||
### Augmenting Text
|
||||
@@ -115,8 +123,48 @@ for data augmentation:
|
||||
- `textattack.EmbeddingAugmenter` augments text by replacing words with neighbors in the counter-fitted embedding space, with a constraint to ensure their cosine similarity is at least 0.8
|
||||
- `textattack.CharSwapAugmenter` augments text by substituting, deleting, inserting, and swapping adjacent characters
|
||||
|
||||
All `Augmenter` objects implement `augment` and `augment_many` to generate augmentations
|
||||
of a string or a list of strings. Here's an example of how to use the `EmbeddingAugmenter`:
|
||||
#### Augmentation Command-Line Interface
|
||||
The easiest way to use our data augmentation tools is with `textattack augment <args>`. `textattack augment`
|
||||
takes an input CSV file and text column to augment, along with the number of words to change per augmentation
|
||||
and the number of augmentations per input example. It outputs a CSV in the same format with all the augmentation
|
||||
examples corresponding to the proper columns.
|
||||
|
||||
For example, given the following as `examples.csv`:
|
||||
|
||||
```csv
|
||||
"text",label
|
||||
"the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1
|
||||
"the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1
|
||||
"take care of my cat offers a refreshingly different slice of asian cinema .", 1
|
||||
"a technically well-made suspenser . . . but its abrupt drop in iq points as it races to the finish line proves simply too discouraging to let slide .", 0
|
||||
"it's a mystery how the movie could be released in this condition .", 0
|
||||
```
|
||||
|
||||
The command `textattack augment --csv examples.csv --input-column text --recipe embedding --num-words-to-swap 4 --transformations-per-example 2 --exclude-original`
|
||||
will augment the `text` column with four swaps per augmentation, twice as many augmentations as original inputs, and exclude the original inputs from the
|
||||
output CSV. (All of this will be saved to `augment.csv` by default.)
|
||||
|
||||
After augmentation, here are the contents of `augment.csv`:
|
||||
```csv
|
||||
text,label
|
||||
"the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1
|
||||
"the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1
|
||||
the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of expression significant adequately describe co-writer/director pedro jackson's expanded vision of j . rs . r . tolkien's middle-earth .,1
|
||||
the gorgeously elaborate continuation of 'the lordy of the piercings' trilogy is so huge that a column of mots cannot adequately describe co-novelist/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .,1
|
||||
take care of my cat offerings a pleasantly several slice of asia cinema .,1
|
||||
taking care of my cat offers a pleasantly different slice of asiatic kino .,1
|
||||
a technically good-made suspenser . . . but its abrupt drop in iq points as it races to the finish bloodline proves straightforward too disheartening to let slide .,0
|
||||
a technically well-made suspenser . . . but its abrupt drop in iq dot as it races to the finish line demonstrates simply too disheartening to leave slide .,0
|
||||
it's a enigma how the film wo be releases in this condition .,0
|
||||
it's a enigma how the filmmaking wo be publicized in this condition .,0
|
||||
```
|
||||
|
||||
The 'embedding' augmentation recipe uses counterfitted embedding nearest-neighbors to augment data.
|
||||
|
||||
#### Augmentation Python Interface
|
||||
In addition to the command-line interface, you can augment text dynamically by importing the
|
||||
`Augmenter` in your own code. All `Augmenter` objects implement `augment` and `augment_many` to generate augmentations
|
||||
of a string or a list of strings. Here's an example of how to use the `EmbeddingAugmenter` in a python script:
|
||||
|
||||
```python
|
||||
>>> from textattack.augmentation import EmbeddingAugmenter
|
||||
@@ -142,12 +190,12 @@ TextAttack is model-agnostic! You can use `TextAttack` to analyze any model that
|
||||
TextAttack also comes built-in with models and datasets. Our command-line interface will automatically match the correct
|
||||
dataset to the correct model. We include various pre-trained models for each of the nine [GLUE](https://gluebenchmark.com/)
|
||||
tasks, as well as some common classification datasets, translation, and summarization. You can
|
||||
see the full list of provided models & datasets via `python -m textattack --help`.
|
||||
see the full list of provided models & datasets via `textattack attack --help`.
|
||||
|
||||
Here's an example of using one of the built-in models:
|
||||
|
||||
```bash
|
||||
python -m textattack --model roberta-base-sst2 --recipe textfooler --num-examples 10
|
||||
textattack attack --model roberta-base-sst2 --recipe textfooler --num-examples 10
|
||||
```
|
||||
|
||||
#### HuggingFace support: `transformers` models and `nlp` datasets
|
||||
@@ -157,7 +205,7 @@ and datasets from the [`nlp` package](https://github.com/huggingface/nlp)! Here'
|
||||
and attacking a pre-trained model and dataset:
|
||||
|
||||
```bash
|
||||
python -m textattack --model_from_huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset_from_nlp glue:sst2 --recipe deepwordbug --num-examples 10
|
||||
textattack attack --model_from_huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset_from_nlp glue:sst2 --recipe deepwordbug --num-examples 10
|
||||
```
|
||||
|
||||
You can explore other pre-trained models using the `--model_from_huggingface` argument, or other datasets by changing
|
||||
|
||||
@@ -3,7 +3,7 @@ Attack Recipes
|
||||
|
||||
We provide a number of pre-built attack recipes. To run an attack recipe, run::
|
||||
|
||||
python -m textattack --recipe [recipe_name]
|
||||
textattack attack --recipe [recipe_name]
|
||||
|
||||
Alzantot
|
||||
###########
|
||||
|
||||
@@ -8,7 +8,7 @@ To use TextAttack, you must be running Python 3.6+. A CUDA-compatible GPU is opt
|
||||
|
||||
You're now all set to use TextAttack! Try running an attack from the command line::
|
||||
|
||||
python -m textattack --recipe textfooler --model bert-mr --num-examples 10
|
||||
textattack attack --recipe textfooler --model bert-mr --num-examples 10
|
||||
|
||||
This will run an attack using the TextFooler_ recipe, attacking BERT fine-tuned on the MR dataset. It will attack the first 10 samples. Once everything downloads and starts running, you should see attack results print to ``stdout``.
|
||||
|
||||
|
||||
1
examples/augmentation/.gitignore
vendored
Normal file
1
examples/augmentation/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
augment.csv # Don't commit the output file of this command
|
||||
11
examples/augmentation/augment.csv
Normal file
11
examples/augmentation/augment.csv
Normal file
@@ -0,0 +1,11 @@
|
||||
text,label
|
||||
"the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1
|
||||
"the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1
|
||||
the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of expression significant adequately describe co-writer/director pedro jackson's expanded vision of j . rs . r . tolkien's middle-earth .,1
|
||||
the gorgeously elaborate continuation of 'the lordy of the piercings' trilogy is so huge that a column of mots cannot adequately describe co-novelist/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .,1
|
||||
take care of my cat offerings a pleasantly several slice of asia cinema .,1
|
||||
taking care of my cat offers a pleasantly different slice of asiatic kino .,1
|
||||
a technically good-made suspenser . . . but its abrupt drop in iq points as it races to the finish bloodline proves straightforward too disheartening to let slide .,0
|
||||
a technically well-made suspenser . . . but its abrupt drop in iq dot as it races to the finish line demonstrates simply too disheartening to leave slide .,0
|
||||
it's a enigma how the film wo be releases in this condition .,0
|
||||
it's a enigma how the filmmaking wo be publicized in this condition .,0
|
||||
|
2
examples/augmentation/augment.sh
Executable file
2
examples/augmentation/augment.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
textattack augment --csv examples.csv --input-column text --recipe embedding --num-words-to-swap 4 --transformations-per-example 2 --exclude-original
|
||||
6
examples/augmentation/examples.csv
Normal file
6
examples/augmentation/examples.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
"text",label
|
||||
"the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1
|
||||
"the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1
|
||||
"take care of my cat offers a refreshingly different slice of asian cinema .", 1
|
||||
"a technically well-made suspenser . . . but its abrupt drop in iq points as it races to the finish line proves simply too discouraging to let slide .", 0
|
||||
"it's a mystery how the movie could be released in this condition .", 0
|
||||
|
@@ -20,3 +20,4 @@ terminaltables
|
||||
tqdm
|
||||
visdom
|
||||
wandb
|
||||
flair
|
||||
|
||||
4
setup.py
4
setup.py
@@ -28,7 +28,9 @@ setuptools.setup(
|
||||
"wandb*",
|
||||
]
|
||||
),
|
||||
entry_points={"console_scripts": ["textattack=textattack.__main__:main"],},
|
||||
entry_points={
|
||||
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
|
||||
4
tests/sample_inputs/augment.csv.txt
Normal file
4
tests/sample_inputs/augment.csv.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
text,label,another_column
|
||||
"For the last 8 years of his life, Galileo was under house arrest for espousing this man's theory", 5, "some text that needs to be preserved"
|
||||
"Signer of the Dec. of Indep., framer of the Constitution of Mass., second President of the United States", -13, "the answer to this question is John Adams"
|
||||
"In the title of an Aesop fable, this insect shared billing with a grasshopper", 1111, "these are from a jeopardy question, and this one is the ant"
|
||||
10
tests/sample_outputs/augment_test.csv.txt
Normal file
10
tests/sample_outputs/augment_test.csv.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
text,label,another_column
|
||||
"For the last 8 years of his life, Galileo was under house arrest for espousing this man's theory",5,some text that needs to be preserved
|
||||
"For the last 8 years of his lives, Galileo was under habitation arrest for espousing this man's theory",5,some text that needs to be preserved
|
||||
"For the last 8 yr of his life, Galileo was under house apprehending for espousing this man's theory",5,some text that needs to be preserved
|
||||
"Signer of the Dec. of Indep., framer of the Constitution of Mass., second President of the United States",-13,the answer to this question is John Adams
|
||||
"Signer of the Dec. of Indep., framer of the Constitution of Masse., seconds President of the United States",-13,the answer to this question is John Adams
|
||||
"Signer of the Dec. of Indep., framer of the Constitutions of Mass., second Chairwoman of the United States",-13,the answer to this question is John Adams
|
||||
"In the title of an Aesop fable, this insect shared billing with a grasshopper",1111,"these are from a jeopardy question, and this one is the ant"
|
||||
"Among the title of an Aesop fable, this beetle shared billing with a grasshopper",1111,"these are from a jeopardy question, and this one is the ant"
|
||||
"In the title of an Aesop fable, this insect exchanging invoice with a grasshopper",1111,"these are from a jeopardy question, and this one is the ant"
|
||||
@@ -15,6 +15,7 @@ Attack(
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(1): PartOfSpeech(
|
||||
(tagger_type): nltk
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
)
|
||||
|
||||
3
tests/sample_outputs/list_augmentation_recipes.txt
Normal file
3
tests/sample_outputs/list_augmentation_recipes.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
[94mcharswap[0m (textattack.augmentation.CharSwapAugmenter)
|
||||
[94membedding[0m (textattack.augmentation.EmbeddingAugmenter)
|
||||
[94mwordnet[0m (textattack.augmentation.WordNetAugmenter)
|
||||
66
tests/sample_outputs/run_attack_flair_pos_tagger.txt
Normal file
66
tests/sample_outputs/run_attack_flair_pos_tagger.txt
Normal file
@@ -0,0 +1,66 @@
|
||||
/.*/Attack(
|
||||
(search_method): GreedyWordSwapWIR(
|
||||
(wir_method): unk
|
||||
)
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 15
|
||||
(embedding_type): paragramcf
|
||||
)
|
||||
(constraints):
|
||||
(0): WordEmbeddingDistance(
|
||||
(embedding_type): paragramcf
|
||||
(min_cos_sim): 0.8
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(1): PartOfSpeech(
|
||||
(tagger_type): flair
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
)
|
||||
(2): RepeatModification
|
||||
(3): StopwordModification
|
||||
(is_black_box): True
|
||||
)
|
||||
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92m1[0m-->[91m[FAILED][0m
|
||||
|
||||
this is a film well worth seeing , talking and singing heads and all .
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[92m1[0m-->[91m0[0m
|
||||
|
||||
what really [92msurprises[0m about wisegirls is its low-[92mkey[0m quality and [92mgenuine[0m [92mtenderness[0m .
|
||||
|
||||
what really [91mdumbfounded[0m about wisegirls is its low-[91mvital[0m quality and [91mveritable[0m [91msensibility[0m .
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[92m1[0m-->[91m[FAILED][0m
|
||||
|
||||
( wendigo is ) why we go to the cinema : to be fed through the eye , the heart , the mind .
|
||||
|
||||
|
||||
--------------------------------------------- Result 4 ---------------------------------------------
|
||||
[92m1[0m-->[91m[FAILED][0m
|
||||
|
||||
one of the greatest family-oriented , fantasy-adventure movies ever .
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 1 |
|
||||
| Number of failed attacks: | 3 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 75.0% |
|
||||
| Attack success rate: | 25.0% |
|
||||
| Average perturbed word %: | 30.77% |
|
||||
| Average num. words per input: | 13.5 |
|
||||
| Avg num queries: | 45.0 |
|
||||
+-------------------------------+--------+
|
||||
@@ -7,7 +7,7 @@ Attack(
|
||||
(constraints): None
|
||||
(is_black_box): True
|
||||
)
|
||||
/.*/
|
||||
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92m1[0m-->[37m[SKIPPED][0m
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ Attack(
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(2): PartOfSpeech(
|
||||
(tagger_type): nltk
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
)
|
||||
|
||||
38
tests/test_command_line/helpers.py
Normal file
38
tests/test_command_line/helpers.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_command_and_get_result(command):
|
||||
""" Runs a command in the console and gets the result.
|
||||
|
||||
Command can be a string (single command) or a tuple of strings (multiple
|
||||
commands). In the multi-command setting, commands will be joined
|
||||
together with a pipe, and the output of the last command will be
|
||||
returned.
|
||||
"""
|
||||
from subprocess import PIPE
|
||||
|
||||
# run command
|
||||
if isinstance(command, tuple):
|
||||
# Support pipes via tuple of commands
|
||||
procs = []
|
||||
for i in range(len(command) - 1):
|
||||
if i == 0:
|
||||
proc = subprocess.Popen(shlex.split(command[i]), stdout=PIPE)
|
||||
else:
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(command[i]),
|
||||
stdout=subprocess.PIPE,
|
||||
stdin=procs[-1].stdout,
|
||||
)
|
||||
procs.append(proc)
|
||||
# Run last commmand
|
||||
result = subprocess.run(
|
||||
shlex.split(command[-1]), stdin=procs[-1].stdout, stdout=PIPE, stderr=PIPE
|
||||
)
|
||||
# Wait for all intermittent processes
|
||||
for proc in procs:
|
||||
proc.wait()
|
||||
else:
|
||||
result = subprocess.run(shlex.split(command), stdout=PIPE, stderr=PIPE)
|
||||
return result
|
||||
@@ -1,14 +1,16 @@
|
||||
import pdb
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
|
||||
DEBUG = False
|
||||
|
||||
"""
|
||||
Attack command-line tests in the format (name, args, sample_output_file)
|
||||
"""
|
||||
|
||||
attack_test_params = [
|
||||
#
|
||||
# test loading an attack from file
|
||||
@@ -16,7 +18,7 @@ attack_test_params = [
|
||||
(
|
||||
"attack_from_file",
|
||||
(
|
||||
"python -m textattack --model cnn-imdb "
|
||||
"textattack attack --model cnn-imdb "
|
||||
"--attack-from-file tests/sample_inputs/attack_from_file.py:Attack "
|
||||
"--num-examples 2 --num-examples-offset 18 --attack-n"
|
||||
),
|
||||
@@ -29,7 +31,7 @@ attack_test_params = [
|
||||
"interactive_mode",
|
||||
(
|
||||
'printf "All that glitters is not gold\nq\n"',
|
||||
"python -m textattack --recipe textfooler --model bert-base-uncased-imdb --interactive",
|
||||
"textattack attack --recipe textfooler --model bert-base-uncased-imdb --interactive",
|
||||
),
|
||||
"tests/sample_outputs/interactive_mode.txt",
|
||||
),
|
||||
@@ -39,7 +41,7 @@ attack_test_params = [
|
||||
(
|
||||
"attack_from_transformers",
|
||||
(
|
||||
"python -m textattack --model-from-huggingface "
|
||||
"textattack attack --model-from-huggingface "
|
||||
"distilbert-base-uncased-finetuned-sst-2-english "
|
||||
"--dataset-from-nlp glue:sst2:train --recipe deepwordbug --num-examples 3"
|
||||
),
|
||||
@@ -51,7 +53,8 @@ attack_test_params = [
|
||||
(
|
||||
"load_model_and_dataset_from_file",
|
||||
(
|
||||
"python -m textattack --model-from-file tests/sample_inputs/sst_model_and_dataset.py "
|
||||
"textattack attack "
|
||||
"--model-from-file tests/sample_inputs/sst_model_and_dataset.py "
|
||||
"--dataset-from-file tests/sample_inputs/sst_model_and_dataset.py "
|
||||
"--recipe deepwordbug --num-examples 3"
|
||||
),
|
||||
@@ -63,7 +66,7 @@ attack_test_params = [
|
||||
(
|
||||
"run_attack_hotflip_lstm_mr_4",
|
||||
(
|
||||
"python -m textattack --model lstm-mr --recipe hotflip "
|
||||
"textattack attack --model lstm-mr --recipe hotflip "
|
||||
"--num-examples 4 --num-examples-offset 3"
|
||||
),
|
||||
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt",
|
||||
@@ -74,7 +77,7 @@ attack_test_params = [
|
||||
(
|
||||
"run_attack_deepwordbug_lstm_mr_2",
|
||||
(
|
||||
"python -m textattack --model lstm-mr --recipe deepwordbug --num-examples 2 --attack-n"
|
||||
"textattack attack --model lstm-mr --recipe deepwordbug --num-examples 2 --attack-n"
|
||||
),
|
||||
"tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt",
|
||||
),
|
||||
@@ -87,21 +90,34 @@ attack_test_params = [
|
||||
(
|
||||
"run_attack_targeted_mnli_misc",
|
||||
(
|
||||
"python -m textattack --attack-n --goal-function targeted-classification:target_class=2 "
|
||||
"textattack attack --attack-n --goal-function targeted-classification:target_class=2 "
|
||||
"--enable-csv --model bert-base-uncased-mnli --num-examples 2 --attack-n --transformation word-swap-wordnet "
|
||||
"--constraints lang-tool repeat stopword --search beam-search:beam_width=2"
|
||||
),
|
||||
"tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt",
|
||||
),
|
||||
#
|
||||
# fmt: off
|
||||
# test: run_attack untargeted classification on BERT MR using word embedding transformation and greedy-word-WIR search
|
||||
# using Flair's part-of-speech tagger as constraint.
|
||||
#
|
||||
(
|
||||
"run_attack_flair_pos_tagger",
|
||||
(
|
||||
"textattack attack --model bert-base-uncased-mr --search greedy-word-wir --transformation word-swap-embedding "
|
||||
"--constraints repeat stopword embedding:min_cos_sim=0.8 part-of-speech:tagger_type=\\'flair\\' "
|
||||
"--num-examples 4 --num-examples-offset 10"
|
||||
),
|
||||
"tests/sample_outputs/run_attack_flair_pos_tagger.txt",
|
||||
),
|
||||
# fmt: on
|
||||
#
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name, command, sample_output_file", attack_test_params)
|
||||
@pytest.mark.slow
|
||||
def test_command_line_attack(capsys, name, command, sample_output_file):
|
||||
def test_command_line_attack(name, command, sample_output_file):
|
||||
""" Runs attack tests and compares their outputs to a reference file.
|
||||
"""
|
||||
# read in file and create regex
|
||||
@@ -109,29 +125,7 @@ def test_command_line_attack(capsys, name, command, sample_output_file):
|
||||
print("desired_output =>", desired_output)
|
||||
# regex in sample file look like /.*/
|
||||
desired_re = re.escape(desired_output).replace("/\\.\\*/", ".*")
|
||||
# run command
|
||||
if isinstance(command, tuple):
|
||||
# Support pipes via tuple of commands
|
||||
procs = []
|
||||
for i in range(len(command) - 1):
|
||||
if i == 0:
|
||||
proc = subprocess.Popen(shlex.split(command[i]), stdout=subprocess.PIPE)
|
||||
else:
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(command[i]),
|
||||
stdout=subprocess.PIPE,
|
||||
stdin=procs[-1].stdout,
|
||||
)
|
||||
procs.append(proc)
|
||||
# Run last commmand
|
||||
result = subprocess.run(
|
||||
shlex.split(command[-1]), stdin=procs[-1].stdout, capture_output=True
|
||||
)
|
||||
# Wait for all intermittent processes
|
||||
for proc in procs:
|
||||
proc.wait()
|
||||
else:
|
||||
result = subprocess.run(shlex.split(command), capture_output=True)
|
||||
result = run_command_and_get_result(command)
|
||||
# get output and check match
|
||||
assert result.stdout is not None
|
||||
stdout = result.stdout.decode().strip()
|
||||
@@ -141,7 +135,5 @@ def test_command_line_attack(capsys, name, command, sample_output_file):
|
||||
print("stderr =>", stderr)
|
||||
|
||||
if DEBUG and not re.match(desired_re, stdout, flags=re.S):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
assert re.match(desired_re, stdout, flags=re.S)
|
||||
37
tests/test_command_line/test_augment.py
Normal file
37
tests/test_command_line/test_augment.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
|
||||
augment_test_params = [
|
||||
(
|
||||
"simple_augment_test",
|
||||
"textattack augment --csv tests/sample_inputs/augment.csv.txt --input-column text --outfile augment_test.csv --overwrite",
|
||||
"augment_test.csv",
|
||||
"tests/sample_outputs/augment_test.csv.txt",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name, command, outfile, sample_output_file", augment_test_params
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_command_line_augmentation(name, command, outfile, sample_output_file):
|
||||
import os
|
||||
|
||||
desired_text = open(sample_output_file).read().strip()
|
||||
|
||||
# Run command and validate outputs.
|
||||
result = run_command_and_get_result(command)
|
||||
|
||||
assert result.stdout is not None
|
||||
stdout = result.stdout.decode().strip()
|
||||
assert stdout == ""
|
||||
|
||||
assert result.stderr is not None
|
||||
stderr = result.stderr.decode().strip()
|
||||
assert "Wrote 9 augmentations to augment_test.csv" in stderr
|
||||
|
||||
# Ensure CSV file exists, then delete it.
|
||||
assert os.path.exists(outfile)
|
||||
os.remove(outfile)
|
||||
28
tests/test_command_line/test_list.py
Normal file
28
tests/test_command_line/test_list.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
|
||||
list_test_params = [
|
||||
(
|
||||
"list_augmentation_recipes",
|
||||
"textattack list augmentation-recipes",
|
||||
"tests/sample_outputs/list_augmentation_recipes.txt",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name, command, sample_output_file", list_test_params)
|
||||
def test_command_line_list(name, command, sample_output_file):
|
||||
desired_text = open(sample_output_file).read().strip()
|
||||
|
||||
# Run command and validate outputs.
|
||||
result = run_command_and_get_result(command)
|
||||
|
||||
assert result.stdout is not None
|
||||
assert result.stderr is not None
|
||||
|
||||
stdout = result.stdout.decode().strip()
|
||||
stderr = result.stderr.decode().strip()
|
||||
|
||||
assert stderr == ""
|
||||
assert stdout == desired_text
|
||||
@@ -3,6 +3,7 @@ name = "textattack"
|
||||
from . import attack_recipes
|
||||
from . import attack_results
|
||||
from . import augmentation
|
||||
from . import commands
|
||||
from . import constraints
|
||||
from . import datasets
|
||||
from . import goal_functions
|
||||
|
||||
@@ -1,26 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
The TextAttack main module:
|
||||
|
||||
A command line parser to run an attack from user specifications.
|
||||
"""
|
||||
|
||||
|
||||
from textattack.shared.scripts.attack_args_parser import get_args
|
||||
from textattack.shared.scripts.run_attack_parallel import run as run_parallel
|
||||
from textattack.shared.scripts.run_attack_single_threaded import (
|
||||
run as run_single_threaded,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
if args.parallel:
|
||||
run_parallel(args)
|
||||
else:
|
||||
run_single_threaded(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
import textattack
|
||||
|
||||
textattack.commands.textattack_cli.main()
|
||||
|
||||
@@ -87,7 +87,7 @@ class Augmenter:
|
||||
if words_swapped == self.num_words_to_swap:
|
||||
break
|
||||
all_transformed_texts.add(current_text)
|
||||
return [t.text for t in all_transformed_texts]
|
||||
return sorted([at.printable_text() for at in all_transformed_texts])
|
||||
|
||||
def augment_many(self, text_list, show_progress=False):
|
||||
"""
|
||||
|
||||
16
textattack/commands/__init__.py
Normal file
16
textattack/commands/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import ArgumentParser, HelpFormatter
|
||||
|
||||
|
||||
class TextAttackCommand(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def register_subcommand(parser):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
from . import textattack_cli
|
||||
2
textattack/commands/attack/__init__.py
Normal file
2
textattack/commands/attack/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .attack_command import AttackCommand
|
||||
from .attack_resume_command import AttackResumeCommand
|
||||
@@ -1,6 +1,6 @@
|
||||
import textattack
|
||||
|
||||
RECIPE_NAMES = {
|
||||
ATTACK_RECIPE_NAMES = {
|
||||
"alzantot": "textattack.attack_recipes.Alzantot2018",
|
||||
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
|
||||
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
|
||||
@@ -244,7 +244,7 @@ CONSTRAINT_CLASS_NAMES = {
|
||||
"stopword": "textattack.constraints.pre_transformation.StopwordModification",
|
||||
}
|
||||
|
||||
SEARCH_CLASS_NAMES = {
|
||||
SEARCH_METHOD_CLASS_NAMES = {
|
||||
"beam-search": "textattack.search_methods.BeamSearch",
|
||||
"greedy": "textattack.search_methods.GreedySearch",
|
||||
"ga-word": "textattack.search_methods.GeneticAlgorithm",
|
||||
@@ -11,35 +11,15 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
from textattack.shared.scripts.attack_args import *
|
||||
|
||||
from .attack_args import *
|
||||
|
||||
|
||||
def set_seed(random_seed):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
|
||||
def get_args():
|
||||
# Parser for regular arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A commandline parser for TextAttack",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
|
||||
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transformation",
|
||||
type=str,
|
||||
required=False,
|
||||
default="word-swap-embedding",
|
||||
choices=transformation_names,
|
||||
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}. Choices: '
|
||||
+ str(transformation_names),
|
||||
)
|
||||
|
||||
def add_model_args(parser):
|
||||
""" Adds model-related arguments to an argparser. This is useful because we
|
||||
want to load pretrained models using multiple different parsers that
|
||||
share these, but not all, arguments.
|
||||
"""
|
||||
model_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
model_names = list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(
|
||||
@@ -66,6 +46,12 @@ def get_args():
|
||||
help="huggingface.co ID of pre-trained model to load",
|
||||
)
|
||||
|
||||
|
||||
def add_dataset_args(parser):
|
||||
""" Adds dataset-related arguments to an argparser. This is useful because we
|
||||
want to load pretrained models using multiple different parsers that
|
||||
share these, but not all, arguments.
|
||||
"""
|
||||
dataset_group = parser.add_mutually_exclusive_group()
|
||||
dataset_group.add_argument(
|
||||
"--dataset-from-nlp",
|
||||
@@ -81,48 +67,13 @@ def get_args():
|
||||
default=None,
|
||||
help="Dataset to load from a file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--constraints",
|
||||
type=str,
|
||||
required=False,
|
||||
nargs="*",
|
||||
default=["repeat", "stopword"],
|
||||
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
||||
+ str(CONSTRAINT_CLASS_NAMES.keys()),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="A directory to output results to.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-visdom", action="store_true", help="Enable logging to visdom."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-wandb",
|
||||
dataset_group.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
help="Enable logging to Weights & Biases.",
|
||||
required=False,
|
||||
default=False,
|
||||
help="Randomly shuffle the data before attacking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-stdout", action="store_true", help="Disable logging to stdout"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-csv",
|
||||
nargs="?",
|
||||
default=None,
|
||||
const="fancy",
|
||||
type=str,
|
||||
help="Enable logging to csv. Use --enable-csv plain to remove [[]] around words.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-examples",
|
||||
"-n",
|
||||
@@ -141,162 +92,20 @@ def get_args():
|
||||
help="The offset to start at in the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
required=False,
|
||||
default=False,
|
||||
help="Randomly shuffle the data before attacking",
|
||||
|
||||
def load_module_from_file(file_path):
|
||||
""" Uses ``importlib`` to dynamically open a file and load an object from
|
||||
it.
|
||||
"""
|
||||
temp_module_name = f"temp_{time.time()}"
|
||||
colored_file_path = textattack.shared.utils.color_text(
|
||||
file_path, color="blue", method="ansi"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to run attacks interactively.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attack-n",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to run attack until `n` examples have been attacked (not skipped).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run attack using multiple GPUs.",
|
||||
)
|
||||
|
||||
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
|
||||
parser.add_argument(
|
||||
"--goal-function",
|
||||
"-g",
|
||||
default="untargeted-classification",
|
||||
help=f"The goal function to use. choices: {goal_function_choices}",
|
||||
)
|
||||
|
||||
def str_to_int(s):
|
||||
return sum((ord(c) for c in s))
|
||||
|
||||
parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"))
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=default_checkpoint_dir(),
|
||||
help="The directory to save checkpoint files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-interval",
|
||||
required=False,
|
||||
type=int,
|
||||
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-budget",
|
||||
"-q",
|
||||
type=int,
|
||||
default=float("inf"),
|
||||
help="The maximum number of model queries allowed per example attacked.",
|
||||
)
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
search_choices = ", ".join(SEARCH_CLASS_NAMES.keys())
|
||||
attack_group.add_argument(
|
||||
"--search",
|
||||
"--search-method",
|
||||
"-s",
|
||||
type=str,
|
||||
required=False,
|
||||
default="greedy-word-wir",
|
||||
help=f"The search method to use. choices: {search_choices}",
|
||||
)
|
||||
attack_group.add_argument(
|
||||
"--recipe",
|
||||
"--attack-recipe",
|
||||
"-r",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="full attack recipe (overrides provided goal function, transformation & constraints)",
|
||||
choices=RECIPE_NAMES.keys(),
|
||||
)
|
||||
attack_group.add_argument(
|
||||
"--attack-from-file",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="attack to load from file (overrides provided goal function, transformation & constraints)",
|
||||
)
|
||||
|
||||
# Parser for parsing args for resume
|
||||
resume_parser = argparse.ArgumentParser(
|
||||
description="A commandline parser for TextAttack",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
resume_parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
"-f",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'
|
||||
"recover latest checkpoint from either current path or specified directory.",
|
||||
)
|
||||
|
||||
resume_parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
"-d",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory to save checkpoint files. If not set, use directory from recovered arguments.",
|
||||
)
|
||||
|
||||
resume_parser.add_argument(
|
||||
"--checkpoint-interval",
|
||||
"-i",
|
||||
required=False,
|
||||
type=int,
|
||||
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
||||
)
|
||||
|
||||
resume_parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run attack using multiple GPUs.",
|
||||
)
|
||||
|
||||
# Resume attack from checkpoint.
|
||||
if sys.argv[1:] and sys.argv[1].lower() == "resume":
|
||||
args = resume_parser.parse_args(sys.argv[2:])
|
||||
setattr(args, "checkpoint_resume", True)
|
||||
else:
|
||||
command_line_args = (
|
||||
None if sys.argv[1:] else ["-h"]
|
||||
) # Default to help with empty arguments.
|
||||
args = parser.parse_args(command_line_args)
|
||||
setattr(args, "checkpoint_resume", False)
|
||||
|
||||
if args.checkpoint_interval and args.shuffle:
|
||||
# Not allowed b/c we cannot recover order of shuffled data
|
||||
raise ValueError("Cannot use `--checkpoint-interval` with `--shuffle=True`")
|
||||
|
||||
set_seed(args.random_seed)
|
||||
|
||||
# Shortcuts for huggingface models using --model.
|
||||
if not args.checkpoint_resume and args.model in HUGGINGFACE_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
|
||||
elif not args.checkpoint_resume and args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
|
||||
return args
|
||||
textattack.shared.logger.info(f"Loading module from `{colored_file_path}`.")
|
||||
spec = importlib.util.spec_from_file_location(temp_module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def parse_transformation_from_args(args, model):
|
||||
@@ -374,11 +183,11 @@ def parse_attack_from_args(args):
|
||||
if args.recipe:
|
||||
if ":" in args.recipe:
|
||||
recipe_name, params = args.recipe.split(":")
|
||||
if recipe_name not in RECIPE_NAMES:
|
||||
if recipe_name not in ATTACK_RECIPE_NAMES:
|
||||
raise ValueError(f"Error: unsupported recipe {recipe_name}")
|
||||
recipe = eval(f"{RECIPE_NAMES[recipe_name]}(model, {params})")
|
||||
elif args.recipe in RECIPE_NAMES:
|
||||
recipe = eval(f"{RECIPE_NAMES[args.recipe]}(model)")
|
||||
recipe = eval(f"{ATTACK_RECIPE_NAMES[recipe_name]}(model, {params})")
|
||||
elif args.recipe in ATTACK_RECIPE_NAMES:
|
||||
recipe = eval(f"{ATTACK_RECIPE_NAMES[args.recipe]}(model)")
|
||||
else:
|
||||
raise ValueError(f"Invalid recipe {args.recipe}")
|
||||
recipe.goal_function.query_budget = args.query_budget
|
||||
@@ -388,8 +197,11 @@ def parse_attack_from_args(args):
|
||||
attack_file, attack_name = args.attack_from_file.split(":")
|
||||
else:
|
||||
attack_file, attack_name = args.attack_from_file, "attack"
|
||||
attack_file = attack_file.replace(".py", "").replace("/", ".")
|
||||
attack_module = importlib.import_module(attack_file)
|
||||
attack_module = load_module_from_file(attack_file)
|
||||
if not hasattr(attack_module, attack_name):
|
||||
raise ValueError(
|
||||
f"Loaded `{attack_file}` but could not find `{attack_name}`."
|
||||
)
|
||||
attack_func = getattr(attack_module, attack_name)
|
||||
return attack_func(model)
|
||||
else:
|
||||
@@ -398,11 +210,11 @@ def parse_attack_from_args(args):
|
||||
constraints = parse_constraints_from_args(args)
|
||||
if ":" in args.search:
|
||||
search_name, params = args.search.split(":")
|
||||
if search_name not in SEARCH_CLASS_NAMES:
|
||||
if search_name not in SEARCH_METHOD_CLASS_NAMES:
|
||||
raise ValueError(f"Error: unsupported search {search_name}")
|
||||
search_method = eval(f"{SEARCH_CLASS_NAMES[search_name]}({params})")
|
||||
elif args.search in SEARCH_CLASS_NAMES:
|
||||
search_method = eval(f"{SEARCH_CLASS_NAMES[args.search]}()")
|
||||
search_method = eval(f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})")
|
||||
elif args.search in SEARCH_METHOD_CLASS_NAMES:
|
||||
search_method = eval(f"{SEARCH_METHOD_CLASS_NAMES[args.search]}()")
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported attack {args.search}")
|
||||
return textattack.shared.Attack(
|
||||
@@ -427,12 +239,9 @@ def parse_model_from_args(args):
|
||||
"tokenizer",
|
||||
)
|
||||
try:
|
||||
model_file = args.model_from_file.replace(".py", "").replace("/", ".")
|
||||
model_module = importlib.import_module(model_file)
|
||||
model_module = load_module_from_file(args.model_from_file)
|
||||
except:
|
||||
raise ValueError(
|
||||
f"Failed to import model or tokenizer from file {args.model_from_file}"
|
||||
)
|
||||
raise ValueError(f"Failed to import file {args.model_from_file}")
|
||||
try:
|
||||
model = getattr(model_module, model_name)
|
||||
except AttributeError:
|
||||
@@ -492,6 +301,14 @@ def parse_model_from_args(args):
|
||||
|
||||
|
||||
def parse_dataset_from_args(args):
|
||||
# Automatically detect dataset for huggingface & textattack models.
|
||||
# This allows us to use the --model shortcut without specifying a dataset.
|
||||
if args.model in HUGGINGFACE_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
|
||||
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
|
||||
# Get dataset from args.
|
||||
if args.dataset_from_file:
|
||||
textattack.shared.logger.info(
|
||||
f"Loading model and tokenizer from file: {args.model_from_file}"
|
||||
@@ -501,8 +318,7 @@ def parse_dataset_from_args(args):
|
||||
else:
|
||||
dataset_file, dataset_name = args.dataset_from_file, "dataset"
|
||||
try:
|
||||
dataset_file = dataset_file.replace(".py", "").replace("/", ".")
|
||||
dataset_module = importlib.import_module(dataset_file)
|
||||
dataset_module = load_module_from_file(dataset_file)
|
||||
except:
|
||||
raise ValueError(
|
||||
f"Failed to import dataset from file {args.dataset_from_file}"
|
||||
184
textattack/commands/attack/attack_command.py
Normal file
184
textattack/commands/attack/attack_command.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
import textattack
|
||||
from textattack.commands import TextAttackCommand
|
||||
from textattack.commands.attack.attack_args import *
|
||||
from textattack.commands.attack.attack_args_helpers import *
|
||||
|
||||
|
||||
class AttackCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack attack module:
|
||||
|
||||
A command line parser to run an attack from user specifications.
|
||||
"""
|
||||
|
||||
def run(self, args):
|
||||
if args.checkpoint_interval and args.shuffle:
|
||||
# Not allowed b/c we cannot recover order of shuffled data
|
||||
raise ValueError("Cannot use `--checkpoint-interval` with `--shuffle=True`")
|
||||
|
||||
textattack.shared.utils.set_seed(args.random_seed)
|
||||
args.checkpoint_resume = False
|
||||
|
||||
from textattack.commands.attack.run_attack_parallel import run as run_parallel
|
||||
from textattack.commands.attack.run_attack_single_threaded import (
|
||||
run as run_single_threaded,
|
||||
)
|
||||
|
||||
if args.parallel:
|
||||
run_parallel(args)
|
||||
else:
|
||||
run_single_threaded(args)
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"attack",
|
||||
help="run an attack on an NLP model",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
|
||||
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transformation",
|
||||
type=str,
|
||||
required=False,
|
||||
default="word-swap-embedding",
|
||||
choices=transformation_names,
|
||||
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}. Choices: '
|
||||
+ str(transformation_names),
|
||||
)
|
||||
|
||||
add_model_args(parser)
|
||||
add_dataset_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--constraints",
|
||||
type=str,
|
||||
required=False,
|
||||
nargs="*",
|
||||
default=["repeat", "stopword"],
|
||||
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
||||
+ str(CONSTRAINT_CLASS_NAMES.keys()),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="A directory to output results to.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-visdom", action="store_true", help="Enable logging to visdom."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-wandb",
|
||||
action="store_true",
|
||||
help="Enable logging to Weights & Biases.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-stdout", action="store_true", help="Disable logging to stdout"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-csv",
|
||||
nargs="?",
|
||||
default=None,
|
||||
const="fancy",
|
||||
type=str,
|
||||
help="Enable logging to csv. Use --enable-csv plain to remove [[]] around words.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to run attacks interactively.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attack-n",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to run attack until `n` examples have been attacked (not skipped).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run attack using multiple GPUs.",
|
||||
)
|
||||
|
||||
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
|
||||
parser.add_argument(
|
||||
"--goal-function",
|
||||
"-g",
|
||||
default="untargeted-classification",
|
||||
help=f"The goal function to use. choices: {goal_function_choices}",
|
||||
)
|
||||
|
||||
def str_to_int(s):
|
||||
return sum((ord(c) for c in s))
|
||||
|
||||
parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"))
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=default_checkpoint_dir(),
|
||||
help="The directory to save checkpoint files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-interval",
|
||||
required=False,
|
||||
type=int,
|
||||
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-budget",
|
||||
"-q",
|
||||
type=int,
|
||||
default=float("inf"),
|
||||
help="The maximum number of model queries allowed per example attacked.",
|
||||
)
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
|
||||
attack_group.add_argument(
|
||||
"--search",
|
||||
"--search-method",
|
||||
"-s",
|
||||
type=str,
|
||||
required=False,
|
||||
default="greedy-word-wir",
|
||||
help=f"The search method to use. choices: {search_choices}",
|
||||
)
|
||||
attack_group.add_argument(
|
||||
"--recipe",
|
||||
"--attack-recipe",
|
||||
"-r",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="full attack recipe (overrides provided goal function, transformation & constraints)",
|
||||
choices=ATTACK_RECIPE_NAMES.keys(),
|
||||
)
|
||||
attack_group.add_argument(
|
||||
"--attack-from-file",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="attack to load from file (overrides provided goal function, transformation & constraints)",
|
||||
)
|
||||
|
||||
parser.set_defaults(func=AttackCommand())
|
||||
70
textattack/commands/attack/attack_resume_command.py
Normal file
70
textattack/commands/attack/attack_resume_command.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
from textattack.commands import TextAttackCommand
|
||||
|
||||
|
||||
class AttackResumeCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack attack resume recipe module:
|
||||
|
||||
A command line parser to resume a checkpointed attack from user specifications.
|
||||
"""
|
||||
|
||||
def run(self):
|
||||
textattack.shared.utils.set_seed(self.random_seed)
|
||||
self.checkpoint_resume = True
|
||||
|
||||
# Run attack from checkpoint.
|
||||
from textattack.commands.attack.run_attack_parallel import run as run_parallel
|
||||
from textattack.commands.attack.run_attack_single_threaded import (
|
||||
run as run_single_threaded,
|
||||
)
|
||||
|
||||
if self.parallel:
|
||||
run_parallel(self)
|
||||
else:
|
||||
run_single_threaded(self)
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
resume_parser = main_parser.add_parser(
|
||||
"attack-resume",
|
||||
help="resume a checkpointed attack",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Parser for parsing args for resume
|
||||
resume_parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
"-f",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'
|
||||
"recover latest checkpoint from either current path or specified directory.",
|
||||
)
|
||||
|
||||
resume_parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
"-d",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory to save checkpoint files. If not set, use directory from recovered arguments.",
|
||||
)
|
||||
|
||||
resume_parser.add_argument(
|
||||
"--checkpoint-interval",
|
||||
"-i",
|
||||
required=False,
|
||||
type=int,
|
||||
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
||||
)
|
||||
|
||||
resume_parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run attack using multiple GPUs.",
|
||||
)
|
||||
|
||||
resume_parser.set_defaults(func=AttackResumeCommand())
|
||||
@@ -11,7 +11,7 @@ import tqdm
|
||||
|
||||
import textattack
|
||||
|
||||
from .attack_args_parser import *
|
||||
from .attack_args_helpers import *
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
@@ -11,7 +11,7 @@ import tqdm
|
||||
|
||||
import textattack
|
||||
|
||||
from .attack_args_parser import *
|
||||
from .attack_args_helpers import *
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
143
textattack/commands/augment.py
Normal file
143
textattack/commands/augment.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
|
||||
import tqdm
|
||||
|
||||
import textattack
|
||||
from textattack.commands import TextAttackCommand
|
||||
|
||||
AUGMENTATION_RECIPE_NAMES = {
|
||||
"wordnet": "textattack.augmentation.WordNetAugmenter",
|
||||
"embedding": "textattack.augmentation.EmbeddingAugmenter",
|
||||
"charswap": "textattack.augmentation.CharSwapAugmenter",
|
||||
}
|
||||
|
||||
|
||||
class AugmentCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack attack module:
|
||||
|
||||
A command line parser to run data augmentation from user specifications.
|
||||
"""
|
||||
|
||||
def run(self, args):
|
||||
""" Reads in a CSV, performs augmentation, and outputs an augmented CSV.
|
||||
|
||||
Preserves all columns except for the input (augmneted) column.
|
||||
"""
|
||||
textattack.shared.utils.set_seed(args.random_seed)
|
||||
start_time = time.time()
|
||||
# Validate input/output paths.
|
||||
if not os.path.exists(args.csv):
|
||||
raise FileNotFoundError(f"Can't find CSV at location {args.csv}")
|
||||
if os.path.exists(args.outfile):
|
||||
if args.overwrite:
|
||||
textattack.shared.logger.info(f"Preparing to overwrite {args.outfile}.")
|
||||
else:
|
||||
raise OSError(f"Outfile {args.outfile} exists and --overwrite not set.")
|
||||
# Read in CSV file as a list of dictionaries. Use the CSV sniffer to
|
||||
# try and automatically infer the correct CSV format.
|
||||
csv_file = open(args.csv, "r")
|
||||
dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,")
|
||||
csv_file.seek(0)
|
||||
rows = [
|
||||
row
|
||||
for row in csv.DictReader(csv_file, dialect=dialect, skipinitialspace=True)
|
||||
]
|
||||
# Validate input column.
|
||||
row_keys = set(rows[0].keys())
|
||||
if args.input_column not in row_keys:
|
||||
raise ValueError(
|
||||
f"Could not find input column {args.input_column} in CSV. Found keys: {row_keys}"
|
||||
)
|
||||
textattack.shared.logger.info(
|
||||
f"Read {len(rows)} rows from {args.csv}. Found columns {row_keys}."
|
||||
)
|
||||
# Augment all examples.
|
||||
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
|
||||
num_words_to_swap=args.num_words_to_swap,
|
||||
transformations_per_example=args.transformations_per_example,
|
||||
)
|
||||
output_rows = []
|
||||
for row in tqdm.tqdm(rows, desc="Augmenting rows"):
|
||||
text_input = row[args.input_column]
|
||||
if not args.exclude_original:
|
||||
output_rows.append(row)
|
||||
for augmentation in augmenter.augment(text_input):
|
||||
augmented_row = row.copy()
|
||||
augmented_row[args.input_column] = augmentation
|
||||
output_rows.append(augmented_row)
|
||||
# Print to file.
|
||||
with open(args.outfile, "w") as outfile:
|
||||
csv_writer = csv.writer(
|
||||
outfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
|
||||
)
|
||||
# Write header.
|
||||
csv_writer.writerow(output_rows[0].keys())
|
||||
# Write rows.
|
||||
for row in output_rows:
|
||||
csv_writer.writerow(row.values())
|
||||
textattack.shared.logger.info(
|
||||
f"Wrote {len(output_rows)} augmentations to {args.outfile} in {time.time() - start_time}s."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"augment",
|
||||
help="augment text data",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv", help="input csv file to augment", type=str, required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-column",
|
||||
"--i",
|
||||
help="csv input column to be augmented",
|
||||
type=str,
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recipe",
|
||||
"--r",
|
||||
help="recipe for augmentation",
|
||||
type=str,
|
||||
default="embedding",
|
||||
choices=AUGMENTATION_RECIPE_NAMES.keys(),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-words-to-swap",
|
||||
"--n",
|
||||
help="words to swap out for each augmented example",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transformations-per-example",
|
||||
"--t",
|
||||
help="number of augmentations to return for each input",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outfile", "--o", help="path to outfile", type=str, default="augment.csv"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude-original",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="exclude original example from augmented CSV",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="overwrite output file, if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed", default=42, type=int, help="random seed to set"
|
||||
)
|
||||
parser.set_defaults(func=AugmentCommand())
|
||||
100
textattack/commands/benchmark_model.py
Normal file
100
textattack/commands/benchmark_model.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
from textattack.commands import TextAttackCommand
|
||||
from textattack.commands.attack.attack_args import *
|
||||
from textattack.commands.attack.attack_args_helpers import *
|
||||
|
||||
|
||||
def _cb(s):
|
||||
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
|
||||
|
||||
|
||||
class BenchmarkModelCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack model benchmarking module:
|
||||
|
||||
A command line parser to benchmark a model from user specifications.
|
||||
"""
|
||||
|
||||
def get_num_successes(self, model, ids, true_labels):
|
||||
with torch.no_grad():
|
||||
preds = textattack.shared.utils.model_predict(model, ids)
|
||||
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.device)
|
||||
guess_labels = preds.argmax(dim=1)
|
||||
successes = (guess_labels == true_labels).sum().item()
|
||||
return successes, true_labels, guess_labels
|
||||
|
||||
def test_model_on_dataset(self, args):
|
||||
model = parse_model_from_args(args)
|
||||
dataset = parse_dataset_from_args(args)
|
||||
succ = 0
|
||||
fail = 0
|
||||
batch_ids = []
|
||||
batch_labels = []
|
||||
all_true_labels = []
|
||||
all_guess_labels = []
|
||||
for i, (text_input, label) in enumerate(dataset):
|
||||
if i >= args.num_examples:
|
||||
break
|
||||
attacked_text = textattack.shared.AttackedText(text_input)
|
||||
ids = model.tokenizer.encode(attacked_text.tokenizer_input)
|
||||
batch_ids.append(ids)
|
||||
batch_labels.append(label)
|
||||
if len(batch_ids) == args.batch_size:
|
||||
batch_succ, true_labels, guess_labels = self.get_num_successes(
|
||||
model, batch_ids, batch_labels
|
||||
)
|
||||
batch_fail = args.batch_size - batch_succ
|
||||
succ += batch_succ
|
||||
fail += batch_fail
|
||||
batch_ids = []
|
||||
batch_labels = []
|
||||
all_true_labels.extend(true_labels.tolist())
|
||||
all_guess_labels.extend(guess_labels.tolist())
|
||||
if len(batch_ids) > 0:
|
||||
batch_succ, true_labels, guess_labels = self.get_num_successes(
|
||||
model, batch_ids, batch_labels
|
||||
)
|
||||
batch_fail = len(batch_ids) - batch_succ
|
||||
succ += batch_succ
|
||||
fail += batch_fail
|
||||
all_true_labels.extend(true_labels.tolist())
|
||||
all_guess_labels.extend(guess_labels.tolist())
|
||||
|
||||
perc = float(succ) / (succ + fail) * 100.0
|
||||
perc = "{:.2f}%".format(perc)
|
||||
print(f"Successes {succ}/{succ+fail} ({_cb(perc)})")
|
||||
return perc
|
||||
|
||||
def run(self, args):
|
||||
# Default to 'all' if no model chosen.
|
||||
if not (args.model or args.model_from_huggingface or args.model_from_file):
|
||||
for model_name in list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(
|
||||
TEXTATTACK_DATASET_BY_MODEL.keys()
|
||||
):
|
||||
args.model = model_name
|
||||
self.test_model_on_dataset(args)
|
||||
else:
|
||||
self.test_model_on_dataset(args)
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"benchmark-model",
|
||||
help="evaluate a model with TextAttack",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
add_model_args(parser)
|
||||
add_dataset_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Batch size for model inference.",
|
||||
)
|
||||
parser.set_defaults(func=BenchmarkModelCommand())
|
||||
23
textattack/commands/benchmark_recipe.py
Normal file
23
textattack/commands/benchmark_recipe.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
from textattack.commands import TextAttackCommand
|
||||
|
||||
|
||||
class BenchmarkRecipeCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack benchmark recipe module:
|
||||
|
||||
A command line parser to benchmark a recipe from user specifications.
|
||||
"""
|
||||
|
||||
def run(self, args):
|
||||
raise NotImplementedError("Cannot benchmark recipes yet - stay tuned!!")
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"benchmark-recipe",
|
||||
help="benchmark a recipe",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.set_defaults(func=BenchmarkRecipeCommand())
|
||||
66
textattack/commands/list_things.py
Normal file
66
textattack/commands/list_things.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
from textattack.commands import TextAttackCommand
|
||||
from textattack.commands.attack.attack_args import *
|
||||
from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES
|
||||
|
||||
|
||||
def _cb(s):
|
||||
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
|
||||
|
||||
|
||||
class ListThingsCommand(TextAttackCommand):
|
||||
"""
|
||||
The list module:
|
||||
|
||||
List default things in textattack.
|
||||
"""
|
||||
|
||||
def _list(self, list_of_things):
|
||||
""" Prints a list or dict of things. """
|
||||
if isinstance(list_of_things, list):
|
||||
list_of_things = sorted(list_of_things)
|
||||
for thing in list_of_things:
|
||||
print(_cb(thing))
|
||||
elif isinstance(list_of_things, dict):
|
||||
for thing in sorted(list_of_things.keys()):
|
||||
thing_long_description = list_of_things[thing]
|
||||
print(f"{_cb(thing)} ({thing_long_description})")
|
||||
else:
|
||||
raise TypeError(f"Cannot print list of type {type(list_of_things)}")
|
||||
|
||||
@staticmethod
|
||||
def things():
|
||||
list_dict = {}
|
||||
list_dict["models"] = list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(
|
||||
TEXTATTACK_DATASET_BY_MODEL.keys()
|
||||
)
|
||||
list_dict["search-methods"] = SEARCH_METHOD_CLASS_NAMES
|
||||
list_dict["transformations"] = {
|
||||
**BLACK_BOX_TRANSFORMATION_CLASS_NAMES,
|
||||
**WHITE_BOX_TRANSFORMATION_CLASS_NAMES,
|
||||
}
|
||||
list_dict["constraints"] = CONSTRAINT_CLASS_NAMES
|
||||
list_dict["goal-functions"] = GOAL_FUNCTION_CLASS_NAMES
|
||||
list_dict["attack-recipes"] = ATTACK_RECIPE_NAMES
|
||||
list_dict["augmentation-recipes"] = AUGMENTATION_RECIPE_NAMES
|
||||
return list_dict
|
||||
|
||||
def run(self, args):
|
||||
try:
|
||||
list_of_things = ListThingsCommand.things()[args.feature]
|
||||
except KeyError:
|
||||
raise ValuError(f"Unknown list key {args.thing}")
|
||||
self._list(list_of_things)
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"list",
|
||||
help="list features in TextAttack",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"feature", help=f"the feature to list", choices=ListThingsCommand.things()
|
||||
)
|
||||
parser.set_defaults(func=ListThingsCommand())
|
||||
43
textattack/commands/textattack_cli.py
Normal file
43
textattack/commands/textattack_cli.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from textattack.commands.attack import AttackCommand, AttackResumeCommand
|
||||
from textattack.commands.augment import AugmentCommand
|
||||
from textattack.commands.benchmark_model import BenchmarkModelCommand
|
||||
from textattack.commands.benchmark_recipe import BenchmarkRecipeCommand
|
||||
from textattack.commands.list_things import ListThingsCommand
|
||||
from textattack.commands.train_model import TrainModelCommand
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
"TextAttack CLI",
|
||||
usage="[python -m] texattack <command> [<args>]",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
subparsers = parser.add_subparsers(help="textattack command helpers")
|
||||
|
||||
# Register commands
|
||||
AttackCommand.register_subcommand(subparsers)
|
||||
AttackResumeCommand.register_subcommand(subparsers)
|
||||
AugmentCommand.register_subcommand(subparsers)
|
||||
BenchmarkModelCommand.register_subcommand(subparsers)
|
||||
BenchmarkRecipeCommand.register_subcommand(subparsers)
|
||||
ListThingsCommand.register_subcommand(subparsers)
|
||||
TrainModelCommand.register_subcommand(subparsers)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
if not hasattr(args, "func"):
|
||||
parser.print_help()
|
||||
exit(1)
|
||||
|
||||
# Run
|
||||
args.func.run(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
textattack/commands/train_model.py
Normal file
18
textattack/commands/train_model.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from textattack.commands import TextAttackCommand
|
||||
|
||||
|
||||
class TrainModelCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack train module:
|
||||
|
||||
A command line parser to train a model from user specifications.
|
||||
"""
|
||||
|
||||
def run(self, args):
|
||||
raise NotImplementedError("Cannot train models yet - stay tuned!!")
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser("train", help="train a model")
|
||||
@@ -1,3 +1,5 @@
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
import lru
|
||||
import nltk
|
||||
|
||||
@@ -10,13 +12,24 @@ class PartOfSpeech(Constraint):
|
||||
""" Constraints word swaps to only swap words with the same part of speech.
|
||||
Uses the NLTK universal part-of-speech tagger by default.
|
||||
An implementation of `<https://arxiv.org/abs/1907.11932>`_
|
||||
adapted from `<https://github.com/jind11/TextFooler>`_.
|
||||
adapted from `<https://github.com/jind11/TextFooler>`_.
|
||||
|
||||
POS tagger from Flair `<https://github.com/flairNLP/flair>` also available
|
||||
"""
|
||||
|
||||
def __init__(self, tagset="universal", allow_verb_noun_swap=True):
|
||||
def __init__(
|
||||
self, tagger_type="nltk", tagset="universal", allow_verb_noun_swap=True
|
||||
):
|
||||
self.tagger_type = tagger_type
|
||||
self.tagset = tagset
|
||||
self.allow_verb_noun_swap = allow_verb_noun_swap
|
||||
|
||||
self._pos_tag_cache = lru.LRU(2 ** 14)
|
||||
if tagger_type == "flair":
|
||||
if tagset == "universal":
|
||||
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
|
||||
else:
|
||||
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
|
||||
def _can_replace_pos(self, pos_a, pos_b):
|
||||
return (pos_a == pos_b) or (
|
||||
@@ -27,11 +40,24 @@ class PartOfSpeech(Constraint):
|
||||
context_words = before_ctx + [word] + after_ctx
|
||||
context_key = " ".join(context_words)
|
||||
if context_key in self._pos_tag_cache:
|
||||
pos_list = self._pos_tag_cache[context_key]
|
||||
word_list, pos_list = self._pos_tag_cache[context_key]
|
||||
else:
|
||||
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
|
||||
self._pos_tag_cache[context_key] = pos_list
|
||||
return pos_list
|
||||
if self.tagger_type == "nltk":
|
||||
word_list, pos_list = zip(
|
||||
*nltk.pos_tag(context_words, tagset=self.tagset)
|
||||
)
|
||||
|
||||
if self.tagger_type == "flair":
|
||||
word_list, pos_list = zip_flair_result(
|
||||
self._flair_pos_tagger.predict(context_key)[0]
|
||||
)
|
||||
|
||||
self._pos_tag_cache[context_key] = (word_list, pos_list)
|
||||
|
||||
# idx of `word` in `context_words`
|
||||
idx = len(before_ctx)
|
||||
assert word_list[idx] == word, "POS list not matched with original word list."
|
||||
return pos_list[idx]
|
||||
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
try:
|
||||
@@ -45,7 +71,7 @@ class PartOfSpeech(Constraint):
|
||||
current_word = current_text.words[i]
|
||||
transformed_word = transformed_text.words[i]
|
||||
before_ctx = current_text.words[max(i - 4, 0) : i]
|
||||
after_ctx = current_text.words[i + 1 : min(i + 5, len(current_text.words))]
|
||||
after_ctx = current_text.words[i + 1 : min(i + 4, len(current_text.words))]
|
||||
cur_pos = self._get_pos(before_ctx, current_word, after_ctx)
|
||||
replace_pos = self._get_pos(before_ctx, transformed_word, after_ctx)
|
||||
if not self._can_replace_pos(cur_pos, replace_pos):
|
||||
@@ -57,4 +83,18 @@ class PartOfSpeech(Constraint):
|
||||
return transformation_consists_of_word_swaps(transformation)
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["tagset", "allow_verb_noun_swap"]
|
||||
return ["tagger_type", "tagset", "allow_verb_noun_swap"]
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
if not isinstance(pred, Sentence):
|
||||
raise TypeError(f"Result from Flair POS tagger must be a `Sentence` object.")
|
||||
|
||||
tokens = pred.tokens
|
||||
word_list = []
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from . import scripts
|
||||
from . import utils
|
||||
from .utils import logger
|
||||
from . import validators
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import os
|
||||
|
||||
from textattack.shared.scripts.attack_args import (
|
||||
HUGGINGFACE_DATASET_BY_MODEL,
|
||||
TEXTATTACK_DATASET_BY_MODEL,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
for model in {**TEXTATTACK_DATASET_BY_MODEL, **HUGGINGFACE_DATASET_BY_MODEL}:
|
||||
print(model)
|
||||
os.system(
|
||||
f'python {os.path.join(dir_path, "benchmark_model.py")} --model {model} --num-examples 1000'
|
||||
)
|
||||
print()
|
||||
@@ -1,73 +0,0 @@
|
||||
import argparse
|
||||
import collections
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from attack_args_parser import get_args, parse_dataset_from_args, parse_model_from_args
|
||||
import textattack
|
||||
|
||||
|
||||
def _cb(s):
|
||||
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
|
||||
|
||||
|
||||
def get_num_successes(args, model, ids, true_labels):
|
||||
with torch.no_grad():
|
||||
preds = textattack.shared.utils.model_predict(model, ids)
|
||||
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.device)
|
||||
guess_labels = preds.argmax(dim=1)
|
||||
successes = (guess_labels == true_labels).sum().item()
|
||||
return successes, true_labels, guess_labels
|
||||
|
||||
|
||||
def test_model_on_dataset(args, model, dataset, num_examples=100, batch_size=128):
|
||||
num_examples = args.num_examples
|
||||
succ = 0
|
||||
fail = 0
|
||||
batch_ids = []
|
||||
batch_labels = []
|
||||
all_true_labels = []
|
||||
all_guess_labels = []
|
||||
for i, (text_input, label) in enumerate(dataset):
|
||||
if i >= num_examples:
|
||||
break
|
||||
attacked_text = textattack.shared.AttackedText(text_input)
|
||||
ids = model.tokenizer.encode(attacked_text.tokenizer_input)
|
||||
batch_ids.append(ids)
|
||||
batch_labels.append(label)
|
||||
if len(batch_ids) == batch_size:
|
||||
batch_succ, true_labels, guess_labels = get_num_successes(
|
||||
args, model, batch_ids, batch_labels
|
||||
)
|
||||
batch_fail = batch_size - batch_succ
|
||||
succ += batch_succ
|
||||
fail += batch_fail
|
||||
batch_ids = []
|
||||
batch_labels = []
|
||||
all_true_labels.extend(true_labels.tolist())
|
||||
all_guess_labels.extend(guess_labels.tolist())
|
||||
if len(batch_ids) > 0:
|
||||
batch_succ, true_labels, guess_labels = get_num_successes(
|
||||
args, model, batch_ids, batch_labels
|
||||
)
|
||||
batch_fail = len(batch_ids) - batch_succ
|
||||
succ += batch_succ
|
||||
fail += batch_fail
|
||||
all_true_labels.extend(true_labels.tolist())
|
||||
all_guess_labels.extend(guess_labels.tolist())
|
||||
|
||||
perc = float(succ) / (succ + fail) * 100.0
|
||||
perc = "{:.2f}%".format(perc)
|
||||
print(f"Successes {succ}/{succ+fail} ({_cb(perc)})")
|
||||
return perc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
model = parse_model_from_args(args)
|
||||
dataset = parse_dataset_from_args(args)
|
||||
|
||||
with torch.no_grad():
|
||||
test_model_on_dataset(args, model, dataset, num_examples=args.num_examples)
|
||||
@@ -1,3 +1,6 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
@@ -77,3 +80,9 @@ def load_textattack_model_from_path(model_name, model_path):
|
||||
else:
|
||||
raise ValueError(f"Unknown textattack model {model_path}")
|
||||
return model
|
||||
|
||||
|
||||
def set_seed(random_seed):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
Reference in New Issue
Block a user