1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

merge & downgrade test (capture_output) for py36

This commit is contained in:
Jack Morris
2020-06-20 18:21:52 -04:00
46 changed files with 1095 additions and 426 deletions

3
.gitignore vendored
View File

@@ -19,9 +19,6 @@ docs/_build/
# Files from IDES
.*.py
# CSVs to upload to MTurk
*.csv
# TF Hub modules
tensorflow-hub

View File

@@ -1,8 +1,8 @@
language: python
python:
- '3.8'
- '3.7'
- '3.6'
ython:
- '3.8.3'
- '3.7.7'
- '3.6.10'
cache: pip
before_install:
- python --version

View File

@@ -34,6 +34,9 @@ You should be running Python 3.6+ to use this package. A CUDA-compatible GPU is
pip install textattack
```
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
or via the python module (`python -m textattack ...`).
### Configuration
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
@@ -41,27 +44,30 @@ environment variable `TA_CACHE_DIR`.
## Usage
TextAttack's main features can all be accessed via the `textattack` command. Two very
common commands are `textattack attack <args>`, and `textattack augment <args>`. You can see more
information about all commands using `textattack --help`, or a specific command using, for example,
`textattack attack --help`.
### Running Attacks
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
We also have a command-line interface for running attacks. See help info and list of arguments with `python -m textattack --help`.
#### Sample Attack Commands
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:
*TextFooler on an LSTM trained on the MR sentiment classification dataset*:
```bash
python -m textattack --recipe textfooler --model bert-base-uncased-mr --num-examples 100
textattack attack --recipe textfooler --model bert-base-uncased-mr --num-examples 100
```
*DeepWordBug on DistilBERT trained on the Quora Question Pairs paraphrase identification dataset*:
```bash
python -m textattack --model distilbert-base-uncased-qqp --recipe deepwordbug --num-examples 100
textattack attack --model distilbert-base-uncased-qqp --recipe deepwordbug --num-examples 100
```
*Beam search with beam width 4 and word embedding transformation and untargeted goal function on an LSTM*:
```bash
python -m textattack --model lstm-mr --num-examples 20 \
textattack attack --model lstm-mr --num-examples 20 \
--search-method beam-search:beam_width=4 --transformation word-swap-embedding \
--constraints repeat stopword max-words-perturbed:max_num_words=2 embedding:min_cos_sim=0.8 part-of-speech \
--goal-function untargeted-classification
@@ -69,7 +75,7 @@ python -m textattack --model lstm-mr --num-examples 20 \
*Non-overlapping output attack using a greedy word swap and WordNet word substitutions on T5 English-to-German translation:*
```bash
python -m textattack --attack-n --goal-function non-overlapping-output \
textattack attack --attack-n --goal-function non-overlapping-output \
--model t5-en2de --num-examples 10 --transformation word-swap-wordnet \
--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword \
--search greedy
@@ -79,7 +85,9 @@ python -m textattack --attack-n --goal-function non-overlapping-output \
### Attacks and Papers Implemented ("Attack Recipes")
We include attack recipes which build an attack such that only one command line argument has to be passed. To run an attack recipes, run `python -m textattack --recipe [recipe_name]`
We include attack recipes which implement attacks from the literature. You can list attack recipes using `textattack list attack-recipes`.
To run an attack recipe: `textattack attack --recipe [recipe_name]`
The first are for classification tasks, like sentiment classification and entailment:
- **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)).
@@ -98,12 +106,12 @@ Here are some exampes of testing attacks from the literature from the command-li
*TextFooler against BERT fine-tuned on SST-2:*
```bash
python -m textattack --model bert-base-uncased-sst2 --recipe textfooler --num-examples 10
textattack attack --model bert-base-uncased-sst2 --recipe textfooler --num-examples 10
```
*seq2sick (black-box) against T5 fine-tuned for English-German translation:*
```bash
python -m textattack --recipe seq2sick --model t5-en2de --num-examples 100
textattack attack --recipe seq2sick --model t5-en2de --num-examples 100
```
### Augmenting Text
@@ -115,8 +123,48 @@ for data augmentation:
- `textattack.EmbeddingAugmenter` augments text by replacing words with neighbors in the counter-fitted embedding space, with a constraint to ensure their cosine similarity is at least 0.8
- `textattack.CharSwapAugmenter` augments text by substituting, deleting, inserting, and swapping adjacent characters
All `Augmenter` objects implement `augment` and `augment_many` to generate augmentations
of a string or a list of strings. Here's an example of how to use the `EmbeddingAugmenter`:
#### Augmentation Command-Line Interface
The easiest way to use our data augmentation tools is with `textattack augment <args>`. `textattack augment`
takes an input CSV file and text column to augment, along with the number of words to change per augmentation
and the number of augmentations per input example. It outputs a CSV in the same format with all the augmentation
examples corresponding to the proper columns.
For example, given the following as `examples.csv`:
```csv
"text",label
"the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1
"the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1
"take care of my cat offers a refreshingly different slice of asian cinema .", 1
"a technically well-made suspenser . . . but its abrupt drop in iq points as it races to the finish line proves simply too discouraging to let slide .", 0
"it's a mystery how the movie could be released in this condition .", 0
```
The command `textattack augment --csv examples.csv --input-column text --recipe embedding --num-words-to-swap 4 --transformations-per-example 2 --exclude-original`
will augment the `text` column with four swaps per augmentation, twice as many augmentations as original inputs, and exclude the original inputs from the
output CSV. (All of this will be saved to `augment.csv` by default.)
After augmentation, here are the contents of `augment.csv`:
```csv
text,label
"the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1
"the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1
the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of expression significant adequately describe co-writer/director pedro jackson's expanded vision of j . rs . r . tolkien's middle-earth .,1
the gorgeously elaborate continuation of 'the lordy of the piercings' trilogy is so huge that a column of mots cannot adequately describe co-novelist/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .,1
take care of my cat offerings a pleasantly several slice of asia cinema .,1
taking care of my cat offers a pleasantly different slice of asiatic kino .,1
a technically good-made suspenser . . . but its abrupt drop in iq points as it races to the finish bloodline proves straightforward too disheartening to let slide .,0
a technically well-made suspenser . . . but its abrupt drop in iq dot as it races to the finish line demonstrates simply too disheartening to leave slide .,0
it's a enigma how the film wo be releases in this condition .,0
it's a enigma how the filmmaking wo be publicized in this condition .,0
```
The 'embedding' augmentation recipe uses counterfitted embedding nearest-neighbors to augment data.
#### Augmentation Python Interface
In addition to the command-line interface, you can augment text dynamically by importing the
`Augmenter` in your own code. All `Augmenter` objects implement `augment` and `augment_many` to generate augmentations
of a string or a list of strings. Here's an example of how to use the `EmbeddingAugmenter` in a python script:
```python
>>> from textattack.augmentation import EmbeddingAugmenter
@@ -142,12 +190,12 @@ TextAttack is model-agnostic! You can use `TextAttack` to analyze any model that
TextAttack also comes built-in with models and datasets. Our command-line interface will automatically match the correct
dataset to the correct model. We include various pre-trained models for each of the nine [GLUE](https://gluebenchmark.com/)
tasks, as well as some common classification datasets, translation, and summarization. You can
see the full list of provided models & datasets via `python -m textattack --help`.
see the full list of provided models & datasets via `textattack attack --help`.
Here's an example of using one of the built-in models:
```bash
python -m textattack --model roberta-base-sst2 --recipe textfooler --num-examples 10
textattack attack --model roberta-base-sst2 --recipe textfooler --num-examples 10
```
#### HuggingFace support: `transformers` models and `nlp` datasets
@@ -157,7 +205,7 @@ and datasets from the [`nlp` package](https://github.com/huggingface/nlp)! Here'
and attacking a pre-trained model and dataset:
```bash
python -m textattack --model_from_huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset_from_nlp glue:sst2 --recipe deepwordbug --num-examples 10
textattack attack --model_from_huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset_from_nlp glue:sst2 --recipe deepwordbug --num-examples 10
```
You can explore other pre-trained models using the `--model_from_huggingface` argument, or other datasets by changing

View File

@@ -3,7 +3,7 @@ Attack Recipes
We provide a number of pre-built attack recipes. To run an attack recipe, run::
python -m textattack --recipe [recipe_name]
textattack attack --recipe [recipe_name]
Alzantot
###########

View File

@@ -8,7 +8,7 @@ To use TextAttack, you must be running Python 3.6+. A CUDA-compatible GPU is opt
You're now all set to use TextAttack! Try running an attack from the command line::
python -m textattack --recipe textfooler --model bert-mr --num-examples 10
textattack attack --recipe textfooler --model bert-mr --num-examples 10
This will run an attack using the TextFooler_ recipe, attacking BERT fine-tuned on the MR dataset. It will attack the first 10 samples. Once everything downloads and starts running, you should see attack results print to ``stdout``.

1
examples/augmentation/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
augment.csv # Don't commit the output file of this command

View File

@@ -0,0 +1,11 @@
text,label
"the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1
"the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1
the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of expression significant adequately describe co-writer/director pedro jackson's expanded vision of j . rs . r . tolkien's middle-earth .,1
the gorgeously elaborate continuation of 'the lordy of the piercings' trilogy is so huge that a column of mots cannot adequately describe co-novelist/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .,1
take care of my cat offerings a pleasantly several slice of asia cinema .,1
taking care of my cat offers a pleasantly different slice of asiatic kino .,1
a technically good-made suspenser . . . but its abrupt drop in iq points as it races to the finish bloodline proves straightforward too disheartening to let slide .,0
a technically well-made suspenser . . . but its abrupt drop in iq dot as it races to the finish line demonstrates simply too disheartening to leave slide .,0
it's a enigma how the film wo be releases in this condition .,0
it's a enigma how the filmmaking wo be publicized in this condition .,0
1 text label
2 the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal. 1
3 the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal. 1
4 the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of expression significant adequately describe co-writer/director pedro jackson's expanded vision of j . rs . r . tolkien's middle-earth . 1
5 the gorgeously elaborate continuation of 'the lordy of the piercings' trilogy is so huge that a column of mots cannot adequately describe co-novelist/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth . 1
6 take care of my cat offerings a pleasantly several slice of asia cinema . 1
7 taking care of my cat offers a pleasantly different slice of asiatic kino . 1
8 a technically good-made suspenser . . . but its abrupt drop in iq points as it races to the finish bloodline proves straightforward too disheartening to let slide . 0
9 a technically well-made suspenser . . . but its abrupt drop in iq dot as it races to the finish line demonstrates simply too disheartening to leave slide . 0
10 it's a enigma how the film wo be releases in this condition . 0
11 it's a enigma how the filmmaking wo be publicized in this condition . 0

View File

@@ -0,0 +1,2 @@
#!/bin/bash
textattack augment --csv examples.csv --input-column text --recipe embedding --num-words-to-swap 4 --transformations-per-example 2 --exclude-original

View File

@@ -0,0 +1,6 @@
"text",label
"the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1
"the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1
"take care of my cat offers a refreshingly different slice of asian cinema .", 1
"a technically well-made suspenser . . . but its abrupt drop in iq points as it races to the finish line proves simply too discouraging to let slide .", 0
"it's a mystery how the movie could be released in this condition .", 0
1 text label
2 the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal. 1
3 the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth . 1
4 take care of my cat offers a refreshingly different slice of asian cinema . 1
5 a technically well-made suspenser . . . but its abrupt drop in iq points as it races to the finish line proves simply too discouraging to let slide . 0
6 it's a mystery how the movie could be released in this condition . 0

View File

@@ -20,3 +20,4 @@ terminaltables
tqdm
visdom
wandb
flair

View File

@@ -28,7 +28,9 @@ setuptools.setup(
"wandb*",
]
),
entry_points={"console_scripts": ["textattack=textattack.__main__:main"],},
entry_points={
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",

View File

@@ -0,0 +1,4 @@
text,label,another_column
"For the last 8 years of his life, Galileo was under house arrest for espousing this man's theory", 5, "some text that needs to be preserved"
"Signer of the Dec. of Indep., framer of the Constitution of Mass., second President of the United States", -13, "the answer to this question is John Adams"
"In the title of an Aesop fable, this insect shared billing with a grasshopper", 1111, "these are from a jeopardy question, and this one is the ant"

View File

@@ -0,0 +1,10 @@
text,label,another_column
"For the last 8 years of his life, Galileo was under house arrest for espousing this man's theory",5,some text that needs to be preserved
"For the last 8 years of his lives, Galileo was under habitation arrest for espousing this man's theory",5,some text that needs to be preserved
"For the last 8 yr of his life, Galileo was under house apprehending for espousing this man's theory",5,some text that needs to be preserved
"Signer of the Dec. of Indep., framer of the Constitution of Mass., second President of the United States",-13,the answer to this question is John Adams
"Signer of the Dec. of Indep., framer of the Constitution of Masse., seconds President of the United States",-13,the answer to this question is John Adams
"Signer of the Dec. of Indep., framer of the Constitutions of Mass., second Chairwoman of the United States",-13,the answer to this question is John Adams
"In the title of an Aesop fable, this insect shared billing with a grasshopper",1111,"these are from a jeopardy question, and this one is the ant"
"Among the title of an Aesop fable, this beetle shared billing with a grasshopper",1111,"these are from a jeopardy question, and this one is the ant"
"In the title of an Aesop fable, this insect exchanging invoice with a grasshopper",1111,"these are from a jeopardy question, and this one is the ant"

View File

@@ -15,6 +15,7 @@ Attack(
(include_unknown_words): True
)
(1): PartOfSpeech(
(tagger_type): nltk
(tagset): universal
(allow_verb_noun_swap): True
)

View File

@@ -0,0 +1,3 @@
charswap (textattack.augmentation.CharSwapAugmenter)
embedding (textattack.augmentation.EmbeddingAugmenter)
wordnet (textattack.augmentation.WordNetAugmenter)

View File

@@ -0,0 +1,66 @@
/.*/Attack(
(search_method): GreedyWordSwapWIR(
(wir_method): unk
)
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
)
(constraints):
(0): WordEmbeddingDistance(
(embedding_type): paragramcf
(min_cos_sim): 0.8
(cased): False
(include_unknown_words): True
)
(1): PartOfSpeech(
(tagger_type): flair
(tagset): universal
(allow_verb_noun_swap): True
)
(2): RepeatModification
(3): StopwordModification
(is_black_box): True
)
--------------------------------------------- Result 1 ---------------------------------------------
1-->[FAILED]
this is a film well worth seeing , talking and singing heads and all .
--------------------------------------------- Result 2 ---------------------------------------------
1-->0
what really surprises about wisegirls is its low-key quality and genuine tenderness .
what really dumbfounded about wisegirls is its low-vital quality and veritable sensibility .
--------------------------------------------- Result 3 ---------------------------------------------
1-->[FAILED]
( wendigo is ) why we go to the cinema : to be fed through the eye , the heart , the mind .
--------------------------------------------- Result 4 ---------------------------------------------
1-->[FAILED]
one of the greatest family-oriented , fantasy-adventure movies ever .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 1 |
| Number of failed attacks: | 3 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 75.0% |
| Attack success rate: | 25.0% |
| Average perturbed word %: | 30.77% |
| Average num. words per input: | 13.5 |
| Avg num queries: | 45.0 |
+-------------------------------+--------+

View File

@@ -7,7 +7,7 @@ Attack(
(constraints): None
(is_black_box): True
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
1-->[SKIPPED]

View File

@@ -17,6 +17,7 @@ Attack(
(include_unknown_words): True
)
(2): PartOfSpeech(
(tagger_type): nltk
(tagset): universal
(allow_verb_noun_swap): True
)

View File

@@ -0,0 +1,38 @@
import shlex
import subprocess
def run_command_and_get_result(command):
""" Runs a command in the console and gets the result.
Command can be a string (single command) or a tuple of strings (multiple
commands). In the multi-command setting, commands will be joined
together with a pipe, and the output of the last command will be
returned.
"""
from subprocess import PIPE
# run command
if isinstance(command, tuple):
# Support pipes via tuple of commands
procs = []
for i in range(len(command) - 1):
if i == 0:
proc = subprocess.Popen(shlex.split(command[i]), stdout=PIPE)
else:
proc = subprocess.Popen(
shlex.split(command[i]),
stdout=subprocess.PIPE,
stdin=procs[-1].stdout,
)
procs.append(proc)
# Run last commmand
result = subprocess.run(
shlex.split(command[-1]), stdin=procs[-1].stdout, stdout=PIPE, stderr=PIPE
)
# Wait for all intermittent processes
for proc in procs:
proc.wait()
else:
result = subprocess.run(shlex.split(command), stdout=PIPE, stderr=PIPE)
return result

View File

@@ -1,14 +1,16 @@
import pdb
import re
import shlex
import subprocess
import pytest
from helpers import run_command_and_get_result
DEBUG = False
"""
Attack command-line tests in the format (name, args, sample_output_file)
"""
attack_test_params = [
#
# test loading an attack from file
@@ -16,7 +18,7 @@ attack_test_params = [
(
"attack_from_file",
(
"python -m textattack --model cnn-imdb "
"textattack attack --model cnn-imdb "
"--attack-from-file tests/sample_inputs/attack_from_file.py:Attack "
"--num-examples 2 --num-examples-offset 18 --attack-n"
),
@@ -29,7 +31,7 @@ attack_test_params = [
"interactive_mode",
(
'printf "All that glitters is not gold\nq\n"',
"python -m textattack --recipe textfooler --model bert-base-uncased-imdb --interactive",
"textattack attack --recipe textfooler --model bert-base-uncased-imdb --interactive",
),
"tests/sample_outputs/interactive_mode.txt",
),
@@ -39,7 +41,7 @@ attack_test_params = [
(
"attack_from_transformers",
(
"python -m textattack --model-from-huggingface "
"textattack attack --model-from-huggingface "
"distilbert-base-uncased-finetuned-sst-2-english "
"--dataset-from-nlp glue:sst2:train --recipe deepwordbug --num-examples 3"
),
@@ -51,7 +53,8 @@ attack_test_params = [
(
"load_model_and_dataset_from_file",
(
"python -m textattack --model-from-file tests/sample_inputs/sst_model_and_dataset.py "
"textattack attack "
"--model-from-file tests/sample_inputs/sst_model_and_dataset.py "
"--dataset-from-file tests/sample_inputs/sst_model_and_dataset.py "
"--recipe deepwordbug --num-examples 3"
),
@@ -63,7 +66,7 @@ attack_test_params = [
(
"run_attack_hotflip_lstm_mr_4",
(
"python -m textattack --model lstm-mr --recipe hotflip "
"textattack attack --model lstm-mr --recipe hotflip "
"--num-examples 4 --num-examples-offset 3"
),
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt",
@@ -74,7 +77,7 @@ attack_test_params = [
(
"run_attack_deepwordbug_lstm_mr_2",
(
"python -m textattack --model lstm-mr --recipe deepwordbug --num-examples 2 --attack-n"
"textattack attack --model lstm-mr --recipe deepwordbug --num-examples 2 --attack-n"
),
"tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt",
),
@@ -87,21 +90,34 @@ attack_test_params = [
(
"run_attack_targeted_mnli_misc",
(
"python -m textattack --attack-n --goal-function targeted-classification:target_class=2 "
"textattack attack --attack-n --goal-function targeted-classification:target_class=2 "
"--enable-csv --model bert-base-uncased-mnli --num-examples 2 --attack-n --transformation word-swap-wordnet "
"--constraints lang-tool repeat stopword --search beam-search:beam_width=2"
),
"tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt",
),
#
# fmt: off
# test: run_attack untargeted classification on BERT MR using word embedding transformation and greedy-word-WIR search
# using Flair's part-of-speech tagger as constraint.
#
(
"run_attack_flair_pos_tagger",
(
"textattack attack --model bert-base-uncased-mr --search greedy-word-wir --transformation word-swap-embedding "
"--constraints repeat stopword embedding:min_cos_sim=0.8 part-of-speech:tagger_type=\\'flair\\' "
"--num-examples 4 --num-examples-offset 10"
),
"tests/sample_outputs/run_attack_flair_pos_tagger.txt",
),
# fmt: on
#
]
@pytest.mark.parametrize("name, command, sample_output_file", attack_test_params)
@pytest.mark.slow
def test_command_line_attack(capsys, name, command, sample_output_file):
def test_command_line_attack(name, command, sample_output_file):
""" Runs attack tests and compares their outputs to a reference file.
"""
# read in file and create regex
@@ -109,29 +125,7 @@ def test_command_line_attack(capsys, name, command, sample_output_file):
print("desired_output =>", desired_output)
# regex in sample file look like /.*/
desired_re = re.escape(desired_output).replace("/\\.\\*/", ".*")
# run command
if isinstance(command, tuple):
# Support pipes via tuple of commands
procs = []
for i in range(len(command) - 1):
if i == 0:
proc = subprocess.Popen(shlex.split(command[i]), stdout=subprocess.PIPE)
else:
proc = subprocess.Popen(
shlex.split(command[i]),
stdout=subprocess.PIPE,
stdin=procs[-1].stdout,
)
procs.append(proc)
# Run last commmand
result = subprocess.run(
shlex.split(command[-1]), stdin=procs[-1].stdout, capture_output=True
)
# Wait for all intermittent processes
for proc in procs:
proc.wait()
else:
result = subprocess.run(shlex.split(command), capture_output=True)
result = run_command_and_get_result(command)
# get output and check match
assert result.stdout is not None
stdout = result.stdout.decode().strip()
@@ -141,7 +135,5 @@ def test_command_line_attack(capsys, name, command, sample_output_file):
print("stderr =>", stderr)
if DEBUG and not re.match(desired_re, stdout, flags=re.S):
import pdb
pdb.set_trace()
assert re.match(desired_re, stdout, flags=re.S)

View File

@@ -0,0 +1,37 @@
import pytest
from helpers import run_command_and_get_result
augment_test_params = [
(
"simple_augment_test",
"textattack augment --csv tests/sample_inputs/augment.csv.txt --input-column text --outfile augment_test.csv --overwrite",
"augment_test.csv",
"tests/sample_outputs/augment_test.csv.txt",
)
]
@pytest.mark.parametrize(
"name, command, outfile, sample_output_file", augment_test_params
)
@pytest.mark.slow
def test_command_line_augmentation(name, command, outfile, sample_output_file):
import os
desired_text = open(sample_output_file).read().strip()
# Run command and validate outputs.
result = run_command_and_get_result(command)
assert result.stdout is not None
stdout = result.stdout.decode().strip()
assert stdout == ""
assert result.stderr is not None
stderr = result.stderr.decode().strip()
assert "Wrote 9 augmentations to augment_test.csv" in stderr
# Ensure CSV file exists, then delete it.
assert os.path.exists(outfile)
os.remove(outfile)

View File

@@ -0,0 +1,28 @@
import pytest
from helpers import run_command_and_get_result
list_test_params = [
(
"list_augmentation_recipes",
"textattack list augmentation-recipes",
"tests/sample_outputs/list_augmentation_recipes.txt",
)
]
@pytest.mark.parametrize("name, command, sample_output_file", list_test_params)
def test_command_line_list(name, command, sample_output_file):
desired_text = open(sample_output_file).read().strip()
# Run command and validate outputs.
result = run_command_and_get_result(command)
assert result.stdout is not None
assert result.stderr is not None
stdout = result.stdout.decode().strip()
stderr = result.stderr.decode().strip()
assert stderr == ""
assert stdout == desired_text

View File

@@ -3,6 +3,7 @@ name = "textattack"
from . import attack_recipes
from . import attack_results
from . import augmentation
from . import commands
from . import constraints
from . import datasets
from . import goal_functions

View File

@@ -1,26 +1,6 @@
#!/usr/bin/env python
"""
The TextAttack main module:
A command line parser to run an attack from user specifications.
"""
from textattack.shared.scripts.attack_args_parser import get_args
from textattack.shared.scripts.run_attack_parallel import run as run_parallel
from textattack.shared.scripts.run_attack_single_threaded import (
run as run_single_threaded,
)
def main():
args = get_args()
if args.parallel:
run_parallel(args)
else:
run_single_threaded(args)
if __name__ == "__main__":
main()
import textattack
textattack.commands.textattack_cli.main()

View File

@@ -87,7 +87,7 @@ class Augmenter:
if words_swapped == self.num_words_to_swap:
break
all_transformed_texts.add(current_text)
return [t.text for t in all_transformed_texts]
return sorted([at.printable_text() for at in all_transformed_texts])
def augment_many(self, text_list, show_progress=False):
"""

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from argparse import ArgumentParser, HelpFormatter
class TextAttackCommand(ABC):
@staticmethod
@abstractmethod
def register_subcommand(parser):
raise NotImplementedError()
@abstractmethod
def run(self):
raise NotImplementedError()
from . import textattack_cli

View File

@@ -0,0 +1,2 @@
from .attack_command import AttackCommand
from .attack_resume_command import AttackResumeCommand

View File

@@ -1,6 +1,6 @@
import textattack
RECIPE_NAMES = {
ATTACK_RECIPE_NAMES = {
"alzantot": "textattack.attack_recipes.Alzantot2018",
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
@@ -244,7 +244,7 @@ CONSTRAINT_CLASS_NAMES = {
"stopword": "textattack.constraints.pre_transformation.StopwordModification",
}
SEARCH_CLASS_NAMES = {
SEARCH_METHOD_CLASS_NAMES = {
"beam-search": "textattack.search_methods.BeamSearch",
"greedy": "textattack.search_methods.GreedySearch",
"ga-word": "textattack.search_methods.GeneticAlgorithm",

View File

@@ -11,35 +11,15 @@ import numpy as np
import torch
import textattack
from textattack.shared.scripts.attack_args import *
from .attack_args import *
def set_seed(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
def get_args():
# Parser for regular arguments
parser = argparse.ArgumentParser(
description="A commandline parser for TextAttack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
)
parser.add_argument(
"--transformation",
type=str,
required=False,
default="word-swap-embedding",
choices=transformation_names,
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}. Choices: '
+ str(transformation_names),
)
def add_model_args(parser):
""" Adds model-related arguments to an argparser. This is useful because we
want to load pretrained models using multiple different parsers that
share these, but not all, arguments.
"""
model_group = parser.add_mutually_exclusive_group()
model_names = list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(
@@ -66,6 +46,12 @@ def get_args():
help="huggingface.co ID of pre-trained model to load",
)
def add_dataset_args(parser):
""" Adds dataset-related arguments to an argparser. This is useful because we
want to load pretrained models using multiple different parsers that
share these, but not all, arguments.
"""
dataset_group = parser.add_mutually_exclusive_group()
dataset_group.add_argument(
"--dataset-from-nlp",
@@ -81,48 +67,13 @@ def get_args():
default=None,
help="Dataset to load from a file.",
)
parser.add_argument(
"--constraints",
type=str,
required=False,
nargs="*",
default=["repeat", "stopword"],
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
+ str(CONSTRAINT_CLASS_NAMES.keys()),
)
parser.add_argument(
"--out-dir",
type=str,
required=False,
default=None,
help="A directory to output results to.",
)
parser.add_argument(
"--enable-visdom", action="store_true", help="Enable logging to visdom."
)
parser.add_argument(
"--enable-wandb",
dataset_group.add_argument(
"--shuffle",
action="store_true",
help="Enable logging to Weights & Biases.",
required=False,
default=False,
help="Randomly shuffle the data before attacking",
)
parser.add_argument(
"--disable-stdout", action="store_true", help="Disable logging to stdout"
)
parser.add_argument(
"--enable-csv",
nargs="?",
default=None,
const="fancy",
type=str,
help="Enable logging to csv. Use --enable-csv plain to remove [[]] around words.",
)
parser.add_argument(
"--num-examples",
"-n",
@@ -141,162 +92,20 @@ def get_args():
help="The offset to start at in the dataset.",
)
parser.add_argument(
"--shuffle",
action="store_true",
required=False,
default=False,
help="Randomly shuffle the data before attacking",
def load_module_from_file(file_path):
""" Uses ``importlib`` to dynamically open a file and load an object from
it.
"""
temp_module_name = f"temp_{time.time()}"
colored_file_path = textattack.shared.utils.color_text(
file_path, color="blue", method="ansi"
)
parser.add_argument(
"--interactive",
action="store_true",
default=False,
help="Whether to run attacks interactively.",
)
parser.add_argument(
"--attack-n",
action="store_true",
default=False,
help="Whether to run attack until `n` examples have been attacked (not skipped).",
)
parser.add_argument(
"--parallel",
action="store_true",
default=False,
help="Run attack using multiple GPUs.",
)
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
parser.add_argument(
"--goal-function",
"-g",
default="untargeted-classification",
help=f"The goal function to use. choices: {goal_function_choices}",
)
def str_to_int(s):
return sum((ord(c) for c in s))
parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"))
parser.add_argument(
"--checkpoint-dir",
required=False,
type=str,
default=default_checkpoint_dir(),
help="The directory to save checkpoint files.",
)
parser.add_argument(
"--checkpoint-interval",
required=False,
type=int,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
parser.add_argument(
"--query-budget",
"-q",
type=int,
default=float("inf"),
help="The maximum number of model queries allowed per example attacked.",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ", ".join(SEARCH_CLASS_NAMES.keys())
attack_group.add_argument(
"--search",
"--search-method",
"-s",
type=str,
required=False,
default="greedy-word-wir",
help=f"The search method to use. choices: {search_choices}",
)
attack_group.add_argument(
"--recipe",
"--attack-recipe",
"-r",
type=str,
required=False,
default=None,
help="full attack recipe (overrides provided goal function, transformation & constraints)",
choices=RECIPE_NAMES.keys(),
)
attack_group.add_argument(
"--attack-from-file",
type=str,
required=False,
default=None,
help="attack to load from file (overrides provided goal function, transformation & constraints)",
)
# Parser for parsing args for resume
resume_parser = argparse.ArgumentParser(
description="A commandline parser for TextAttack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
resume_parser.add_argument(
"--checkpoint-file",
"-f",
type=str,
required=True,
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'
"recover latest checkpoint from either current path or specified directory.",
)
resume_parser.add_argument(
"--checkpoint-dir",
"-d",
required=False,
type=str,
default=None,
help="The directory to save checkpoint files. If not set, use directory from recovered arguments.",
)
resume_parser.add_argument(
"--checkpoint-interval",
"-i",
required=False,
type=int,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
resume_parser.add_argument(
"--parallel",
action="store_true",
default=False,
help="Run attack using multiple GPUs.",
)
# Resume attack from checkpoint.
if sys.argv[1:] and sys.argv[1].lower() == "resume":
args = resume_parser.parse_args(sys.argv[2:])
setattr(args, "checkpoint_resume", True)
else:
command_line_args = (
None if sys.argv[1:] else ["-h"]
) # Default to help with empty arguments.
args = parser.parse_args(command_line_args)
setattr(args, "checkpoint_resume", False)
if args.checkpoint_interval and args.shuffle:
# Not allowed b/c we cannot recover order of shuffled data
raise ValueError("Cannot use `--checkpoint-interval` with `--shuffle=True`")
set_seed(args.random_seed)
# Shortcuts for huggingface models using --model.
if not args.checkpoint_resume and args.model in HUGGINGFACE_DATASET_BY_MODEL:
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
elif not args.checkpoint_resume and args.model in TEXTATTACK_DATASET_BY_MODEL:
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
return args
textattack.shared.logger.info(f"Loading module from `{colored_file_path}`.")
spec = importlib.util.spec_from_file_location(temp_module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def parse_transformation_from_args(args, model):
@@ -374,11 +183,11 @@ def parse_attack_from_args(args):
if args.recipe:
if ":" in args.recipe:
recipe_name, params = args.recipe.split(":")
if recipe_name not in RECIPE_NAMES:
if recipe_name not in ATTACK_RECIPE_NAMES:
raise ValueError(f"Error: unsupported recipe {recipe_name}")
recipe = eval(f"{RECIPE_NAMES[recipe_name]}(model, {params})")
elif args.recipe in RECIPE_NAMES:
recipe = eval(f"{RECIPE_NAMES[args.recipe]}(model)")
recipe = eval(f"{ATTACK_RECIPE_NAMES[recipe_name]}(model, {params})")
elif args.recipe in ATTACK_RECIPE_NAMES:
recipe = eval(f"{ATTACK_RECIPE_NAMES[args.recipe]}(model)")
else:
raise ValueError(f"Invalid recipe {args.recipe}")
recipe.goal_function.query_budget = args.query_budget
@@ -388,8 +197,11 @@ def parse_attack_from_args(args):
attack_file, attack_name = args.attack_from_file.split(":")
else:
attack_file, attack_name = args.attack_from_file, "attack"
attack_file = attack_file.replace(".py", "").replace("/", ".")
attack_module = importlib.import_module(attack_file)
attack_module = load_module_from_file(attack_file)
if not hasattr(attack_module, attack_name):
raise ValueError(
f"Loaded `{attack_file}` but could not find `{attack_name}`."
)
attack_func = getattr(attack_module, attack_name)
return attack_func(model)
else:
@@ -398,11 +210,11 @@ def parse_attack_from_args(args):
constraints = parse_constraints_from_args(args)
if ":" in args.search:
search_name, params = args.search.split(":")
if search_name not in SEARCH_CLASS_NAMES:
if search_name not in SEARCH_METHOD_CLASS_NAMES:
raise ValueError(f"Error: unsupported search {search_name}")
search_method = eval(f"{SEARCH_CLASS_NAMES[search_name]}({params})")
elif args.search in SEARCH_CLASS_NAMES:
search_method = eval(f"{SEARCH_CLASS_NAMES[args.search]}()")
search_method = eval(f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})")
elif args.search in SEARCH_METHOD_CLASS_NAMES:
search_method = eval(f"{SEARCH_METHOD_CLASS_NAMES[args.search]}()")
else:
raise ValueError(f"Error: unsupported attack {args.search}")
return textattack.shared.Attack(
@@ -427,12 +239,9 @@ def parse_model_from_args(args):
"tokenizer",
)
try:
model_file = args.model_from_file.replace(".py", "").replace("/", ".")
model_module = importlib.import_module(model_file)
model_module = load_module_from_file(args.model_from_file)
except:
raise ValueError(
f"Failed to import model or tokenizer from file {args.model_from_file}"
)
raise ValueError(f"Failed to import file {args.model_from_file}")
try:
model = getattr(model_module, model_name)
except AttributeError:
@@ -492,6 +301,14 @@ def parse_model_from_args(args):
def parse_dataset_from_args(args):
# Automatically detect dataset for huggingface & textattack models.
# This allows us to use the --model shortcut without specifying a dataset.
if args.model in HUGGINGFACE_DATASET_BY_MODEL:
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
# Get dataset from args.
if args.dataset_from_file:
textattack.shared.logger.info(
f"Loading model and tokenizer from file: {args.model_from_file}"
@@ -501,8 +318,7 @@ def parse_dataset_from_args(args):
else:
dataset_file, dataset_name = args.dataset_from_file, "dataset"
try:
dataset_file = dataset_file.replace(".py", "").replace("/", ".")
dataset_module = importlib.import_module(dataset_file)
dataset_module = load_module_from_file(dataset_file)
except:
raise ValueError(
f"Failed to import dataset from file {args.dataset_from_file}"

View File

@@ -0,0 +1,184 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import textattack
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args import *
from textattack.commands.attack.attack_args_helpers import *
class AttackCommand(TextAttackCommand):
"""
The TextAttack attack module:
A command line parser to run an attack from user specifications.
"""
def run(self, args):
if args.checkpoint_interval and args.shuffle:
# Not allowed b/c we cannot recover order of shuffled data
raise ValueError("Cannot use `--checkpoint-interval` with `--shuffle=True`")
textattack.shared.utils.set_seed(args.random_seed)
args.checkpoint_resume = False
from textattack.commands.attack.run_attack_parallel import run as run_parallel
from textattack.commands.attack.run_attack_single_threaded import (
run as run_single_threaded,
)
if args.parallel:
run_parallel(args)
else:
run_single_threaded(args)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"attack",
help="run an attack on an NLP model",
formatter_class=ArgumentDefaultsHelpFormatter,
)
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
)
parser.add_argument(
"--transformation",
type=str,
required=False,
default="word-swap-embedding",
choices=transformation_names,
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}. Choices: '
+ str(transformation_names),
)
add_model_args(parser)
add_dataset_args(parser)
parser.add_argument(
"--constraints",
type=str,
required=False,
nargs="*",
default=["repeat", "stopword"],
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
+ str(CONSTRAINT_CLASS_NAMES.keys()),
)
parser.add_argument(
"--out-dir",
type=str,
required=False,
default=None,
help="A directory to output results to.",
)
parser.add_argument(
"--enable-visdom", action="store_true", help="Enable logging to visdom."
)
parser.add_argument(
"--enable-wandb",
action="store_true",
help="Enable logging to Weights & Biases.",
)
parser.add_argument(
"--disable-stdout", action="store_true", help="Disable logging to stdout"
)
parser.add_argument(
"--enable-csv",
nargs="?",
default=None,
const="fancy",
type=str,
help="Enable logging to csv. Use --enable-csv plain to remove [[]] around words.",
)
parser.add_argument(
"--interactive",
action="store_true",
default=False,
help="Whether to run attacks interactively.",
)
parser.add_argument(
"--attack-n",
action="store_true",
default=False,
help="Whether to run attack until `n` examples have been attacked (not skipped).",
)
parser.add_argument(
"--parallel",
action="store_true",
default=False,
help="Run attack using multiple GPUs.",
)
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
parser.add_argument(
"--goal-function",
"-g",
default="untargeted-classification",
help=f"The goal function to use. choices: {goal_function_choices}",
)
def str_to_int(s):
return sum((ord(c) for c in s))
parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"))
parser.add_argument(
"--checkpoint-dir",
required=False,
type=str,
default=default_checkpoint_dir(),
help="The directory to save checkpoint files.",
)
parser.add_argument(
"--checkpoint-interval",
required=False,
type=int,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
parser.add_argument(
"--query-budget",
"-q",
type=int,
default=float("inf"),
help="The maximum number of model queries allowed per example attacked.",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
attack_group.add_argument(
"--search",
"--search-method",
"-s",
type=str,
required=False,
default="greedy-word-wir",
help=f"The search method to use. choices: {search_choices}",
)
attack_group.add_argument(
"--recipe",
"--attack-recipe",
"-r",
type=str,
required=False,
default=None,
help="full attack recipe (overrides provided goal function, transformation & constraints)",
choices=ATTACK_RECIPE_NAMES.keys(),
)
attack_group.add_argument(
"--attack-from-file",
type=str,
required=False,
default=None,
help="attack to load from file (overrides provided goal function, transformation & constraints)",
)
parser.set_defaults(func=AttackCommand())

View File

@@ -0,0 +1,70 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from textattack.commands import TextAttackCommand
class AttackResumeCommand(TextAttackCommand):
"""
The TextAttack attack resume recipe module:
A command line parser to resume a checkpointed attack from user specifications.
"""
def run(self):
textattack.shared.utils.set_seed(self.random_seed)
self.checkpoint_resume = True
# Run attack from checkpoint.
from textattack.commands.attack.run_attack_parallel import run as run_parallel
from textattack.commands.attack.run_attack_single_threaded import (
run as run_single_threaded,
)
if self.parallel:
run_parallel(self)
else:
run_single_threaded(self)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
resume_parser = main_parser.add_parser(
"attack-resume",
help="resume a checkpointed attack",
formatter_class=ArgumentDefaultsHelpFormatter,
)
# Parser for parsing args for resume
resume_parser.add_argument(
"--checkpoint-file",
"-f",
type=str,
required=True,
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'
"recover latest checkpoint from either current path or specified directory.",
)
resume_parser.add_argument(
"--checkpoint-dir",
"-d",
required=False,
type=str,
default=None,
help="The directory to save checkpoint files. If not set, use directory from recovered arguments.",
)
resume_parser.add_argument(
"--checkpoint-interval",
"-i",
required=False,
type=int,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
resume_parser.add_argument(
"--parallel",
action="store_true",
default=False,
help="Run attack using multiple GPUs.",
)
resume_parser.set_defaults(func=AttackResumeCommand())

View File

@@ -11,7 +11,7 @@ import tqdm
import textattack
from .attack_args_parser import *
from .attack_args_helpers import *
logger = textattack.shared.logger

View File

@@ -11,7 +11,7 @@ import tqdm
import textattack
from .attack_args_parser import *
from .attack_args_helpers import *
logger = textattack.shared.logger

View File

@@ -0,0 +1,143 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import csv
import os
import time
import tqdm
import textattack
from textattack.commands import TextAttackCommand
AUGMENTATION_RECIPE_NAMES = {
"wordnet": "textattack.augmentation.WordNetAugmenter",
"embedding": "textattack.augmentation.EmbeddingAugmenter",
"charswap": "textattack.augmentation.CharSwapAugmenter",
}
class AugmentCommand(TextAttackCommand):
"""
The TextAttack attack module:
A command line parser to run data augmentation from user specifications.
"""
def run(self, args):
""" Reads in a CSV, performs augmentation, and outputs an augmented CSV.
Preserves all columns except for the input (augmneted) column.
"""
textattack.shared.utils.set_seed(args.random_seed)
start_time = time.time()
# Validate input/output paths.
if not os.path.exists(args.csv):
raise FileNotFoundError(f"Can't find CSV at location {args.csv}")
if os.path.exists(args.outfile):
if args.overwrite:
textattack.shared.logger.info(f"Preparing to overwrite {args.outfile}.")
else:
raise OSError(f"Outfile {args.outfile} exists and --overwrite not set.")
# Read in CSV file as a list of dictionaries. Use the CSV sniffer to
# try and automatically infer the correct CSV format.
csv_file = open(args.csv, "r")
dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,")
csv_file.seek(0)
rows = [
row
for row in csv.DictReader(csv_file, dialect=dialect, skipinitialspace=True)
]
# Validate input column.
row_keys = set(rows[0].keys())
if args.input_column not in row_keys:
raise ValueError(
f"Could not find input column {args.input_column} in CSV. Found keys: {row_keys}"
)
textattack.shared.logger.info(
f"Read {len(rows)} rows from {args.csv}. Found columns {row_keys}."
)
# Augment all examples.
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
num_words_to_swap=args.num_words_to_swap,
transformations_per_example=args.transformations_per_example,
)
output_rows = []
for row in tqdm.tqdm(rows, desc="Augmenting rows"):
text_input = row[args.input_column]
if not args.exclude_original:
output_rows.append(row)
for augmentation in augmenter.augment(text_input):
augmented_row = row.copy()
augmented_row[args.input_column] = augmentation
output_rows.append(augmented_row)
# Print to file.
with open(args.outfile, "w") as outfile:
csv_writer = csv.writer(
outfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
)
# Write header.
csv_writer.writerow(output_rows[0].keys())
# Write rows.
for row in output_rows:
csv_writer.writerow(row.values())
textattack.shared.logger.info(
f"Wrote {len(output_rows)} augmentations to {args.outfile} in {time.time() - start_time}s."
)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"augment",
help="augment text data",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--csv", help="input csv file to augment", type=str, required=True
)
parser.add_argument(
"--input-column",
"--i",
help="csv input column to be augmented",
type=str,
required=True,
)
parser.add_argument(
"--recipe",
"--r",
help="recipe for augmentation",
type=str,
default="embedding",
choices=AUGMENTATION_RECIPE_NAMES.keys(),
)
parser.add_argument(
"--num-words-to-swap",
"--n",
help="words to swap out for each augmented example",
type=int,
default=2,
)
parser.add_argument(
"--transformations-per-example",
"--t",
help="number of augmentations to return for each input",
type=int,
default=2,
)
parser.add_argument(
"--outfile", "--o", help="path to outfile", type=str, default="augment.csv"
)
parser.add_argument(
"--exclude-original",
default=False,
action="store_true",
help="exclude original example from augmented CSV",
)
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="overwrite output file, if it exists",
)
parser.add_argument(
"--random-seed", default=42, type=int, help="random seed to set"
)
parser.set_defaults(func=AugmentCommand())

View File

@@ -0,0 +1,100 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import torch
import textattack
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args import *
from textattack.commands.attack.attack_args_helpers import *
def _cb(s):
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
class BenchmarkModelCommand(TextAttackCommand):
"""
The TextAttack model benchmarking module:
A command line parser to benchmark a model from user specifications.
"""
def get_num_successes(self, model, ids, true_labels):
with torch.no_grad():
preds = textattack.shared.utils.model_predict(model, ids)
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.device)
guess_labels = preds.argmax(dim=1)
successes = (guess_labels == true_labels).sum().item()
return successes, true_labels, guess_labels
def test_model_on_dataset(self, args):
model = parse_model_from_args(args)
dataset = parse_dataset_from_args(args)
succ = 0
fail = 0
batch_ids = []
batch_labels = []
all_true_labels = []
all_guess_labels = []
for i, (text_input, label) in enumerate(dataset):
if i >= args.num_examples:
break
attacked_text = textattack.shared.AttackedText(text_input)
ids = model.tokenizer.encode(attacked_text.tokenizer_input)
batch_ids.append(ids)
batch_labels.append(label)
if len(batch_ids) == args.batch_size:
batch_succ, true_labels, guess_labels = self.get_num_successes(
model, batch_ids, batch_labels
)
batch_fail = args.batch_size - batch_succ
succ += batch_succ
fail += batch_fail
batch_ids = []
batch_labels = []
all_true_labels.extend(true_labels.tolist())
all_guess_labels.extend(guess_labels.tolist())
if len(batch_ids) > 0:
batch_succ, true_labels, guess_labels = self.get_num_successes(
model, batch_ids, batch_labels
)
batch_fail = len(batch_ids) - batch_succ
succ += batch_succ
fail += batch_fail
all_true_labels.extend(true_labels.tolist())
all_guess_labels.extend(guess_labels.tolist())
perc = float(succ) / (succ + fail) * 100.0
perc = "{:.2f}%".format(perc)
print(f"Successes {succ}/{succ+fail} ({_cb(perc)})")
return perc
def run(self, args):
# Default to 'all' if no model chosen.
if not (args.model or args.model_from_huggingface or args.model_from_file):
for model_name in list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(
TEXTATTACK_DATASET_BY_MODEL.keys()
):
args.model = model_name
self.test_model_on_dataset(args)
else:
self.test_model_on_dataset(args)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"benchmark-model",
help="evaluate a model with TextAttack",
formatter_class=ArgumentDefaultsHelpFormatter,
)
add_model_args(parser)
add_dataset_args(parser)
parser.add_argument(
"--batch-size",
type=int,
default=256,
help="Batch size for model inference.",
)
parser.set_defaults(func=BenchmarkModelCommand())

View File

@@ -0,0 +1,23 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from textattack.commands import TextAttackCommand
class BenchmarkRecipeCommand(TextAttackCommand):
"""
The TextAttack benchmark recipe module:
A command line parser to benchmark a recipe from user specifications.
"""
def run(self, args):
raise NotImplementedError("Cannot benchmark recipes yet - stay tuned!!")
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"benchmark-recipe",
help="benchmark a recipe",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.set_defaults(func=BenchmarkRecipeCommand())

View File

@@ -0,0 +1,66 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args import *
from textattack.commands.augment import AUGMENTATION_RECIPE_NAMES
def _cb(s):
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
class ListThingsCommand(TextAttackCommand):
"""
The list module:
List default things in textattack.
"""
def _list(self, list_of_things):
""" Prints a list or dict of things. """
if isinstance(list_of_things, list):
list_of_things = sorted(list_of_things)
for thing in list_of_things:
print(_cb(thing))
elif isinstance(list_of_things, dict):
for thing in sorted(list_of_things.keys()):
thing_long_description = list_of_things[thing]
print(f"{_cb(thing)} ({thing_long_description})")
else:
raise TypeError(f"Cannot print list of type {type(list_of_things)}")
@staticmethod
def things():
list_dict = {}
list_dict["models"] = list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(
TEXTATTACK_DATASET_BY_MODEL.keys()
)
list_dict["search-methods"] = SEARCH_METHOD_CLASS_NAMES
list_dict["transformations"] = {
**BLACK_BOX_TRANSFORMATION_CLASS_NAMES,
**WHITE_BOX_TRANSFORMATION_CLASS_NAMES,
}
list_dict["constraints"] = CONSTRAINT_CLASS_NAMES
list_dict["goal-functions"] = GOAL_FUNCTION_CLASS_NAMES
list_dict["attack-recipes"] = ATTACK_RECIPE_NAMES
list_dict["augmentation-recipes"] = AUGMENTATION_RECIPE_NAMES
return list_dict
def run(self, args):
try:
list_of_things = ListThingsCommand.things()[args.feature]
except KeyError:
raise ValuError(f"Unknown list key {args.thing}")
self._list(list_of_things)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"list",
help="list features in TextAttack",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"feature", help=f"the feature to list", choices=ListThingsCommand.things()
)
parser.set_defaults(func=ListThingsCommand())

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python
import argparse
import os
import sys
from textattack.commands.attack import AttackCommand, AttackResumeCommand
from textattack.commands.augment import AugmentCommand
from textattack.commands.benchmark_model import BenchmarkModelCommand
from textattack.commands.benchmark_recipe import BenchmarkRecipeCommand
from textattack.commands.list_things import ListThingsCommand
from textattack.commands.train_model import TrainModelCommand
def main():
parser = argparse.ArgumentParser(
"TextAttack CLI",
usage="[python -m] texattack <command> [<args>]",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
subparsers = parser.add_subparsers(help="textattack command helpers")
# Register commands
AttackCommand.register_subcommand(subparsers)
AttackResumeCommand.register_subcommand(subparsers)
AugmentCommand.register_subcommand(subparsers)
BenchmarkModelCommand.register_subcommand(subparsers)
BenchmarkRecipeCommand.register_subcommand(subparsers)
ListThingsCommand.register_subcommand(subparsers)
TrainModelCommand.register_subcommand(subparsers)
# Let's go
args = parser.parse_args()
if not hasattr(args, "func"):
parser.print_help()
exit(1)
# Run
args.func.run(args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,18 @@
from argparse import ArgumentParser
from textattack.commands import TextAttackCommand
class TrainModelCommand(TextAttackCommand):
"""
The TextAttack train module:
A command line parser to train a model from user specifications.
"""
def run(self, args):
raise NotImplementedError("Cannot train models yet - stay tuned!!")
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser("train", help="train a model")

View File

@@ -1,3 +1,5 @@
from flair.data import Sentence
from flair.models import SequenceTagger
import lru
import nltk
@@ -10,13 +12,24 @@ class PartOfSpeech(Constraint):
""" Constraints word swaps to only swap words with the same part of speech.
Uses the NLTK universal part-of-speech tagger by default.
An implementation of `<https://arxiv.org/abs/1907.11932>`_
adapted from `<https://github.com/jind11/TextFooler>`_.
adapted from `<https://github.com/jind11/TextFooler>`_.
POS tagger from Flair `<https://github.com/flairNLP/flair>` also available
"""
def __init__(self, tagset="universal", allow_verb_noun_swap=True):
def __init__(
self, tagger_type="nltk", tagset="universal", allow_verb_noun_swap=True
):
self.tagger_type = tagger_type
self.tagset = tagset
self.allow_verb_noun_swap = allow_verb_noun_swap
self._pos_tag_cache = lru.LRU(2 ** 14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
else:
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
def _can_replace_pos(self, pos_a, pos_b):
return (pos_a == pos_b) or (
@@ -27,11 +40,24 @@ class PartOfSpeech(Constraint):
context_words = before_ctx + [word] + after_ctx
context_key = " ".join(context_words)
if context_key in self._pos_tag_cache:
pos_list = self._pos_tag_cache[context_key]
word_list, pos_list = self._pos_tag_cache[context_key]
else:
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
self._pos_tag_cache[context_key] = pos_list
return pos_list
if self.tagger_type == "nltk":
word_list, pos_list = zip(
*nltk.pos_tag(context_words, tagset=self.tagset)
)
if self.tagger_type == "flair":
word_list, pos_list = zip_flair_result(
self._flair_pos_tagger.predict(context_key)[0]
)
self._pos_tag_cache[context_key] = (word_list, pos_list)
# idx of `word` in `context_words`
idx = len(before_ctx)
assert word_list[idx] == word, "POS list not matched with original word list."
return pos_list[idx]
def _check_constraint(self, transformed_text, current_text, original_text=None):
try:
@@ -45,7 +71,7 @@ class PartOfSpeech(Constraint):
current_word = current_text.words[i]
transformed_word = transformed_text.words[i]
before_ctx = current_text.words[max(i - 4, 0) : i]
after_ctx = current_text.words[i + 1 : min(i + 5, len(current_text.words))]
after_ctx = current_text.words[i + 1 : min(i + 4, len(current_text.words))]
cur_pos = self._get_pos(before_ctx, current_word, after_ctx)
replace_pos = self._get_pos(before_ctx, transformed_word, after_ctx)
if not self._can_replace_pos(cur_pos, replace_pos):
@@ -57,4 +83,18 @@ class PartOfSpeech(Constraint):
return transformation_consists_of_word_swaps(transformation)
def extra_repr_keys(self):
return ["tagset", "allow_verb_noun_swap"]
return ["tagger_type", "tagset", "allow_verb_noun_swap"]
def zip_flair_result(pred):
if not isinstance(pred, Sentence):
raise TypeError(f"Result from Flair POS tagger must be a `Sentence` object.")
tokens = pred.tokens
word_list = []
pos_list = []
for token in tokens:
word_list.append(token.text)
pos_list.append(token.annotation_layers["pos"][0]._value)
return word_list, pos_list

View File

@@ -1,4 +1,3 @@
from . import scripts
from . import utils
from .utils import logger
from . import validators

View File

@@ -1,16 +0,0 @@
import os
from textattack.shared.scripts.attack_args import (
HUGGINGFACE_DATASET_BY_MODEL,
TEXTATTACK_DATASET_BY_MODEL,
)
if __name__ == "__main__":
dir_path = os.path.dirname(os.path.realpath(__file__))
for model in {**TEXTATTACK_DATASET_BY_MODEL, **HUGGINGFACE_DATASET_BY_MODEL}:
print(model)
os.system(
f'python {os.path.join(dir_path, "benchmark_model.py")} --model {model} --num-examples 1000'
)
print()

View File

@@ -1,73 +0,0 @@
import argparse
import collections
import sys
import torch
from attack_args_parser import get_args, parse_dataset_from_args, parse_model_from_args
import textattack
def _cb(s):
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
def get_num_successes(args, model, ids, true_labels):
with torch.no_grad():
preds = textattack.shared.utils.model_predict(model, ids)
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.device)
guess_labels = preds.argmax(dim=1)
successes = (guess_labels == true_labels).sum().item()
return successes, true_labels, guess_labels
def test_model_on_dataset(args, model, dataset, num_examples=100, batch_size=128):
num_examples = args.num_examples
succ = 0
fail = 0
batch_ids = []
batch_labels = []
all_true_labels = []
all_guess_labels = []
for i, (text_input, label) in enumerate(dataset):
if i >= num_examples:
break
attacked_text = textattack.shared.AttackedText(text_input)
ids = model.tokenizer.encode(attacked_text.tokenizer_input)
batch_ids.append(ids)
batch_labels.append(label)
if len(batch_ids) == batch_size:
batch_succ, true_labels, guess_labels = get_num_successes(
args, model, batch_ids, batch_labels
)
batch_fail = batch_size - batch_succ
succ += batch_succ
fail += batch_fail
batch_ids = []
batch_labels = []
all_true_labels.extend(true_labels.tolist())
all_guess_labels.extend(guess_labels.tolist())
if len(batch_ids) > 0:
batch_succ, true_labels, guess_labels = get_num_successes(
args, model, batch_ids, batch_labels
)
batch_fail = len(batch_ids) - batch_succ
succ += batch_succ
fail += batch_fail
all_true_labels.extend(true_labels.tolist())
all_guess_labels.extend(guess_labels.tolist())
perc = float(succ) / (succ + fail) * 100.0
perc = "{:.2f}%".format(perc)
print(f"Successes {succ}/{succ+fail} ({_cb(perc)})")
return perc
if __name__ == "__main__":
args = get_args()
model = parse_model_from_args(args)
dataset = parse_dataset_from_args(args)
with torch.no_grad():
test_model_on_dataset(args, model, dataset, num_examples=args.num_examples)

View File

@@ -1,3 +1,6 @@
import random
import numpy as np
import torch
import textattack
@@ -77,3 +80,9 @@ def load_textattack_model_from_path(model_name, model_path):
else:
raise ValueError(f"Unknown textattack model {model_path}")
return model
def set_seed(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)