1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
This commit is contained in:
Jack Morris
2020-07-06 14:57:08 -04:00
159 changed files with 14211 additions and 1472 deletions

34
.github/workflows/check-formatting.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Formatting with black & isort
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install black flake8 isort # Testing packages
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e .
- name: Check code format with black and isort
run: |
make lint

37
.github/workflows/make-docs.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Build documentation with Sphinx
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo sed -i 's/azure\.//' /etc/apt/sources.list # workaround for flaky pandoc install
sudo apt-get update # from here https://github.com/actions/virtual-environments/issues/675
sudo apt-get install pandoc -o Acquire::Retries=3 # install pandoc
python -m pip install --upgrade pip setuptools wheel # update python
pip install ipython --upgrade # needed for Github for whatever reason
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e . ".[dev]" # This should install all packages for development
- name: Build docs with Sphinx and check for errors
run: |
sphinx-build -b html docs docs/_build/html -W

31
.github/workflows/publish-to-pypi.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Upload Python Package to PyPI
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*

34
.github/workflows/run-pytest.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Test with PyTest
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install pytest pytest-xdist # Testing packages
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e .
- name: Test with pytest
run: |
pytest tests -vx --dist=loadfile -n auto

View File

@@ -9,6 +9,10 @@ It also helps us if you spread the word: reference the library from blog posts
on the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply star the repo to say "thank you".
## Slack Channel
For help and realtime updates related to TextAttack, please [join the TextAttack Slack](https://join.slack.com/t/textattack/shared_invite/zt-ez3ts03b-Nr55tDiqgAvCkRbbz8zz9g)!
## Ways to contribute
There are lots of ways you can contribute to TextAttack:
@@ -113,7 +117,7 @@ Follow these steps to start contributing:
```bash
$ cd TextAttack
$ pip install -e .
$ pip install -e . ".[dev]"
$ pip install black isort pytest pytest-xdist
```
@@ -175,11 +179,25 @@ Follow these steps to start contributing:
$ git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied (**and the checklist below is happy too**), go to the
6. Add documentation.
Our docs are in the `docs/` folder. Thanks to `sphinx-automodule`, this
should just be two lines. Our docs will automatically generate from the
comments you added to your code. If you're adding an attack recipe, add a
reference in `attack_recipes.rst`. If you're adding a transformation, add
a reference in `transformation.rst`, etc.
You can build the docs and view the updates using `make docs`. If you're
adding a tutorial or something where you want to update the docs multiple
times, you can run `make docs-auto`. This will run a server using
`sphinx-autobuild` that should automatically reload whenever you change
a file.
7. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors
8. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.

View File

@@ -1,22 +1,26 @@
format: FORCE ## Run black and isort (rewriting files)
black .
isort --atomic --recursive tests textattack
isort --atomic tests textattack
lint: FORCE ## Run black (in check mode)
lint: FORCE ## Run black, isort, flake8 (in check mode)
black . --check
isort --check-only --recursive tests textattack
isort --check-only tests textattack
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8
test: FORCE ## Run tests using pytest
python -m pytest -qx --dist=loadfile -n auto
python -m pytest --dist=loadfile -n auto
docs: FORCE ## Build docs using Sphinx.
sphinx-build -b html docs docs/_build/html
docs-check: FORCE ## Builds docs using Sphinx. If there is an error, exit with an error code (instead of warning & continuing).
sphinx-build -b html docs docs/_build/html -W
docs-auto: FORCE ## Build docs using Sphinx and run hotreload server using Sphinx autobuild.
sphinx-autobuild docs docs/_build/html -H 0.0.0.0 -p 8765
all: format lint test ## Format, lint, and test.
all: format lint docs-check test ## Format, lint, and test.
.PHONY: help

110
README.md
View File

@@ -1,28 +1,41 @@
<h1 align="center">TextAttack 🐙</h1>
<p align="center">Generating adversarial examples for NLP models</p>
<p align="center">
<a href="https://textattack.readthedocs.io/">Docs</a>
<a href="https://textattack.readthedocs.io/">[TextAttack Documentation on ReadTheDocs]</a>
<br> <br>
<a href="#about">About</a>
<a href="#setup">Setup</a>
<a href="#usage">Usage</a>
<a href="#design">Design</a>
<br> <br>
<a target="_blank" href="https://travis-ci.org/QData/TextAttack">
<img src="https://travis-ci.org/QData/TextAttack.svg?branch=master" alt="Coverage Status">
<a target="_blank">
<img src="https://github.com/QData/TextAttack/workflows/Github%20PyTest/badge.svg" alt="Github Runner Covergae Status">
</a>
<a href="https://badge.fury.io/py/textattack">
<img src="https://badge.fury.io/py/textattack.svg" alt="PyPI version" height="18">
</a>
</p>
<img src="http://jackxmorris.com/files/textattack.gif" alt="TextAttack Demo GIF" style="display: block; margin: 0 auto;" />
## About
TextAttack is a Python framework for running adversarial attacks against NLP models. TextAttack builds attacks from four components: a search method, goal function, transformation, and set of constraints. TextAttack's modular design makes it easily extensible to new NLP tasks, models, and attack strategies. TextAttack currently supports attacks on models trained for classification, entailment, and translation.
TextAttack is a Python framework for adversarial attacks, data augmentation, and model training in NLP.
## Slack Channel
For help and realtime updates related to TextAttack, please [join the TextAttack Slack](https://join.slack.com/t/textattack/shared_invite/zt-ez3ts03b-Nr55tDiqgAvCkRbbz8zz9g)!
### *Why TextAttack?*
There are lots of reasons to use TextAttack:
1. **Understand NLP models better** by running different adversarial attacks on them and examining the output
2. **Research and develop different NLP adversarial attacks** using the TextAttack framework and library of components
3. **Augment your dataset** to increase model generalization and robustness downstream
3. **Train NLP models** using just a single command (all downloads included!)
## Setup
@@ -35,12 +48,11 @@ pip install textattack
```
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
or via the python module (`python -m textattack ...`).
or via python module (`python -m textattack ...`).
### Configuration
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
environment variable `TA_CACHE_DIR`.
> **Tip**: TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
> environment variable `TA_CACHE_DIR`. (for example: `TA_CACHE_DIR=/tmp/ textattack attack ...`).
## Usage
@@ -49,11 +61,15 @@ common commands are `textattack attack <args>`, and `textattack augment <args>`.
information about all commands using `textattack --help`, or a specific command using, for example,
`textattack attack --help`.
The [`examples/`](examples/) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. The[documentation website](https://textattack.readthedocs.io/en/latest) contains walkthroughs explaining basic usage of TextAttack, including building a custom transformation and a custom constraint..
### Running Attacks
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
The easiest way to try out an attack is via the command-line interface, `textattack attack`.
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:
> **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance.
Here are some concrete examples:
*TextFooler on an LSTM trained on the MR sentiment classification dataset*:
```bash
@@ -73,15 +89,7 @@ textattack attack --model lstm-mr --num-examples 20 \
--goal-function untargeted-classification
```
*Non-overlapping output attack using a greedy word swap and WordNet word substitutions on T5 English-to-German translation:*
```bash
textattack attack --attack-n --goal-function non-overlapping-output \
--model t5-en2de --num-examples 10 --transformation word-swap-wordnet \
--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword \
--search greedy
```
> **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance.
> **Tip:** Instead of specifying a dataset and number of examples, you can pass `--interactive` to attack samples inputted by the user.
### Attacks and Papers Implemented ("Attack Recipes")
@@ -89,16 +97,19 @@ We include attack recipes which implement attacks from the literature. You can l
To run an attack recipe: `textattack attack --recipe [recipe_name]`
The first are for classification tasks, like sentiment classification and entailment:
Attacks on classification tasks, like sentiment classification and entailment:
- **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)).
- **bae**: BERT masked language model transformation attack from (["BAE: BERT-based Adversarial Examples for Text Classification" (Garg & Ramakrishnan, 2019)](https://arxiv.org/abs/2004.01970)).
- **bert-attack**: BERT masked language model transformation attack with subword replacements (["BERT-ATTACK: Adversarial Attack Against BERT Using BERT" (Li et al., 2020)](https://arxiv.org/abs/2004.09984)).
- **deepwordbug**: Greedy replace-1 scoring and multi-transformation character-swap attack (["Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers" (Gao et al., 2018)](https://arxiv.org/abs/1801.04354)).
- **hotflip**: Beam search and gradient-based word swap (["HotFlip: White-Box Adversarial Examples for Text Classification" (Ebrahimi et al., 2017)](https://arxiv.org/abs/1712.06751)).
- **input-reduction**: Reducing the input while maintaining the prediction through word importance ranking (["Pathologies of Neural Models Make Interpretation Difficult" (Feng et al., 2018)](https://arxiv.org/pdf/1804.07781.pdf)).
- **kuleshov**: Greedy search and counterfitted embedding swap (["Adversarial Examples for Natural Language Classification Problems" (Kuleshov et al., 2018)](https://openreview.net/pdf?id=r1QZ3zbAZ)).
- **pwws**: Greedy attack with word importance ranking based on word saliency and synonym swap scores (["Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency" (Ren et al., 2019)](https://www.aclweb.org/anthology/P19-1103/)).
- **textbugger**: Greedy attack with word importance ranking and character-based swaps ([(["TextBugger: Generating Adversarial Text Against Real-world Applications" (Li et al., 2018)](https://arxiv.org/abs/1812.05271)).
- **textfooler**: Greedy attack with word importance ranking and counter-fitted embedding swap (["Is Bert Really Robust?" (Jin et al., 2019)](https://arxiv.org/abs/1907.11932)).
- **PWWS**: Greedy attack with word importance ranking based on word saliency and synonym swap scores (["Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency" (Ren et al., 2019)](https://www.aclweb.org/anthology/P19-1103/)).
The final is for sequence-to-sequence models:
Attacks on sequence-to-sequence models:
- **seq2sick**: Greedy attack with goal of changing every word in the output translation. Currently implemented as black-box with plans to change to white-box as done in paper (["Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples" (Cheng et al., 2018)](https://arxiv.org/abs/1803.01128)).
#### Recipe Usage Examples
@@ -112,7 +123,7 @@ textattack attack --model bert-base-uncased-sst2 --recipe textfooler --num-examp
*seq2sick (black-box) against T5 fine-tuned for English-German translation:*
```bash
textattack attack --recipe seq2sick --model t5-en2de --num-examples 100
textattack attack --model t5-en-de --recipe seq2sick --num-examples 100
```
### Augmenting Text
@@ -175,6 +186,32 @@ of a string or a list of strings. Here's an example of how to use the `Embedding
['What I notable create, I do not understand.', 'What I significant create, I do not understand.', 'What I cannot engender, I do not understand.', 'What I cannot creating, I do not understand.', 'What I cannot creations, I do not understand.', 'What I cannot create, I do not comprehend.', 'What I cannot create, I do not fathom.', 'What I cannot create, I do not understanding.', 'What I cannot create, I do not understands.', 'What I cannot create, I do not understood.', 'What I cannot create, I do not realise.']
```
### Training Models
Our model training code is available via `textattack train` to help you train LSTMs,
CNNs, and `transformers` models using TextAttack out-of-the-box. Datasets are
automatically loaded using the `nlp` package.
#### Training Examples
*Train our default LSTM for 50 epochs on the Yelp Polarity dataset:*
```bash
textattack train --model lstm --dataset yelp_polarity --batch-size 64 --epochs 50 --learning-rate 1e-5
```
*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs**:
```bash
textattack train --model bert-base-uncased --dataset glue:cola --batch-size 32 --epochs 5
```
## `textattack peek-dataset`
To take a closer look at a dataset, use `textattack peek-dataset`. TextAttack will print some cursory statistics about the inputs and outputs from the dataset. For example, `textattack peek-dataset --dataset-from-nlp snli` will show information about the SNLI dataset from the NLP package.
## `textattack list`
There are lots of pieces in TextAttack, and it can be difficult to keep track of all of them. You can use `textattack list` to list components, for example, pretrained models (`textattack list models`) or available search methods (`textattack list search-methods`).
## Design
### AttackedText
@@ -190,10 +227,10 @@ TextAttack is model-agnostic! You can use `TextAttack` to analyze any model that
TextAttack also comes built-in with models and datasets. Our command-line interface will automatically match the correct
dataset to the correct model. We include various pre-trained models for each of the nine [GLUE](https://gluebenchmark.com/)
tasks, as well as some common classification datasets, translation, and summarization. You can
tasks, as well as some common datasets for classification, translation, and summarization. You can
see the full list of provided models & datasets via `textattack attack --help`.
Here's an example of using one of the built-in models:
Here's an example of using one of the built-in models (the SST-2 dataset is automatically loaded):
```bash
textattack attack --model roberta-base-sst2 --recipe textfooler --num-examples 10
@@ -206,11 +243,11 @@ and datasets from the [`nlp` package](https://github.com/huggingface/nlp)! Here'
and attacking a pre-trained model and dataset:
```bash
textattack attack --model_from_huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset_from_nlp glue:sst2 --recipe deepwordbug --num-examples 10
textattack attack --model-from-huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset-from-nlp glue:sst2 --recipe deepwordbug --num-examples 10
```
You can explore other pre-trained models using the `--model_from_huggingface` argument, or other datasets by changing
`--dataset_from_nlp`.
You can explore other pre-trained models using the `--model-from-huggingface` argument, or other datasets by changing
`--dataset-from-nlp`.
#### Loading a model or dataset from a file
@@ -229,7 +266,7 @@ model = load_model()
tokenizer = load_tokenizer()
```
Then, run an attack with the argument `--model_from_file my_model.py`. The model and tokenizer will be loaded automatically.
Then, run an attack with the argument `--model-from-file my_model.py`. The model and tokenizer will be loaded automatically.
#### Dataset from a file
@@ -240,7 +277,7 @@ The following example would load a sentiment classification dataset from file `m
dataset = [('Today was....', 1), ('This movie is...', 0), ...]
```
You can then run attacks on samples from this dataset by adding the argument `--dataset_from_file my_dataset.py`.
You can then run attacks on samples from this dataset by adding the argument `--dataset-from-file my_dataset.py`.
### Attacks
@@ -248,7 +285,7 @@ The `attack_one` method in an `Attack` takes as input an `AttackedText`, and out
### Goal Functions
A `GoalFunction` takes as input an `AttackedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
A `GoalFunction` takes as input an `AttackedText` object, scores it, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
### Constraints
@@ -262,10 +299,13 @@ A `Transformation` takes as input an `AttackedText` and returns a list of possib
A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a final `GoalFunctionResult` The search is given access to the `get_transformations` function, which takes as input an `AttackedText` object and outputs a list of possible transformations filtered by meeting all of the attacks constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined using `get_goal_results`) or is exhausted.
## Contributing to TextAttack
We welcome suggestions and contributions! Submit an issue or pull request and we will do our best to respond in a timely manner. TextAttack is currently in an "alpha" stage in which we are working to improve its capabilities and design.
See [CONTRIBUTING.md](https://github.com/QData/TextAttack/blob/master/CONTRIBUTING.md) for detailed information on contributing.
## Citing TextAttack
If you use TextAttack for your research, please cite [TextAttack: A Framework for Adversarial Attacks in Natural Language Processing](https://arxiv.org/abs/2005.05909).

View File

@@ -2,6 +2,13 @@
Attack
========
TextAttack builds attacks from four components:
- `Goal Functions <../attacks/goal_function.html>`__ stipulate the goal of the attack, like to change the prediction score of a classification model, or to change all of the words in a translation output.
- `Constraints <../attacks/constraint.html>`__ determine if a potential perturbation is valid with respect to the original input.
- `Transformations <../attacks/transformation.html>`__ take a text input and transform it by inserting and deleting characters, words, and/or phrases.
- `Search Methods <../attacks/search_method.html>`__ explore the space of possible **transformations** within the defined **constraints** and attempt to find a successful perturbation which satisfies the **goal function**.
The ``Attack`` class represents an adversarial attack composed of a goal function, search method, transformation, and constraints.
.. automodule:: textattack.shared.attack

View File

@@ -5,53 +5,82 @@ We provide a number of pre-built attack recipes. To run an attack recipe, run::
textattack attack --recipe [recipe_name]
Alzantot
###########
Alzantot Genetic Algorithm (Generating Natural Language Adversarial Examples)
###################################################################################
.. automodule:: textattack.attack_recipes.alzantot_2018
.. automodule:: textattack.attack_recipes.genetic_algorithm_alzantot_2018
:members:
Faster Alzantot Genetic Algorithm (Certified Robustness to Adversarial Word Substitutions)
##############################################################################################
.. automodule:: textattack.attack_recipes.faster_genetic_algorithm_jia_2019
:members:
BAE (BAE: BERT-Based Adversarial Examples)
#############################################
.. automodule:: textattack.attack_recipes.bae_garg_2019
:members:
DeepWordBug
############
BERT-Attack: (BERT-Attack: Adversarial Attack Against BERT Using BERT)
#########################################################################
.. automodule:: textattack.attack_recipes.bert_attack_li_2020
:members:
DeepWordBug (Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers)
######################################################################################################
.. automodule:: textattack.attack_recipes.deepwordbug_gao_2018
:members:
HotFlip
###########
HotFlip (HotFlip: White-Box Adversarial Examples for Text Classification)
##############################################################################
.. automodule:: textattack.attack_recipes.hotflip_ebrahimi_2017
:members:
Input Reduction
################
.. automodule:: textattack.attack_recipes.input_reduction_feng_2018
:members:
Kuleshov
###########
Kuleshov (Adversarial Examples for Natural Language Classification Problems)
##############################################################################
.. automodule:: textattack.attack_recipes.kuleshov_2017
:members:
Seq2Sick
###########
Particle Swarm Optimization (Word-level Textual Adversarial Attacking as Combinatorial Optimization)
#####################################################################################################
.. automodule:: textattack.attack_recipes.PSO_zang_2020
:members:
PWWS (Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency)
###################################################################################################
.. automodule:: textattack.attack_recipes.pwws_ren_2019
:members:
Seq2Sick (Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples)
#########################################################################################################
.. automodule:: textattack.attack_recipes.seq2sick_cheng_2018_blackbox
:members:
TextFooler
###########
TextFooler (Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and Entailment)
########################################################################################################################
.. automodule:: textattack.attack_recipes.textfooler_jin_2019
:members:
PWWS
###########
.. automodule:: textattack.attack_recipes.pwws_ren_2019
:members:
TextBugger
###########
TextBugger (TextBugger: Generating Adversarial Text Against Real-world Applications)
########################################################################################
.. automodule:: textattack.attack_recipes.textbugger_li_2018
:members:

View File

@@ -83,6 +83,12 @@ GPT-2
.. automodule:: textattack.constraints.grammaticality.language_models.gpt2
:members:
"Learning To Write" Language Model
************************************
.. automodule:: textattack.constraints.grammaticality.language_models.learning_to_write.learning_to_write
:members:
Google 1-Billion Words Language Model
@@ -136,7 +142,7 @@ Maximum Words Perturbed
.. _pre_transformation:
Pre-Transformation
----------
-------------------------
Pre-transformation constraints determine if a transformation is valid based on
only the original input and the position of the replacement. These constraints
@@ -145,7 +151,7 @@ constraints can prevent search methods from swapping words at the same index
twice, or from replacing stopwords.
Pre-Transformation Constraint
########################
###############################
.. automodule:: textattack.constraints.pre_transformation.pre_transformation_constraint
:special-members: __call__
:private-members:
@@ -160,3 +166,13 @@ Repeat Modification
########################
.. automodule:: textattack.constraints.pre_transformation.repeat_modification
:members:
Input Column Modification
#############################
.. automodule:: textattack.constraints.pre_transformation.input_column_modification
:members:
Max Word Index Modification
###############################
.. automodule:: textattack.constraints.pre_transformation.max_word_index_modification
:members:

View File

@@ -69,7 +69,7 @@ Word Swap by Random Character Insertion
:members:
Word Swap by Random Character Substitution
---------------------------------------
-------------------------------------------
.. automodule:: textattack.transformations.word_swap_random_character_substitution
:members:

View File

@@ -22,7 +22,7 @@ copyright = "2020, UVA QData Lab"
author = "UVA QData Lab"
# The full version, including alpha/beta/rc tags
release = "0.0.3.1"
release = "0.1.5"
# Set master doc to `index.rst`.
master_doc = "index"
@@ -37,7 +37,10 @@ extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx_rtd_theme",
# Enable .ipynb doc files
"nbsphinx",
# Enable .md doc files
"recommonmark",
]
# Add any paths that contain templates here, relative to this directory.

View File

@@ -6,40 +6,10 @@ Datasets
:members:
:private-members:
Classification
###############
.. automodule:: textattack.datasets.classification.classification_dataset
.. automodule:: textattack.datasets.huggingface_nlp_dataset
:members:
.. automodule:: textattack.datasets.classification.ag_news
:members:
.. automodule:: textattack.datasets.classification.imdb_sentiment
:members:
.. automodule:: textattack.datasets.classification.kaggle_fake_news
:members:
.. automodule:: textattack.datasets.classification.movie_review_sentiment
:members:
.. automodule:: textattack.datasets.classification.yelp_sentiment
:members:
Entailment
############
.. automodule:: textattack.datasets.entailment.entailment_dataset
:members:
.. automodule:: textattack.datasets.entailment.mnli
:members:
.. automodule:: textattack.datasets.entailment.snli
.. automodule:: textattack.datasets.translation.ted_multi
:members:
Translation
#############
.. automodule:: textattack.datasets.translation.translation_datasets
:members:

View File

@@ -11,7 +11,7 @@ We split models up into two broad categories:
**Classification models:**
:ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using transformers_.
:ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using ``transformers``.
:ref:`LSTM`: a standard LSTM fine-tuned on various datasets.
@@ -20,85 +20,32 @@ We split models up into two broad categories:
**Text-to-text models:**
:ref:`T5`: ``T5`` fine-tuned on various datasets using transformers_.
:ref:`T5`: ``T5`` fine-tuned on various datasets using ``transformers``.
.. _BERT:
BERT
********
.. _BERT:
.. automodule:: textattack.models.helpers.bert_for_classification
:members:
We provide pre-trained BERT models on the following datasets:
.. automodule:: textattack.models.classification.bert.bert_for_ag_news_classification
:members:
.. automodule:: textattack.models.classification.bert.bert_for_imdb_sentiment_classification
:members:
.. automodule:: textattack.models.classification.bert.bert_for_mr_sentiment_classification
:members:
.. automodule:: textattack.models.classification.bert.bert_for_yelp_sentiment_classification
:members:
.. automodule:: textattack.models.entailment.bert.bert_for_mnli
:members:
.. automodule:: textattack.models.entailment.bert.bert_for_snli
:members:
LSTM
*******
.. _LSTM:
LSTM
*******
.. automodule:: textattack.models.helpers.lstm_for_classification
:members:
We provide pre-trained LSTM models on the following datasets:
.. automodule:: textattack.models.classification.lstm.lstm_for_ag_news_classification
:members:
.. automodule:: textattack.models.classification.lstm.lstm_for_imdb_sentiment_classification
:members:
.. automodule:: textattack.models.classification.lstm.lstm_for_mr_sentiment_classification
:members:
.. automodule:: textattack.models.classification.lstm.lstm_for_yelp_sentiment_classification
:members:
.. _CNN:
Word-CNN
************
.. _CNN:
.. automodule:: textattack.models.helpers.word_cnn_for_classification
:members:
We provide pre-trained CNN models on the following datasets:
.. automodule:: textattack.models.classification.cnn.word_cnn_for_ag_news_classification
:members:
.. automodule:: textattack.models.classification.cnn.word_cnn_for_imdb_sentiment_classification
:members:
.. automodule:: textattack.models.classification.cnn.word_cnn_for_mr_sentiment_classification
:members:
.. automodule:: textattack.models.classification.cnn.word_cnn_for_yelp_sentiment_classification
:members:
.. _T5:
T5
@@ -106,21 +53,3 @@ T5
.. automodule:: textattack.models.helpers.t5_for_text_to_text
:members:
We provide pre-trained T5 models on the following tasks & datasets:
Translation
##############
.. automodule:: textattack.models.translation.t5.t5_models
:members:
Summarization
##############
.. automodule:: textattack.models.summarization.t5_summarization
:members:
.. _transformers: https://github.com/huggingface/transformers

View File

@@ -2,20 +2,14 @@
Tokenizers
===========
.. automodule:: textattack.tokenizers.tokenizer
.. automodule:: textattack.models.tokenizers.auto_tokenizer
:members:
.. automodule:: textattack.tokenizers.auto_tokenizer
.. automodule:: textattack.models.tokenizers.glove_tokenizer
:members:
.. automodule:: textattack.tokenizers.spacy_tokenizer
.. automodule:: textattack.models.tokenizers.t5_tokenizer
:members:
.. automodule:: textattack.tokenizers.t5_tokenizer
.. automodule:: textattack.models.tokenizers.bert_tokenizer
:members:
.. automodule:: textattack.tokenizers.bert_tokenizer
:members:
.. automodule:: textattack.tokenizers.bert_entailment_tokenizer
:members:

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# The TextAttack🐙 ecosystem: search, transformations, and constraints\n",
"# The TextAttack ecosystem: search, transformations, and constraints\n",
"\n",
"An attack in TextAttack consists of four parts.\n",
"\n",
@@ -31,9 +31,9 @@
"This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\" (these don't intend to preserve the grammaticality of inputs). Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n",
"\n",
"\n",
"### Banana word swap 🍌\n",
"### Banana word swap \n",
"\n",
"As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'."
"As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'. 🍌"
]
},
{
@@ -308,9 +308,9 @@
"collapsed": true
},
"source": [
"### Conclusion 🍌\n",
"### Conclusion n",
"\n",
"We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a single \"banana\", while others needed up to 17 \"banana\" substitutions to change the class score. Wow!"
"We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a couple \"banana\"s, while others needed up to 17 \"banana\" substitutions to change the class score. Wow! 🍌"
]
}
],

View File

@@ -16,7 +16,7 @@ TextAttack provides a framework for constructing and thinking about attacks via
TextAttack provides a set of `Attack Recipes <attacks/attack_recipes.html>`__ that assemble attacks from the literature from these four components.
Data Augmentation
-------------
--------------------
Data augmentation is easy and extremely common in computer vision but harder and less common in NLP. We provide a `Data Augmentation <augmentation/augmenter.html>`__ module using transformations and constraints.
Features
@@ -31,12 +31,14 @@ TextAttack has some other features that make it a pleasure to use:
.. toctree::
:maxdepth: 1
:hidden:
:caption: Quickstart
:caption: Getting Started
quickstart/installation
quickstart/overview
Example 1: Transformations <examples/1_Introduction_and_Transformations.ipynb>
Example 2: Constraints <examples/2_Constraints.ipynb>
Installation <quickstart/installation>
Command-Line Usage <quickstart/command_line_usage>
Tutorial 0: TextAttack End-To-End (Train, Eval, Attack) <examples/0_End_to_End.ipynb>
Tutorial 1: Transformations <examples/1_Introduction_and_Transformations.ipynb>
Tutorial 2: Constraints <examples/2_Constraints.ipynb>
.. toctree::
:maxdepth: 3
@@ -73,7 +75,7 @@ TextAttack has some other features that make it a pleasure to use:
:hidden:
:caption: Miscellaneous
misc/attacked_text
misc/checkpoints
misc/loggers
misc/validators
misc/tokenized_text

View File

@@ -0,0 +1,6 @@
===================
Attacked Text
===================
.. automodule:: textattack.shared.attacked_text
:members:

View File

@@ -1,6 +0,0 @@
===================
Tokenized Text
===================
.. automodule:: textattack.shared.tokenized_text
:members:

View File

@@ -0,0 +1,135 @@
Command-Line Usage
=======================================
The easiest way to use textattack is from the command-line. Installing textattack
will provide you with the handy `textattack` command which will allow you to do
just about anything TextAttack offers in a single bash command.
> *Tip*: If you are for some reason unable to use the `textattack` command, you
> can access all the same functionality by prepending `python -m` to the command
> (`python -m textattack ...`).
To see all available commands, type `textattack --help`. This page explains
some of the most important functionalities of textattack: NLP data augmentation,
adversarial attacks, and training and evaluating models.
## Data Augmentation with `textattack augment`
The easiest way to use our data augmentation tools is with `textattack augment <args>`. `textattack augment`
takes an input CSV file and text column to augment, along with the number of words to change per augmentation
and the number of augmentations per input example. It outputs a CSV in the same format with all the augmentation
examples corresponding to the proper columns.
For example, given the following as `examples.csv`:
```
"text",label
"the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1
"the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1
"take care of my cat offers a refreshingly different slice of asian cinema .", 1
"a technically well-made suspenser . . . but its abrupt drop in iq points as it races to the finish line proves simply too discouraging to let slide .", 0
"it's a mystery how the movie could be released in this condition .", 0
```
The command:
```
textattack augment --csv examples.csv --input-column text --recipe embedding --num-words-to-swap 4 \
--transformations-per-example 2 --exclude-original
```
will augment the `text` column with four swaps per augmentation, twice as many augmentations as original inputs, and exclude the original inputs from the
output CSV. (All of this will be saved to `augment.csv` by default.)
After augmentation, here are the contents of `augment.csv`:
```
text,label
"the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1
"the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1
the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of expression significant adequately describe co-writer/director pedro jackson's expanded vision of j . rs . r . tolkien's middle-earth .,1
the gorgeously elaborate continuation of 'the lordy of the piercings' trilogy is so huge that a column of mots cannot adequately describe co-novelist/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .,1
take care of my cat offerings a pleasantly several slice of asia cinema .,1
taking care of my cat offers a pleasantly different slice of asiatic kino .,1
a technically good-made suspenser . . . but its abrupt drop in iq points as it races to the finish bloodline proves straightforward too disheartening to let slide .,0
a technically well-made suspenser . . . but its abrupt drop in iq dot as it races to the finish line demonstrates simply too disheartening to leave slide .,0
it's a enigma how the film wo be releases in this condition .,0
it's a enigma how the filmmaking wo be publicized in this condition .,0
```
The 'embedding' augmentation recipe uses counterfitted embedding nearest-neighbors to augment data.
## Adversarial Attacks with `textattack attack`
The heart of textattack is running adversarial attacks on NLP models with
`textattack attack`. You can build an attack from the command-line in several ways:
1. Use an **attack recipe** to launch an attack from the literature: `textattack attack --recipe deepwordbug`
2. Build your attack from components:
```
textattack attack --model lstm-mr --num-examples 20 --search-method beam-search:beam_width=4 \
--transformation word-swap-embedding \
--constraints repeat stopword max-words-perturbed:max_num_words=2 embedding:min_cos_sim=0.8 part-of-speech \
--goal-function untargeted-classification
```
3. Create a python file that builds your attack and load it: `textattack attack --attack-from-file my_file.py:my_attack_name`
## Training Models with `textattack train`
With textattack, you can train models on any classification or regression task
from [`nlp`](https://github.com/huggingface/nlp/) using a single line.
### Available Models
#### TextAttack Models
TextAttack has two build-in model types, a 1-layer bidirectional LSTM with a hidden
state size of 150 (`lstm`), and a WordCNN with 3 window sizes
(3, 4, 5) and 100 filters for the window size (`cnn`). Both models set dropout
to 0.3 and use a base of the 200-dimensional GLoVE embeddings.
#### `transformers` Models
Along with the `lstm` and `cnn`, you can theoretically fine-tune any model based
in the huggingface [transformers](https://github.com/huggingface/transformers/)
repo. Just type the model name (like `bert-base-cased`) and it will be automatically
loaded.
Here are some models from transformers that have worked well for us:
- `bert-base-uncased` and `bert-base-cased`
- `distilbert-base-uncased` and `distilbert-base-cased`
- `albert-base-v2`
- `roberta-base`
- `xlnet-base-cased`
## Evaluating Models with `textattack eval-model`
## Other Commands
### Checkpoints and `textattack attack-resume`
Some attacks can take a very long time. Sometimes this is because they're using
a very slow search method (like beam search with a high beam width) or sometimes
they're just attacking a large number of samples. In these cases, it can be
useful to save attack checkpoints throughout the course of the attack. Then,
if the attack crashes for some reason, you can resume without restarting from
scratch.
- To save checkpoints while running an attack, add the argument `--checkpoint-interval X`,
where X is the number of attacks you want to run between checkpoints (for example `textattack attack <args> --checkpoint-interval 5`).
- To load an attack from a checkpoint, use `textattack attack-resume --checkpoint-file <checkpoint-file>`.
### Listing features with `textattack list`
TextAttack has a lot of built-in features (models, search methods, constraints, etc.)
and it can get overwhelming to keep track of all the options. To list all of the
options within a given category, use `textattack list`.
For example:
- list all the built-in models: `textattack list models`
- list all constraints: `textattack list constraints`
- list all search methods: `textattack list search-methods`
### Examining datasets with `textattack peek-dataset`
It can be useful to take a cursory look at and compute some basic statistics of
whatever dataset you're working with. Whether you're loading a dataset of your
own from a file, or one from NLP, you can use `textattack peek-dataset` to
see some basic information about the dataset.
For example, use `textattack peek-dataset --dataset-from-nlp glue:mrpc` to see
information about the MRPC dataset (from the GLUE set of datasets). This will
print statistics like the number of labels, average number of words, etc.

View File

@@ -8,7 +8,7 @@ To use TextAttack, you must be running Python 3.6+. A CUDA-compatible GPU is opt
You're now all set to use TextAttack! Try running an attack from the command line::
textattack attack --recipe textfooler --model bert-mr --num-examples 10
textattack attack --recipe textfooler --model bert-base-uncased-mr --num-examples 10
This will run an attack using the TextFooler_ recipe, attacking BERT fine-tuned on the MR dataset. It will attack the first 10 samples. Once everything downloads and starts running, you should see attack results print to ``stdout``.

View File

@@ -1,52 +0,0 @@
===========
Overview
===========
TextAttack builds attacks from four components:
- `Goal Functions <../attacks/goal_function.html>`__ stipulate the goal of the attack, like to change the prediction score of a classification model, or to change all of the words in a translation output.
- `Constraints <../attacks/constraint.html>`__ determine if a potential perturbation is valid with respect to the original input.
- `Transformations <../attacks/transformation.html>`__ take a text input and transform it by inserting and deleting characters, words, and/or phrases.
- `Search Methods <../attacks/search_method.html>`__ explore the space of possible **transformations** within the defined **constraints** and attempt to find a successful perturbation which satisfies the **goal function**.
Any model that overrides ``__call__``, takes ``TokenizedText`` as input, and formats output correctly can be used with TextAttack. TextAttack also has built-in datasets and pre-trained models on these datasets. Below is an example of attacking a pre-trained model on the AGNews dataset::
from tqdm import tqdm
from textattack.loggers import FileLogger
from textattack.datasets.classification import AGNews
from textattack.models.classification.lstm import LSTMForAGNewsClassification
from textattack.goal_functions import UntargetedClassification
from textattack.shared import Attack
from textattack.search_methods import GreedySearch
from textattack.transformations import WordSwapEmbedding
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.semantics import RepeatModification, StopwordModification
# Create the model and goal function
model = LSTMForAGNewsClassification()
goal_function = UntargetedClassification(model)
# Use the default WordSwapEmbedding transformation
transformation = WordSwapEmbedding()
# Add a constraint, note that an empty list can be used if no constraints are wanted
constraints = [
RepeatModification(),
StopwordModification(),
PartOfSpeech()
]
# Choose a search method
search = GreedySearch()
# Make an attack with the above parameters
attack = Attack(goal_function, constraints, transformation, search)
# Run the attack on 5 examples and see the results using a logger to output to stdout
results = attack.attack_dataset(AGNews(), num_examples=5, attack_n=True)
logger = FileLogger(stdout=True)
for result in tqdm(results, total=5):
logger.log_attack_result(result)

View File

@@ -1,2 +1,4 @@
recommonmark
nbsphinx
sphinx-autobuild
sphinx-rtd-theme

View File

@@ -0,0 +1,7 @@
#!/bin/bash
# Shows how to build an attack from components and use it on a pre-trained
# model on the Yelp dataset.
textattack attack --attack-n --goal-function untargeted-classification \
--model bert-base-uncased-yelp --num-examples 8 --transformation word-swap-wordnet \
--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword \
--search greedy

View File

@@ -0,0 +1,4 @@
#!/bin/bash
# Shows how to attack a DistilBERT model fine-tuned on SST2 dataset *from the
# huggingface model repository& using the DeepWordBug recipe and 10 examples.
textattack attack --model-from-huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset-from-nlp glue:sst2 --recipe deepwordbug --num-examples 10

View File

@@ -0,0 +1,4 @@
#!/bin/bash
# Shows how to attack our RoBERTA model fine-tuned on SST2 using the TextFooler
# recipe and 10 examples.
textattack attack --model roberta-base-sst2 --recipe textfooler --num-examples 10

View File

@@ -0,0 +1,2 @@
"text",label
"it's a mystery how the movie could be released in this condition .", 0
1 text label
2 it's a mystery how the movie could be released in this condition . 0

View File

@@ -0,0 +1,5 @@
#!/bin/bash
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a
# demonstration of how our training script can handle different `transformers`
# models and customize for different datasets.
textattack train --model albert-base-v2 --dataset snli --batch-size 128 --epochs 5 --max-length 128 --learning-rate 1e-5 --allowed-labels 0 1 2

View File

@@ -0,0 +1,4 @@
#!/bin/bash
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a demonstration
# of how our training script handles regression.
textattack train --model bert-base-cased --dataset glue:stsb --batch-size 128 --epochs 3 --max-length 128 --learning-rate 1e-5

View File

@@ -0,0 +1,4 @@
#!/bin/bash
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a basic
# demonstration of our training script and `nlp` integration.
textattack train --model lstm --dataset rotten_romatoes --batch-size 64 --epochs 50 --learning-rate 1e-5

View File

@@ -1,24 +1,23 @@
click
bert-score
editdistance
flair==0.5.1
filelock
language_tool_python
lru-dict
nlp
nltk
numpy
pandas
pyyaml>=5.1
pandas>=1.0.1
scikit-learn
scipy==1.4.1
sentence_transformers
spacy
sentence_transformers==0.2.6.1
torch
transformers>=2.5.1
transformers>=3
tensorflow>=2
tensorflow_hub
tensorboardX
terminaltables
tokenizers==0.8.0-rc4
tqdm
visdom
wandb
flair
bert-score

View File

@@ -1,9 +1,3 @@
[flake8]
ignore = E203, E266, E501, W503
max-line-length = 120
per-file-ignores = __init__.py:F401
mypy_config = mypy.ini
[isort]
line_length = 88
skip = __init__.py
@@ -14,3 +8,11 @@ multi_line_output = 3
include_trailing_comma = True
use_parentheses = True
force_grid_wrap = 0
[flake8]
exclude = .git,__pycache__,wandb,build,dist
ignore = E203, E266, E501, W503, D203
max-complexity = 10
max-line-length = 120
mypy_config = mypy.ini
per-file-ignores = __init__.py:F401

View File

@@ -6,6 +6,14 @@ from docs import conf as docs_conf
with open("README.md", "r") as fh:
long_description = fh.read()
extras = {}
# Packages required for installing docs.
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
# Packages required for formatting code & running tests.
extras["test"] = ["black", "isort==5.0.3", "flake8", "pytest", "pytest-xdist"]
# For developers, install development tools along with all optional dependencies.
extras["dev"] = extras["docs"] + extras["test"]
setuptools.setup(
name="textattack",
version=docs_conf.release,
@@ -22,12 +30,13 @@ setuptools.setup(
"build*",
"docs*",
"dist*",
"examples*",
"outputs*",
"tests*",
"local_test*",
"wandb*",
]
),
extras_require=extras,
entry_points={
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
},

View File

@@ -28,6 +28,10 @@
)
(3): RepeatModification
(4): StopwordModification
(5): InputColumnModification(
(matching_column_labels): ['premise', 'hypothesis']
(columns_to_ignore): {'premise'}
)
(is_black_box): True
)
/.*/

View File

@@ -0,0 +1,57 @@
/.*/Attack(
(search_method): GreedySearch
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
)
(constraints):
(0): MaxWordsPerturbed(
(max_percent): 0.5
)
(1): ThoughtVector(
(embedding_type): paragramcf
(metric): max_euclidean
(threshold): -0.2
(compare_with_original): False
(window_size): inf
(skip_text_shorter_than_window): False
)
(2): GPT2(
(max_log_prob_diff): 2.0
)
(3): RepeatModification
(4): StopwordModification
(is_black_box): True
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> Negative (69%)
it 's a charming and often affecting journey .
it 's a loveable and ordinarily affecting journey .
--------------------------------------------- Result 2 ---------------------------------------------
Negative (83%) --> Positive (90%)
unflinchingly bleak and desperate
unflinchingly bleak and desperation
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 0 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 25.0% |
| Average num. words per input: | 6.0 |
| Avg num queries: | 48.5 |
+-------------------------------+--------+

View File

@@ -14,19 +14,19 @@
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> Negative (73%)
Positive (100%) --> Negative (88%)
lovingly photographed in the manner of a golden book sprung to life , stuart little 2 manages sweetness largely without stickiness .
lovingly photographed in the manner of a golden book sprung to life , stuart little 2 manages sweetness largely without stickiness .
covingly photographed in the manner of a golden book sprung to life , stuart little 2 manages seetness largely without stickiness .
locingly photographed in the manenr of a golden book sprung to lief , stuart little 2 manages sweetness largely without stickiness .
--------------------------------------------- Result 2 ---------------------------------------------
Positive (100%) --> Negative (62%)
Positive (100%) --> Negative (61%)
consistently clever and suspenseful .
consistently clever and suspenseful .
consistently clevger and surpenseful .
cnosistently Mclever and suspensWful .
@@ -39,7 +39,7 @@ consistently clevger and surpenseful .
| Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 30.26% |
| Average num. words per input: | 11.5 |
| Avg num queries: | 20.5 |
| Average perturbed word %: | 45.0% |
| Average num. words per input: | 12.0 |
| Avg num queries: | 27.0 |
+-------------------------------+--------+

View File

@@ -0,0 +1,66 @@
/.*/Attack(
(search_method): GeneticAlgorithm(
(pop_size): 60
(max_iters): 20
(temp): 0.3
(give_up_if_no_improvement): False
)
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 8
(embedding_type): paragramcf
)
(constraints):
(0): MaxWordsPerturbed(
(max_percent): 0.2
)
(1): WordEmbeddingDistance(
(embedding_type): paragramcf
(max_mse_dist): 0.5
(cased): False
(include_unknown_words): True
)
(2): LearningToWriteLanguageModel(
(max_log_prob_diff): 5.0
)
(3): RepeatModification
(4): StopwordModification
(is_black_box): True
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> Negative (73%)
this kind of hands-on storytelling is ultimately what makes shanghai ghetto move beyond a good , dry , reliable textbook and what allows it to rank with its worthy predecessors .
this kind of hands-on tale is ultimately what makes shanghai ghetto move beyond a good , secs , credible textbook and what allows it to rank with its worthy predecessors .
--------------------------------------------- Result 2 ---------------------------------------------
Positive (80%) --> Negative (97%)
making such a tragedy the backdrop to a love story risks trivializing it , though chouraqui no doubt intended the film to affirm love's power to help people endure almost unimaginable horror .
making such a tragedy the backdrop to a love story risks trivializing it , notwithstanding chouraqui no doubt intended the film to affirm love's power to help people endure almost incomprehensible horror .
--------------------------------------------- Result 3 ---------------------------------------------
Positive (92%) --> [FAILED]
grown-up quibbles are beside the point here . the little girls understand , and mccracken knows that's all that matters .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 1 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 66.67% |
| Average perturbed word %: | 8.58% |
| Average num. words per input: | 25.67 |
| Avg num queries: |/.*/|
+-------------------------------+--------+

View File

@@ -1,5 +1,4 @@
/.*/
Attack(
/.*/Attack(
(search_method): GreedyWordSwapWIR(
(wir_method): unk
)
@@ -25,44 +24,48 @@ Attack(
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> [FAILED]
Positive (100%) --> Negative (98%)
this is a film well worth seeing , talking and singing heads and all .
exposing the ways we fool ourselves is one hour photo's real strength .
exposing the ways we fool ourselves is one stopwatch photo's real kraft .
--------------------------------------------- Result 2 ---------------------------------------------
Positive (100%) --> Negative (57%)
Positive (96%) --> Negative (99%)
what really surprises about wisegirls is its low-key quality and genuine tenderness .
it's up to you to decide whether to admire these people's dedication to their cause or be repelled by their dogmatism , manipulativeness and narrow , fearful view of american life .
what really dumbfounded about wisegirls is its low-vital quality and veritable sensibility .
it's up to you to decide whether to admire these people's dedication to their cause or be rescheduled by their dogmatism , manipulativeness and narrow , shitless view of american life .
--------------------------------------------- Result 3 ---------------------------------------------
Positive (100%) --> Negative (84%)
Positive (100%) --> Negative (96%)
( wendigo is ) why we go to the cinema : to be fed through the eye , the heart , the mind .
mostly , [goldbacher] just lets her complicated characters be unruly , confusing and , through it all , human .
( wendigo is ) why we go to the movie : to be stoked through the eyelids , the coeur , the bother .
mostly , [goldbacher] just lets her complicated characters be haphazard , confusing and , through it all , humanistic .
--------------------------------------------- Result 4 ---------------------------------------------
Positive (99%) --> [FAILED]
Positive (99%) --> Negative (90%)
one of the greatest family-oriented , fantasy-adventure movies ever .
. . . quite good at providing some good old fashioned spooks .
. . . rather good at provision some good old fashioned bugging .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 2 |
| Number of successful attacks: | 4 |
| Number of failed attacks: | 0 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 50.0% |
| Attack success rate: | 50.0% |
| Average perturbed word %: | 29.27% |
| Average num. words per input: | 13.5 |
| Avg num queries: | 63.25 |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 17.56% |
| Average num. words per input: | 16.25 |
| Avg num queries: | 45.5 |
+-------------------------------+--------+

View File

@@ -9,31 +9,19 @@
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (63%) --> [SKIPPED]
Positive (75%) --> Negative (71%)
I was surprised how much I enjoyed this. Sure it is a bit slow moving in parts, but what else would one expect from Rollin? Also there is plenty of nudity, nothing wrong with that, particularly as it includes lots of the gorgeous, Brigitte Lahaie. There are also some spectacularly eroticised female dead, bit more dodgey, perhaps, but most effective. There is also a sci-fi like storyline with a brief explanation at the end, but I wouldn't bother too much with that. No, here we have a most interesting exploration of memory and the effect of memory loss and to just what extent one is still 'alive' without memory. My DVD sleeve mentions David Cronenberg and whilst this is perhaps not quite as good as his best films, there is some similarity here, particularly with the great use of seemingly menacing architecture and the effective and creepy use of inside space. As I have tried to indicate this is by no means a rip roaring thriller, it is a captivating, nightmare like movie that makes the very most of its locations, including a stunning railway setting at the end.
I was surprised how much I enjoyed this. Sure it is a bit slow moving in parts, but what else would one expect from Rollin? Also there is plenty of nudity, nothing wrong with that, particularly as it includes lots of the gorgeous, Brigitte Lahaie. There are also some spectacularly eroticised female dead, bit more dodgey, perhaps, but most effective. There is also a sci-fi like storyline with a brief explanation at the end, but I wouldn't bother too much with that. No, here we have a most interesting exploration of memory and the effect of memory loss and to just what extent one is still 'alive' without memory. My DVD sleeve mentions David Cronenberg and whilst this is perhaps not quite as good as his best films, there is some similarity here, particularly with the great use of seemingly menacing architecture and the effective and creepy use of inside space. As I have tried to indicate this is by no means a rip roaring thriller, it is a captivating, nightmare like movie that makes the very most of its locations, including a stunning railway setting at the end.
I was surprised how much I enjoyed this. Sure it is a bct slow moving in parts, but what else would one expect from Rollin? Also there is plenty of nudity, nothing wrong with that, particularly as it includes lots of the gorgeous, Brigitte Lahaie. There are also some spectacularly eroticised female dead, bit more dodgey, perhaps, but most effective. There is also a sci-fi like storyline with a brief explanation at the end, but I wouldn't bother too much with that. No, here we have a most interesting exploration of memory and the effect of memory loss and to just what extent one is still 'alive' without memory. My DVD sleeve mentions David Cronenberg and whilst this is perhaps not quite as good as his best films, there is some similarity here, particularly with the great use of seemingly menacing architecture and the effective and creepy use of inside space. As I have tried to indicate this is by no means a rip roaring thriller, it is a captivating, nightmare like movie that makes the very most of its locations, including a stunning railway setting at the end.
--------------------------------------------- Result 2 ---------------------------------------------
Positive (87%) --> Negative (54%)
Positive (69%) --> Negative (53%)
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks and grim memories, they eventually get to the bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks and grim memories, they eventually get to the bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks acd gAim memories, they eventually get to thr bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
--------------------------------------------- Result 3 ---------------------------------------------
Positive (83%) --> [SKIPPED]
I have certainly not seen all of Jean Rollin's films, but they mostly seem to be bloody vampire naked women fests, which if you like that sort of thing is not bad, but this is a major departure and could almost be Cronenberg minus the bio-mechanical nightmarish stuff. Except it's in French with subtitles of course. A man driving on the road at night comes across a woman that is in her slippers and bathrobe and picks her up, while in the background yet another woman lingers, wearing nothing. As they drive along it's obvious that there is something not right about the woman, in that she forgets things almost as quickly as they happen. Still though, that doesn't prevent the man from having sex with her once they return to Paris & his apartment. The man leaves for work and some strangers show up at his place and take the woman away to this 'tower block', a huge apartment building referred to as the Black Tower, where others of her kind (for whom the 'no memory' things seems to be the least of their problems) are being held for some reason. Time and events march by in the movie, which involve mostly trying to find what's going on and get out of the building for this woman, and she does manage to call Robert, the guy that picked her up in the first place, to come rescue her. The revelation as to what's going on comes in the last few moments of the movie, which has a rather strange yet touching end to it. In avoiding what seemed to be his "typical" formula, Rollin created, in this, what I feel is his most fascinating and disturbing film. I like this one a lot, check it out. 8 out of 10.
--------------------------------------------- Result 4 ---------------------------------------------
Positive (98%) --> Negative (51%)
Since this cartoon was made in the old days, Felix talks using cartoon bubbles and the animation style is very crude when compared to today. However, compared to its contemporaries, it's a pretty good cartoon and still holds up well. That's because despite its age, the cartoon is very creative and funny.<br /><br />Felix meets a guy whose shoe business is folding because he can't sell any shoes. Well, Felix needs money so he can go to Hollywood, so he tells the guy at the shop he'll get every shoe sold. Felix spreads chewing gum all over town and soon people are stuck and leave their shoes--rushing to buy new ones from the shoe store. In gratitude, the guy gives Felix $500! However, Felix's owner wants to take the money and go alone, so Felix figures out a way to sneak along.<br /><br />Once there, Felix barges into a studio and makes a bit of a nuisance of himself. Along the way, he meets cartoon versions of comics Ben Turpin and Charlie Chaplin. In the end, though, through luck, Felix is discovered and offered a movie contract. Hurray!
Since this cartoon was made in the old days, Felix talks using cartoon bubbles and the animation style is very crude when compared to today. However, compared to its contemporaries, it's a pretty gogd cartoon and still holds up well. That's because despite its age, the cartoon is very creative and funny.<br /><br />Felix meets a guy whose shoe business is folding because he can't sell any shoes. Well, Felix needs money so he can go to Hollywood, so he tells the guy at the shop he'll get every shoe sold. Felix spreads chewing gum lll over town and soon people are stuck and leave their shoes--rushing to buy new ones from the shoe store. In gratitude, the guy gives Felix $500! However, Felix's owner wants to take the money and go alone, so Felix figures out a way to sneak along.<br /><br />Once there, Felix barges into a studio and makes a bit of a nuisance of himself. Along the way, he meets cartoon versions of comics Ben Turpin and Charlie Chaplin. In the end, though, through luck, Felix is discovered and offered a movie contract. Hurray!
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks and grEm memories, they eventually get to the bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
@@ -42,11 +30,11 @@ Since this cartoon was made in the old days, Felix talks using cartoon bubbles a
+-------------------------------+--------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 0 |
| Number of skipped attacks: | 2 |
| Original accuracy: | 50.0% |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 1.59% |
| Average num. words per input: | 207.5 |
| Avg num queries: | 172.0 |
| Average perturbed word %: | 0.62% |
| Average num. words per input: | 164.0 |
| Avg num queries: | 166.0 |
+-------------------------------+--------+

View File

@@ -27,7 +27,7 @@
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> [FAILED]
Positive (97%) --> [FAILED]
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur supplies with tremendous skill .
@@ -63,6 +63,6 @@ throws in enough clever and unexpected twists to make the formula feel fresh .
| Accuracy under attack: | 75.0% |
| Attack success rate: | 25.0% |
| Average perturbed word %: | 3.85% |
| Average num. words per input: | 15.75 |
| Average num. words per input: | 15.5 |
| Avg num queries: | 1.25 |
+-------------------------------+--------+

View File

@@ -16,40 +16,43 @@
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Contradiction (99%) --> [SKIPPED]
Entailment (99%) --> [SKIPPED]
Premise: The new rights are nice enough
Hypothesis: Everyone really likes the newest benefits
--------------------------------------------- Result 2 ---------------------------------------------
Entailment (100%) --> [FAILED]
Neutral (100%) --> Entailment (56%)
Premise: This site includes a list of all award winners and a searchable database of Government Executive articles.
Hypothesis: The Government Executive articles housed on the website are not able to be searched.
Hypothesis: The Government Executive articles housed on the website are not able to be searched.
Premise: This site includes a list of all award winners and a searchable database of Government Executive articles.
Hypothesis: The Government Executive articles housed on the website are not able-bodied to be searched.
--------------------------------------------- Result 3 ---------------------------------------------
Neutral (99%) --> Contradiction (100%)
Contradiction (99%) --> Entailment (100%)
Premise: uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him
Hypothesis: I like him for the most part, but would still enjoy seeing someone beat him.
Hypothesis: I like him for the most part, but would still enjoy seeing someone beat him.
Premise: uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him
Hypothesis: I like him for the most office, but would still enjoy seeing someone beat him.
Hypothesis: I like him for the most office, but would still enjoy seeing someone beat him.
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 1 |
| Number of failed attacks: | 1 |
| Number of successful attacks: | 2 |
| Number of failed attacks: | 0 |
| Number of skipped attacks: | 1 |
| Original accuracy: | 66.67% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 50.0% |
| Average perturbed word %: | 2.27% |
| Average num. words per input: | 29.0 |
| Avg num queries: | 447.5 |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 2.78% |
| Average num. words per input: | 28.67 |
| Avg num queries: | 182.0 |
+-------------------------------+--------+

View File

@@ -12,12 +12,27 @@ def attacked_text():
return textattack.shared.AttackedText(raw_text)
raw_pokemon_text = "the threat implied in the title pokémon 4ever is terrifying – like locusts in a horde these things will keep coming ."
@pytest.fixture
def pokemon_attacked_text():
return textattack.shared.AttackedText(raw_pokemon_text)
premise = "Among these are the red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes."
hypothesis = "The Patan Museum is down the street from the red brick Royal Palace."
raw_text_pair = collections.OrderedDict(
[("premise", premise), ("hypothesis", hypothesis)]
)
raw_hyphenated_text = "It's a run-of-the-mill kind of farmer's tan."
@pytest.fixture
def hyphenated_text():
return textattack.shared.AttackedText(raw_hyphenated_text)
@pytest.fixture
def attacked_text_pair():
@@ -25,27 +40,13 @@ def attacked_text_pair():
class TestAttackedText:
def test_words(self, attacked_text):
def test_words(self, attacked_text, pokemon_attacked_text):
# fmt: off
assert attacked_text.words == [
"A",
"person",
"walks",
"up",
"stairs",
"into",
"a",
"room",
"and",
"sees",
"beer",
"poured",
"from",
"a",
"keg",
"and",
"people",
"talking",
"A", "person", "walks", "up", "stairs", "into", "a", "room", "and", "sees", "beer", "poured", "from", "a", "keg", "and", "people", "talking",
]
assert pokemon_attacked_text.words == ['the', 'threat', 'implied', 'in', 'the', 'title', 'pokémon', '4ever', 'is', 'terrifying', 'like', 'locusts', 'in', 'a', 'horde', 'these', 'things', 'will', 'keep', 'coming']
# fmt: on
def test_window_around_index(self, attacked_text):
assert attacked_text.text_window_around_index(5, 1) == "into"
@@ -53,6 +54,10 @@ class TestAttackedText:
assert attacked_text.text_window_around_index(5, 3) == "stairs into a"
assert attacked_text.text_window_around_index(5, 4) == "up stairs into a"
assert attacked_text.text_window_around_index(5, 5) == "up stairs into a room"
assert (
attacked_text.text_window_around_index(5, float("inf"))
== "A person walks up stairs into a room and sees beer poured from a keg and people talking"
)
def test_big_window_around_index(self, attacked_text):
assert (
@@ -65,8 +70,9 @@ class TestAttackedText:
def test_window_around_index_end(self, attacked_text):
assert attacked_text.text_window_around_index(17, 3) == "and people talking"
def test_text(self, attacked_text, attacked_text_pair):
def test_text(self, attacked_text, pokemon_attacked_text, attacked_text_pair):
assert attacked_text.text == raw_text
assert pokemon_attacked_text.text == raw_pokemon_text
assert attacked_text_pair.text == "\n".join(raw_text_pair.values())
def test_printable_text(self, attacked_text, attacked_text_pair):
@@ -136,13 +142,13 @@ class TestAttackedText:
+ "\n"
+ "The Patan Museum is down the street from the red brick Royal Palace."
)
new_text = new_text.insert_text_after_word_index(38, "and shapes")
new_text = new_text.insert_text_after_word_index(37, "and shapes")
assert new_text.text == (
"Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes."
+ "\n"
+ "The Patan Museum is down the street from the red brick Royal Palace."
)
new_text = new_text.insert_text_after_word_index(41, "The")
new_text = new_text.insert_text_after_word_index(40, "The")
assert new_text.text == (
"Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes."
+ "\n"
@@ -159,7 +165,7 @@ class TestAttackedText:
)
for old_idx, new_idx in enumerate(new_text.attack_attrs["original_index_map"]):
assert (attacked_text.words[old_idx] == new_text.words[new_idx]) or (
new_i == -1
new_idx == -1
)
new_text = (
new_text.delete_word_at_index(0)
@@ -176,3 +182,14 @@ class TestAttackedText:
new_text.text
== "person walks a very long way up stairs into a room and sees beer poured and people on the couch."
)
def test_hyphen_apostrophe_words(self, hyphenated_text):
assert hyphenated_text.words == [
"It's",
"a",
"run-of-the-mill",
"kind",
"of",
"farmer's",
"tan",
]

View File

@@ -1,9 +1,8 @@
import pdb
import re
import pytest
from helpers import run_command_and_get_result
import pytest
DEBUG = False
@@ -112,6 +111,24 @@ attack_test_params = [
),
# fmt: on
#
# test: run_attack on LSTM MR using word embedding transformation and genetic algorithm. Simulate alzantot recipe without using expensive LM
(
"run_attack_faster_alzantot_recipe",
(
"textattack attack --model lstm-mr --recipe faster-alzantot --num-examples 3 --num-examples-offset 20"
),
"tests/sample_outputs/run_attack_faster_alzantot_recipe.txt",
),
#
# test: run_attack with kuleshov recipe and sst-2 cnn
#
(
"run_attack_kuleshov_nn",
(
"textattack attack --recipe kuleshov --num-examples 2 --model cnn-sst --attack-n --query-budget 200"
),
"tests/sample_outputs/kuleshov_cnn_sst_2.txt",
),
]
@@ -144,3 +161,5 @@ def test_command_line_attack(name, command, sample_output_file):
if DEBUG and not re.match(desired_re, stdout, flags=re.S):
pdb.set_trace()
assert re.match(desired_re, stdout, flags=re.S)
assert result.returncode == 0

View File

@@ -1,6 +1,5 @@
import pytest
from helpers import run_command_and_get_result
import pytest
augment_test_params = [
(
@@ -37,3 +36,5 @@ def test_command_line_augmentation(name, command, outfile, sample_output_file):
# Ensure CSV file exists, then delete it.
assert os.path.exists(outfile)
os.remove(outfile)
assert result.returncode == 0

View File

@@ -1,6 +1,5 @@
import pytest
from helpers import run_command_and_get_result
import pytest
list_test_params = [
(
@@ -27,3 +26,5 @@ def test_command_line_list(name, command, sample_output_file):
print("stderr =>", stderr)
assert stdout == desired_text
assert result.returncode == 0

View File

@@ -10,18 +10,28 @@ from test_augment import augment_test_params
from test_list import list_test_params
def update_test(command, outfile):
def update_test(command, outfile, add_magic_str=False):
if isinstance(command, str):
command = (command,)
command = command + (f"tee {outfile}",)
print("\n".join(f"> {c}" for c in command))
run_command_and_get_result(command)
print(">", command)
else:
print("\n".join(f"> {c}" for c in command))
result = run_command_and_get_result(command)
stdout = result.stdout.decode().strip()
if add_magic_str:
# add magic string to beginning
magic_str = "/.*/"
stdout = magic_str + stdout
# add magic string after attack
mid_attack_str = "\n--------------------------------------------- Result 1"
stdout.replace(mid_attack_str, magic_str + mid_attack_str)
# write to file
open(outfile, "w").write(stdout + "\n")
def main():
#### `textattack attack` tests ####
for _, command, outfile in attack_test_params:
update_test(command, outfile)
update_test(command, outfile, add_magic_str=True)
#### `textattack augment` tests ####
for _, command, outfile, __ in augment_test_params:
update_test(command, outfile)

View File

@@ -1,7 +1,8 @@
def test_imports():
import textattack
import torch
import textattack
del textattack, torch

123
tests/test_tokenizers.py Normal file
View File

@@ -0,0 +1,123 @@
import collections
import pytest
import textattack
news_article = """The unemployment rate dropped to 8.2% last
month, but the economy only added 120,000 jobs,
when 203,000 new jobs had been predicted,
according to today's jobs report. Reaction on the
Wall Street Journal's MarketBeat Blog was swift:
"Woah!!! Bad number." The unemployment rate,
however, is better news; it had been expected to
hold steady at 8.3%. But the AP notes that the dip
is mostly due to more Americans giving up on
seeking employment. """
question = """What phenomenon makes global winds blow northeast
to southwest or the reverse in the northern
hemisphere and northwest to southeast or the
reverse in the southern hemisphere?"""
premise = "Among these are the red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes."
hypothesis = "The Patan Museum is down the street from the red brick Royal Palace."
raw_text_pair = collections.OrderedDict(
[("premise", premise), ("hypothesis", hypothesis)]
)
@pytest.fixture
def bert_tokenizer():
return textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
@pytest.fixture
def glove_tokenizer():
lstm = textattack.models.helpers.LSTMForClassification()
return lstm.tokenizer
@pytest.fixture
def lstm():
lstm = textattack.models.helpers.LSTMForClassification()
return lstm
# Disable formatting so that the samples don't wrap over thousands of lines
# fmt: off
class TestTokenizer:
def test_bert_encode(self, bert_tokenizer):
assert bert_tokenizer.encode(news_article) == {
"input_ids": [
101, 1996, 12163, 3446, 3333, 2000, 1022, 1012, 1016, 1003, 2197, 3204, 1010, 2021, 1996, 4610, 2069, 2794, 6036, 1010, 2199, 5841, 1010, 2043, 18540, 1010, 2199, 2047, 5841, 2018, 2042, 10173, 1010, 2429, 2000, 2651, 1005, 1055, 5841, 3189, 1012, 4668, 2006, 1996, 2813, 2395, 3485, 1005, 1055, 3006, 19442, 9927, 2001, 9170, 1024, 1000, 24185, 4430, 999, 999, 999, 2919, 2193, 1012, 1000, 1996, 12163, 3446, 1010, 2174, 1010, 2003, 2488, 2739, 1025, 2009, 2018, 2042, 3517, 2000, 2907, 6706, 2012, 1022, 1012, 1017, 1003, 1012, 2021, 1996, 9706, 3964, 2008, 1996, 16510, 2003, 3262, 2349, 2000, 2062, 4841, 3228, 2039, 2006, 6224, 6107, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "token_type_ids": [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "attention_mask": [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
}
assert bert_tokenizer.encode(question) == {
"input_ids": [
101, 2054, 9575, 3084, 3795, 7266, 6271, 4794, 2000, 4943, 2030, 1996, 7901, 1999, 1996, 2642, 14130, 1998, 4514, 2000, 4643, 2030, 1996, 7901, 1999, 1996, 2670, 14130, 1029, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "token_type_ids": [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "attention_mask": [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
}
def test_bert_multiseq(self, bert_tokenizer):
assert bert_tokenizer.encode((premise, hypothesis)) == {
"input_ids": [
101, 2426, 2122, 2024, 1996, 2417, 5318, 2548, 4186, 1010, 2029, 2085, 3506, 1996, 6986, 2319, 2688, 1006, 8222, 1005, 1055, 10418, 1998, 2087, 2715, 2688, 1007, 1010, 1998, 1010, 5307, 1996, 4186, 2408, 1996, 4867, 5318, 8232, 1010, 2809, 8436, 1997, 2367, 6782, 1998, 10826, 1012, 102, 1996, 6986, 2319, 2688, 2003, 2091, 1996, 2395, 2013, 1996, 2417, 5318, 2548, 4186, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "token_type_ids": [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "attention_mask": [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
}
def test_bert_encode_tuple(self, bert_tokenizer):
assert bert_tokenizer.encode(news_article) == bert_tokenizer.encode(
(news_article,)
)
assert bert_tokenizer.encode(question) == bert_tokenizer.encode((question,))
def test_bert_encode_batch(self, bert_tokenizer):
assert bert_tokenizer.batch_encode([premise, hypothesis]) == [
{
'input_ids': [101, 2426, 2122, 2024, 1996, 2417, 5318, 2548, 4186, 1010, 2029, 2085, 3506, 1996, 6986, 2319, 2688, 1006, 8222, 1005, 1055, 10418, 1998, 2087, 2715, 2688, 1007, 1010, 1998, 1010, 5307, 1996, 4186, 2408, 1996, 4867, 5318, 8232, 1010, 2809, 8436, 1997, 2367, 6782, 1998, 10826, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
},
{
'input_ids': [101, 1996, 6986, 2319, 2688, 2003, 2091, 1996, 2395, 2013, 1996, 2417, 5318, 2548, 4186, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
}
]
def test_glove_encode(self, glove_tokenizer):
assert glove_tokenizer.encode(news_article) == [
399999, 2479, 570, 1199, 3, 399999, 75, 399999, 33, 399999, 426, 90, 294, 12277, 399999, 60, 100738, 49, 1051, 39, 50, 399999, 199, 3, 399999, 1051, 399999, 2613, 12, 399999, 1014, 490, 399999, 399999, 7640, 14, 399999, 399999, 977, 399999, 399999, 2479, 399999, 399999, 13, 438, 399999, 19, 39, 50, 286, 3, 801, 3798, 21, 399999, 33, 399999, 1581, 2141, 11, 399999, 10247, 13, 1245, 444, 3, 55, 826, 1226, 59, 12, 1308, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000,
]
assert glove_tokenizer.encode(question) == [
101, 6387, 906, 700, 3683, 3314, 2591, 3, 2735, 45, 399999, 4962, 5, 399999, 528, 8686, 4, 2233, 3, 1980, 45, 399999, 4962, 5, 399999, 481, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000,
]
def test_glove_multiseq(self, glove_tokenizer):
# the GloVE tokenizer is expected to throw an error for multi-sequence inputs
with pytest.raises(ValueError):
glove_tokenizer.encode((premise, hypothesis))
def test_glove_encode_tuple(self, glove_tokenizer):
assert glove_tokenizer.encode(news_article) == glove_tokenizer.encode(
(news_article,)
)
assert glove_tokenizer.encode(question) == glove_tokenizer.encode((question,))
def test_glove_encode_batch(self, glove_tokenizer):
assert glove_tokenizer.batch_encode([premise, hypothesis]) == [
[243, 157, 31, 399999, 638, 6061, 1141, 399999, 41, 113, 1630, 399999, 82219, 1132, 399999, 9432, 4, 95, 1193, 399999, 399999, 2094, 399999, 2532, 530, 399999, 3756, 6061, 399999, 501, 9203, 2, 493, 6151, 4, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000,], [399999, 82219, 1132, 13, 134, 399999, 490, 24, 399999, 638, 6061, 1141, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, ],]
def test_glove_hotfliph(self, lstm, glove_tokenizer):
# these things are needed for hotflip: pad_id, oov_id, convert_id_to_word
assert glove_tokenizer.pad_id == 400_000
assert glove_tokenizer.oov_id == 399_999
assert glove_tokenizer.convert_id_to_word(2_179) == "jack"
assert 2_179 == lstm.word2id["jack"]
# fmt: on

View File

@@ -1,15 +1,17 @@
name = "textattack"
from . import attack_recipes
from . import attack_results
from . import augmentation
from . import commands
from . import constraints
from . import datasets
from . import goal_functions
from . import goal_function_results
from . import loggers
from . import models
from . import search_methods
from . import shared
from . import transformations
from . import (
attack_recipes,
attack_results,
augmentation,
commands,
constraints,
datasets,
goal_function_results,
goal_functions,
loggers,
models,
search_methods,
shared,
transformations,
)

View File

@@ -0,0 +1,56 @@
from textattack.constraints.pre_transformation import (
InputColumnModification,
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import PSOAlgorithm
from textattack.shared.attack import Attack
from textattack.transformations import WordSwapEmbedding, WordSwapHowNet
def PSOZang2020(model):
"""
Zang, Y., Yang, C., Qi, F., Liu, Z., Zhang, M., Liu, Q., & Sun, M. (2019).
Word-level Textual Adversarial Attacking as Combinatorial Optimization.
https://www.aclweb.org/anthology/2020.acl-main.540.pdf
Methodology description quoted from the paper:
"We propose a novel word substitution-based textual attack model, which reforms
both the aforementioned two steps. In the first step, we adopt a sememe-based word
substitution strategy, which can generate more candidate adversarial examples with
better semantic preservation. In the second step, we utilize particle swarm optimization
(Eberhart and Kennedy, 1995) as the adversarial example searching algorithm."
And "Following the settings in Alzantot et al. (2018), we set the max iteration time G to 20."
"""
#
# Swap words with their synonyms extracted based on the HowNet.
#
transformation = WordSwapHowNet()
#
# Don't modify the same word twice or stopwords
#
constraints = [RepeatModification(), StopwordModification()]
#
#
# During entailment, we should only edit the hypothesis - keep the premise
# the same.
#
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
#
# Use untargeted classification for demo, can be switched to targeted one
#
goal_function = UntargetedClassification(model)
#
# Perform word substitution with a Particle Swarm Optimization (PSO) algorithm.
#
search_method = PSOAlgorithm(pop_size=60, max_iters=20)
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -1,9 +1,14 @@
from .alzantot_2018 import Alzantot2018
from .bae_garg_2019 import BAEGarg2019
from .bert_attack_li_2020 import BERTAttackLi2020
from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
from .deepwordbug_gao_2018 import DeepWordBugGao2018
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
from .input_reduction_feng_2018 import InputReductionFeng2018
from .kuleshov_2017 import Kuleshov2017
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
from .textbugger_li_2018 import TextBuggerLi2018
from .textfooler_jin_2019 import TextFoolerJin2019
from .pwws_ren_2019 import PWWSRen2019
from .pruthi_2019 import Pruthi2019
from .PSO_zang_2020 import PSOZang2020

View File

@@ -0,0 +1,109 @@
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared.attack import Attack
from textattack.transformations import WordSwapMaskedLM
def BAEGarg2019(model):
"""
Siddhant Garg and Goutham Ramakrishnan, 2019.
BAE: BERT-based Adversarial Examples for Text Classification.
https://arxiv.org/pdf/2004.01970
This is "attack mode" 1 from the paper, BAE-R, word replacement.
We present 4 attack modes for BAE based on the
R and I operations, where for each token t in S:
• BAE-R: Replace token t (See Algorithm 1)
• BAE-I: Insert a token to the left or right of t
• BAE-R/I: Either replace token t or insert a
token to the left or right of t
• BAE-R+I: First replace token t, then insert a
token to the left or right of t
"""
# "In this paper, we present a simple yet novel technique: BAE (BERT-based
# Adversarial Examples), which uses a language model (LM) for token
# replacement to best fit the overall context. We perturb an input sentence
# by either replacing a token or inserting a new token in the sentence, by
# means of masking a part of the input and using a LM to fill in the mask."
#
# We only consider the top K=50 synonyms from the MLM predictions.
#
# [from email correspondance with the author]
# "When choosing the top-K candidates from the BERT masked LM, we filter out
# the sub-words and only retain the whole words (by checking if they are
# present in the GloVE vocabulary)"
#
transformation = WordSwapMaskedLM(method="bae", max_candidates=50)
#
# Don't modify the same word twice or stopwords.
#
constraints = [RepeatModification(), StopwordModification()]
# For the R operations we add an additional check for
# grammatical correctness of the generated adversarial example by filtering
# out predicted tokens that do not form the same part of speech (POS) as the
# original token t_i in the sentence.
constraints.append(PartOfSpeech(allow_verb_noun_swap=True))
# "To ensure semantic similarity on introducing perturbations in the input
# text, we filter the set of top-K masked tokens (K is a pre-defined
# constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
# (Cer et al., 2018)-based sentence similarity scorer."
#
# "[We] set a threshold of 0.8 for the cosine similarity between USE-based
# embeddings of the adversarial and input text."
#
# [from email correspondence with the author]
# "For a fair comparison of the benefits of using a BERT-MLM in our paper,
# we retained the majority of TextFooler's specifications. Thus we:
# 1. Use the USE for comparison within a window of size 15 around the word
# being replaced/inserted.
# 2. Set the similarity score threshold to 0.1 for inputs shorter than the
# window size (this translates roughly to almost always accepting the new text).
# 3. Perform the USE similarity thresholding of 0.8 with respect to the text
# just before the replacement/insertion and not the original text (For
# example: at the 3rd R/I operation, we compute the USE score on a window
# of size 15 of the text obtained after the first 2 R/I operations and not
# the original text).
# ...
# To address point (3) from above, compare the USE with the original text
# at each iteration instead of the current one (While doing this change
# for the R-operation is trivial, doing it for the I-operation with the
# window based USE comparison might be more involved)."
use_constraint = UniversalSentenceEncoder(
threshold=0.8,
metric="cosine",
compare_with_original=True,
window_size=15,
skip_text_shorter_than_window=True,
)
constraints.append(use_constraint)
#
# Goal s untargeted classification.
#
goal_function = UntargetedClassification(model)
#
# "We estimate the token importance Ii of each token
# t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
# decrease in probability of predicting the correct label y, similar
# to (Jin et al., 2019).
#
# • "If there are multiple tokens can cause C to misclassify S when they
# replace the mask, we choose the token which makes Sadv most similar to
# the original S based on the USE score."
# • "If no token causes misclassification, we choose the perturbation that
# decreases the prediction probability P(C(Sadv)=y) the most."
#
search_method = GreedyWordSwapWIR(wir_method="delete")
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -0,0 +1,76 @@
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared.attack import Attack
from textattack.transformations import WordSwapMaskedLM
def BERTAttackLi2020(model):
"""
Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020).
BERT-ATTACK: Adversarial Attack Against BERT Using BERT
https://arxiv.org/abs/2004.09984
This is "attack mode" 1 from the paper, BAE-R, word replacement.
"""
# [from correspondence with the author]
# Candidate size K is set to 48 for all data-sets.
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)
#
# Don't modify the same word twice or stopwords.
#
constraints = [RepeatModification(), StopwordModification()]
# "We only take ε percent of the most important words since we tend to keep
# perturbations minimum."
#
# [from correspondence with the author]
# "Word percentage allowed to change is set to 0.4 for most data-sets, this
# parameter is trivial since most attacks only need a few changes. This
# epsilon is only used to avoid too much queries on those very hard samples."
constraints.append(MaxWordsPerturbed(max_percent=0.4))
# "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence
# Encoder (Cer et al., 2018) to measure the semantic consistency between the
# adversarial sample and the original sequence. To balance between semantic
# preservation and attack success rate, we set up a threshold of semantic
# similarity score to filter the less similar examples."
#
# [from correspondence with author]
# "Over the full texts, after generating all the adversarial samples, we filter
# out low USE score samples. Thus the success rate is lower but the USE score
# can be higher. (actually USE score is not a golden metric, so we simply
# measure the USE score over the final texts for a comparison with TextFooler).
# For datasets like IMDB, we set a higher threshold between 0.4-0.7; for
# datasets like MNLI, we set threshold between 0-0.2."
#
# Since the threshold in the real world can't be determined from the training
# data, the TextAttack implementation uses a fixed threshold - determined to
# be 0.2 to be most fair.
use_constraint = UniversalSentenceEncoder(
threshold=0.2, metric="cosine", compare_with_original=True, window_size=None,
)
constraints.append(use_constraint)
#
# Goal is untargeted classification.
#
goal_function = UntargetedClassification(model)
#
# "We first select the words in the sequence which have a high significance
# influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote
# the input sentence, and oy(S) denote the logit output by the target model
# for correct label y, the importance score Iwi is defined as
# Iwi = oy(S) oy(S\wi), where S\wi = [w0, ··· , wi1, [MASK], wi+1, ···]
# is the sentence after replacing wi with [MASK]. Then we rank all the words
# according to the ranking score Iwi in descending order to create word list
# L."
search_method = GreedyWordSwapWIR(wir_method="unk")
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -0,0 +1,124 @@
from textattack.constraints.grammaticality.language_models import (
LearningToWriteLanguageModel,
)
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GeneticAlgorithm
from textattack.shared.attack import Attack
from textattack.transformations import WordSwapEmbedding
def FasterGeneticAlgorithmJia2019(model):
"""
Certified Robustness to Adversarial Word Substitutions.
Robin Jia, Aditi Raghunathan, Kerem Göksel, Percy Liang (2019).
https://arxiv.org/pdf/1909.00986.pdf
Section 5: Experiments
We base our sets of allowed word substitutions S(x, i) on the
substitutions allowed by Alzantot et al. (2018). They demonstrated that
their substitutions lead to adversarial examples that are qualitatively
similar to the original input and retain the original label, as judged
by humans. Alzantot et al. (2018) define the neighbors N(w) of a word w
as the n = 8 nearest neighbors of w in a “counter-fitted” word vector
space where antonyms are far apart (Mrksiˇ c´ et al., 2016). The
neighbors must also lie within some Euclidean distance threshold. They
also use a language model constraint to avoid nonsensical perturbations:
they allow substituting xi with x˜i ∈ N(xi) if and only if it does not
decrease the log-likelihood of the text under a pre-trained language
model by more than some threshold.
We make three modifications to this approach:
First, in Alzantot et al. (2018), the adversary
applies substitutions one at a time, and the
neighborhoods and language model scores are computed.
Equation (4) must be applied before the model
can combine information from multiple words, but it can
be delayed until after processing each word independently.
Note that the model itself classifies using a different
set of pre-trained word vectors; the counter-fitted vectors
are only used to define the set of allowed substitution words.
relative to the current altered version of the input.
This results in a hard-to-define attack surface, as
changing one word can allow or disallow changes
to other words. It also requires recomputing
language model scores at each iteration of the genetic
attack, which is inefficient. Moreover, the same
word can be substituted multiple times, leading
to semantic drift. We define allowed substitutions
relative to the original sentence x, and disallow
repeated substitutions.
Second, we use a faster language model that allows us to query
longer contexts; Alzantot et al. (2018) use a slower language
model and could only query it with short contexts.
Finally, we use the language model constraint only
at test time; the model is trained against all perturbations in N(w). This encourages the model to be
robust to a larger space of perturbations, instead of
specializing for the particular choice of language
model. See Appendix A.3 for further details. [This is a model-specific
adjustment, so does not affect the attack recipe.]
Appendix A.3:
In Alzantot et al. (2018), the adversary applies replacements one at a
time, and the neighborhoods and language model scores are computed
relative to the current altered version of the input. This results in a
hard-to-define attack surface, as the same word can be replaced many
times, leading to semantic drift. We instead pre-compute the allowed
substitutions S(x, i) at index i based on the original x. We define
S(x, i) as the set of x_i ∈ N(x_i) such that where probabilities are
assigned by a pre-trained language model, and the window radius W and
threshold δ are hyperparameters. We use W = 6 and δ = 5.
"""
# # @TODO update all this stuff
# Swap words with their embedding nearest-neighbors.
#
# Embedding: Counter-fitted Paragram Embeddings.
#
# "[We] fix the hyperparameter values to S = 60, N = 8, K = 4, and δ = 0.5"
#
transformation = WordSwapEmbedding(max_candidates=8)
#
# Don't modify the same word twice or stopwords
#
constraints = [RepeatModification(), StopwordModification()]
#
# Maximum words perturbed percentage of 20%
#
constraints.append(MaxWordsPerturbed(max_percent=0.2))
#
# Maximum word embedding euclidean distance of 0.5.
#
constraints.append(WordEmbeddingDistance(max_mse_dist=0.5))
#
# Language Model
#
#
#
constraints.append(
LearningToWriteLanguageModel(
window_size=6, max_log_prob_diff=5.0, compare_against_original=True
)
)
# constraints.append(LearningToWriteLanguageModel(window_size=5))
#
# Goal is untargeted classification
#
goal_function = UntargetedClassification(model)
#
# Perform word substitution with a genetic algorithm.
#
search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0)
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -3,6 +3,7 @@ from textattack.constraints.grammaticality.language_models import (
)
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
InputColumnModification,
RepeatModification,
StopwordModification,
)
@@ -13,7 +14,7 @@ from textattack.shared.attack import Attack
from textattack.transformations import WordSwapEmbedding
def Alzantot2018(model):
def GeneticAlgorithmAlzantot2018(model):
"""
Alzantot, M., Sharma, Y., Elgohary, A., Ho, B., Srivastava, M.B., & Chang, K. (2018).
@@ -34,6 +35,14 @@ def Alzantot2018(model):
#
constraints = [RepeatModification(), StopwordModification()]
#
# During entailment, we should only edit the hypothesis - keep the premise
# the same.
#
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
#
# Maximum words perturbed percentage of 20%
#
constraints.append(MaxWordsPerturbed(max_percent=0.2))
@@ -52,6 +61,6 @@ def Alzantot2018(model):
#
# Perform word substitution with a genetic algorithm.
#
search_method = GeneticAlgorithm(pop_size=60, max_iters=20)
search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0)
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -0,0 +1,40 @@
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import InputReduction
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared.attack import Attack
from textattack.transformations import WordDeletion
def InputReductionFeng2018(model):
"""
Feng, Wallace, Grissom, Iyyer, Rodriguez, Boyd-Graber. (2018).
Pathologies of Neural Models Make Interpretations Difficult.
ArXiv, abs/1804.07781.
"""
# At each step, we remove the word with the lowest importance value until
# the model changes its prediction.
transformation = WordDeletion()
constraints = [RepeatModification(), StopwordModification()]
#
# Goal is untargeted classification
#
goal_function = InputReduction(model, maximizable=True)
#
# "For each word in an input sentence, we measure its importance by the
# change in the confidence of the original prediction when we remove
# that word from the sentence."
#
# "Instead of looking at the words with high importance values—what
# interpretation methods commonly do—we take a complementary approach
# and study how the model behaves when the supposedly unimportant words are
# removed."
#
search_method = GreedyWordSwapWIR(wir_method="delete")
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -1,30 +1,42 @@
from textattack.transformations import (
WordSwapNeighboringCharacterSwap,
WordSwapRandomCharacterDeletion,
WordSwapRandomCharacterInsertion,
WordSwapQWERTY,
CompositeTransformation,
)
from textattack.constraints.pre_transformation import StopwordModification, MinWordLength, RepeatModification
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.search_methods import GreedySearch
from textattack.constraints.pre_transformation import (
MinWordLength,
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedySearch
from textattack.shared.attack import Attack
from textattack.transformations import (
CompositeTransformation,
WordSwapNeighboringCharacterSwap,
WordSwapQWERTY,
WordSwapRandomCharacterDeletion,
WordSwapRandomCharacterInsertion,
)
def Pruthi2019(model, max_num_word_swaps=1):
transformation = CompositeTransformation([
WordSwapNeighboringCharacterSwap(random_one=False, skip_first_char=True, skip_last_char=True),
WordSwapRandomCharacterDeletion(random_one=False, skip_first_char=True, skip_last_char=True),
WordSwapRandomCharacterInsertion(random_one=False, skip_first_char=True, skip_last_char=True),
WordSwapQWERTY(random_one=False, skip_first_char=True, skip_last_char=True),
])
constraints = [MinWordLength(min_length=4), StopwordModification(), MaxWordsPerturbed(max_num_words=max_num_word_swaps), RepeatModification()]
transformation = CompositeTransformation(
[
WordSwapNeighboringCharacterSwap(
random_one=False, skip_first_char=True, skip_last_char=True
),
WordSwapRandomCharacterDeletion(
random_one=False, skip_first_char=True, skip_last_char=True
),
WordSwapRandomCharacterInsertion(
random_one=False, skip_first_char=True, skip_last_char=True
),
WordSwapQWERTY(random_one=False, skip_first_char=True, skip_last_char=True),
]
)
constraints = [
MinWordLength(min_length=4),
StopwordModification(),
MaxWordsPerturbed(max_num_words=max_num_word_swaps),
RepeatModification(),
]
goal_function = UntargetedClassification(model)
search_method = GreedySearch()
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -26,5 +26,5 @@ def PWWSRen2019(model):
constraints = [RepeatModification(), StopwordModification()]
goal_function = UntargetedClassification(model)
# search over words based on a combination of their saliency score, and how efficient the WordSwap transform is
search_method = GreedyWordSwapWIR("pwws", ascending=False)
search_method = GreedyWordSwapWIR("pwws")
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -27,7 +27,7 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"):
# Goal is non-overlapping output.
#
goal_function = NonOverlappingOutput(model)
# @TODO implement transformation / search method just like they do in
# TODO implement transformation / search method just like they do in
# seq2sick.
transformation = WordSwapEmbedding(max_candidates=50)
#
@@ -42,6 +42,6 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"):
#
# Greedily swap words with "Word Importance Ranking".
#
search_method = GreedyWordSwapWIR()
search_method = GreedyWordSwapWIR(wir_method="unk")
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -1,5 +1,6 @@
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.pre_transformation import (
InputColumnModification,
RepeatModification,
StopwordModification,
)
@@ -28,278 +29,20 @@ def TextFoolerJin2019(model):
# Don't modify the same word twice or the stopwords defined
# in the TextFooler public implementation.
#
# fmt: off
stopwords = set(
[
"a",
"about",
"above",
"across",
"after",
"afterwards",
"again",
"against",
"ain",
"all",
"almost",
"alone",
"along",
"already",
"also",
"although",
"am",
"among",
"amongst",
"an",
"and",
"another",
"any",
"anyhow",
"anyone",
"anything",
"anyway",
"anywhere",
"are",
"aren",
"aren't",
"around",
"as",
"at",
"back",
"been",
"before",
"beforehand",
"behind",
"being",
"below",
"beside",
"besides",
"between",
"beyond",
"both",
"but",
"by",
"can",
"cannot",
"could",
"couldn",
"couldn't",
"d",
"didn",
"didn't",
"doesn",
"doesn't",
"don",
"don't",
"down",
"due",
"during",
"either",
"else",
"elsewhere",
"empty",
"enough",
"even",
"ever",
"everyone",
"everything",
"everywhere",
"except",
"first",
"for",
"former",
"formerly",
"from",
"hadn",
"hadn't",
"hasn",
"hasn't",
"haven",
"haven't",
"he",
"hence",
"her",
"here",
"hereafter",
"hereby",
"herein",
"hereupon",
"hers",
"herself",
"him",
"himself",
"his",
"how",
"however",
"hundred",
"i",
"if",
"in",
"indeed",
"into",
"is",
"isn",
"isn't",
"it",
"it's",
"its",
"itself",
"just",
"latter",
"latterly",
"least",
"ll",
"may",
"me",
"meanwhile",
"mightn",
"mightn't",
"mine",
"more",
"moreover",
"most",
"mostly",
"must",
"mustn",
"mustn't",
"my",
"myself",
"namely",
"needn",
"needn't",
"neither",
"never",
"nevertheless",
"next",
"no",
"nobody",
"none",
"noone",
"nor",
"not",
"nothing",
"now",
"nowhere",
"o",
"of",
"off",
"on",
"once",
"one",
"only",
"onto",
"or",
"other",
"others",
"otherwise",
"our",
"ours",
"ourselves",
"out",
"over",
"per",
"please",
"s",
"same",
"shan",
"shan't",
"she",
"she's",
"should've",
"shouldn",
"shouldn't",
"somehow",
"something",
"sometime",
"somewhere",
"such",
"t",
"than",
"that",
"that'll",
"the",
"their",
"theirs",
"them",
"themselves",
"then",
"thence",
"there",
"thereafter",
"thereby",
"therefore",
"therein",
"thereupon",
"these",
"they",
"this",
"those",
"through",
"throughout",
"thru",
"thus",
"to",
"too",
"toward",
"towards",
"under",
"unless",
"until",
"up",
"upon",
"used",
"ve",
"was",
"wasn",
"wasn't",
"we",
"were",
"weren",
"weren't",
"what",
"whatever",
"when",
"whence",
"whenever",
"where",
"whereafter",
"whereas",
"whereby",
"wherein",
"whereupon",
"wherever",
"whether",
"which",
"while",
"whither",
"who",
"whoever",
"whole",
"whom",
"whose",
"why",
"with",
"within",
"without",
"won",
"won't",
"would",
"wouldn",
"wouldn't",
"y",
"yet",
"you",
"you'd",
"you'll",
"you're",
"you've",
"your",
"yours",
"yourself",
"yourselves",
]
["a", "about", "above", "across", "after", "afterwards", "again", "against", "ain", "all", "almost", "alone", "along", "already", "also", "although", "am", "among", "amongst", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", "aren", "aren't", "around", "as", "at", "back", "been", "before", "beforehand", "behind", "being", "below", "beside", "besides", "between", "beyond", "both", "but", "by", "can", "cannot", "could", "couldn", "couldn't", "d", "didn", "didn't", "doesn", "doesn't", "don", "don't", "down", "due", "during", "either", "else", "elsewhere", "empty", "enough", "even", "ever", "everyone", "everything", "everywhere", "except", "first", "for", "former", "formerly", "from", "hadn", "hadn't", "hasn", "hasn't", "haven", "haven't", "he", "hence", "her", "here", "hereafter", "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", "how", "however", "hundred", "i", "if", "in", "indeed", "into", "is", "isn", "isn't", "it", "it's", "its", "itself", "just", "latter", "latterly", "least", "ll", "may", "me", "meanwhile", "mightn", "mightn't", "mine", "more", "moreover", "most", "mostly", "must", "mustn", "mustn't", "my", "myself", "namely", "needn", "needn't", "neither", "never", "nevertheless", "next", "no", "nobody", "none", "noone", "nor", "not", "nothing", "now", "nowhere", "o", "of", "off", "on", "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", "ours", "ourselves", "out", "over", "per", "please", "s", "same", "shan", "shan't", "she", "she's", "should've", "shouldn", "shouldn't", "somehow", "something", "sometime", "somewhere", "such", "t", "than", "that", "that'll", "the", "their", "theirs", "them", "themselves", "then", "thence", "there", "thereafter", "thereby", "therefore", "therein", "thereupon", "these", "they", "this", "those", "through", "throughout", "thru", "thus", "to", "too", "toward", "towards", "under", "unless", "until", "up", "upon", "used", "ve", "was", "wasn", "wasn't", "we", "were", "weren", "weren't", "what", "whatever", "when", "whence", "whenever", "where", "whereafter", "whereas", "whereby", "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", "who", "whoever", "whole", "whom", "whose", "why", "with", "within", "without", "won", "won't", "would", "wouldn", "wouldn't", "y", "yet", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves"]
)
# fmt: on
constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)]
#
# During entailment, we should only edit the hypothesis - keep the premise
# the same.
#
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
# Minimum word embedding cosine similarity of 0.5.
# (The paper claims 0.7, but analysis of the released code and some empirical
# results show that it's 0.5.)

View File

@@ -1,3 +1,4 @@
from .maximized_attack_result import MaximizedAttackResult
from .failed_attack_result import FailedAttackResult
from .skipped_attack_result import SkippedAttackResult
from .successful_attack_result import SuccessfulAttackResult

View File

@@ -13,11 +13,11 @@ class AttackResult:
perturbed text. May or may not have been successful.
"""
def __init__(self, original_result, perturbed_result, num_queries=0):
def __init__(self, original_result, perturbed_result):
if original_result is None:
raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult):
raise TypeError(f"Invalid original goal function result: {original_text}")
raise TypeError(f"Invalid original goal function result: {original_result}")
if perturbed_result is None:
raise ValueError("Attack perturbed result cannot be None")
elif not isinstance(perturbed_result, GoalFunctionResult):
@@ -27,7 +27,7 @@ class AttackResult:
self.original_result = original_result
self.perturbed_result = perturbed_result
self.num_queries = num_queries
self.num_queries = perturbed_result.num_queries
# We don't want the AttackedText attributes sticking around clogging up
# space on our devices. Delete them here, if they're still present,
@@ -89,27 +89,34 @@ class AttackResult:
i1 = 0
i2 = 0
while i1 < len(t1.words) and i2 < len(t2.words):
while i1 < t1.num_words or i2 < t2.num_words:
# show deletions
while t2.attack_attrs["original_index_map"][i1] == -1:
while (
i1 < len(t2.attack_attrs["original_index_map"])
and t2.attack_attrs["original_index_map"][i1] == -1
):
words_1.append(utils.color_text(t1.words[i1], color_1, color_method))
words_1_idxs.append(i1)
i1 += 1
# show insertions
while i2 < t2.attack_attrs["original_index_map"][i1]:
while (
i1 < len(t2.attack_attrs["original_index_map"])
and i2 < t2.attack_attrs["original_index_map"][i1]
):
words_2.append(utils.color_text(t1.words[i2], color_2, color_method))
words_2_idxs.append(i2)
i2 += 1
# show swaps
word_1 = t1.words[i1]
word_2 = t2.words[i2]
if word_1 != word_2:
words_1.append(utils.color_text(word_1, color_1, color_method))
words_2.append(utils.color_text(word_2, color_2, color_method))
words_1_idxs.append(i1)
words_2_idxs.append(i2)
i1 += 1
i2 += 1
if i1 < t1.num_words and i2 < t2.num_words:
word_1 = t1.words[i1]
word_2 = t2.words[i2]
if word_1 != word_2:
words_1.append(utils.color_text(word_1, color_1, color_method))
words_2.append(utils.color_text(word_2, color_2, color_method))
words_1_idxs.append(i1)
words_2_idxs.append(i2)
i1 += 1
i2 += 1
t1 = self.original_result.attacked_text.replace_words_at_indices(
words_1_idxs, words_1

View File

@@ -6,9 +6,9 @@ from .attack_result import AttackResult
class FailedAttackResult(AttackResult):
"""The result of a failed attack."""
def __init__(self, original_result, perturbed_result=None, num_queries=0):
def __init__(self, original_result, perturbed_result=None):
perturbed_result = perturbed_result or original_result
super().__init__(original_result, perturbed_result, num_queries)
super().__init__(original_result, perturbed_result)
def str_lines(self, color_method=None):
lines = (

View File

@@ -0,0 +1,5 @@
from .attack_result import AttackResult
class MaximizedAttackResult(AttackResult):
""" The result of a successful attack. """

View File

@@ -122,13 +122,12 @@ class CharSwapAugmenter(Augmenter):
""" Augments words by swapping characters out for other characters. """
def __init__(self, **kwargs):
from textattack.transformations import CompositeTransformation
from textattack.transformations import (
CompositeTransformation,
WordSwapNeighboringCharacterSwap,
WordSwapRandomCharacterDeletion,
WordSwapRandomCharacterInsertion,
WordSwapRandomCharacterSubstitution,
WordSwapNeighboringCharacterSwap,
)
transformation = CompositeTransformation(

View File

@@ -1,2 +1,5 @@
from .attack_command import AttackCommand
from .attack_resume_command import AttackResumeCommand
from .run_attack_single_threaded import run as run_attack_single_threaded
from .run_attack_parallel import run as run_attack_parallel

View File

@@ -1,15 +1,20 @@
import textattack
ATTACK_RECIPE_NAMES = {
"alzantot": "textattack.attack_recipes.Alzantot2018",
"alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
"bae": "textattack.attack_recipes.BAEGarg2019",
"bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
"faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
"input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
"kuleshov": "textattack.attack_recipes.Kuleshov2017",
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
"textbugger": "textattack.attack_recipes.TextBuggerLi2018",
"textfooler": "textattack.attack_recipes.TextFoolerJin2019",
"pwws": "textattack.attack_recipes.PWWSRen2019",
"pruthi": "textattack.attack_recipes.Pruthi2019",
"pso": "textattack.attack_recipes.PSOZang2020",
}
#
@@ -55,6 +60,14 @@ HUGGINGFACE_DATASET_BY_MODEL = {
"textattack/bert-base-uncased-WNLI",
("glue", "wnli", "validation"),
),
"bert-base-uncased-mr": (
"textattack/bert-base-uncased-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
"bert-base-uncased-snli": (
"textattack/bert-base-uncased-snli",
("snli", None, "test", [1, 2, 0]),
),
#
# distilbert-base-cased
#
@@ -141,6 +154,24 @@ HUGGINGFACE_DATASET_BY_MODEL = {
"textattack/roberta-base-WNLI",
("glue", "wnli", "validation"),
),
"roberta-base-mr": (
"textattack/roberta-base-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
#
# albert-base-v2 (ALBERT is cased by default)
#
"albert-base-v2-mr": (
"textattack/albert-base-v2-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
#
# xlnet-base-cased
#
"xlnet-base-cased-mr": (
"textattack/xlnet-base-cased-rotten-tomatoes",
("rotten_tomatoes", None, "test"),
),
}
@@ -172,10 +203,6 @@ TEXTATTACK_DATASET_BY_MODEL = {
#
# Text classification models
#
"bert-base-uncased-mr": (
("models/classification/bert/mr-uncased", 2),
("rotten_tomatoes", None, "train"),
),
"bert-base-cased-imdb": (
("models/classification/bert/imdb-cased", 2),
("imdb", None, "test"),
@@ -194,11 +221,22 @@ TEXTATTACK_DATASET_BY_MODEL = {
),
#
# Translation models
# TODO add proper `nlp` datasets for translation & summarization
"t5-en-de": (
"english_to_german",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
),
"t5-en-fr": (
"english_to_french",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "fr"),
),
"t5-en-ro": (
"english_to_romanian",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
),
#
# Summarization models
#
#'t5-summ': 'textattack.models.summarization.T5Summarization',
"t5-summarization": ("summarization", ("gigaword", None, "test")),
}
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
@@ -233,6 +271,7 @@ CONSTRAINT_CLASS_NAMES = {
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
#
# Overlap constraints
#

View File

@@ -1,6 +1,7 @@
import argparse
import copy
import importlib
import json
import os
import pickle
import random
@@ -30,7 +31,7 @@ def add_model_args(parser):
type=str,
required=False,
default=None,
help='The pre-trained model to attack. Usage: "--model {model}:{arg_1}={value_1},{arg_3}={value_3},...". Choices: '
help='Name of or path to a pre-trained model to attack. Usage: "--model {model}:{arg_1}={value_1},{arg_3}={value_3},...". Choices: '
+ str(model_names),
)
model_group.add_argument(
@@ -153,6 +154,8 @@ def parse_goal_function_from_args(args, model):
else:
raise ValueError(f"Error: unsupported goal_function {goal_function}")
goal_function.query_budget = args.query_budget
goal_function.model_batch_size = args.model_batch_size
goal_function.model_cache_size = args.model_cache_size
return goal_function
@@ -191,6 +194,9 @@ def parse_attack_from_args(args):
else:
raise ValueError(f"Invalid recipe {args.recipe}")
recipe.goal_function.query_budget = args.query_budget
recipe.goal_function.model_batch_size = args.model_batch_size
recipe.goal_function.model_cache_size = args.model_cache_size
recipe.constraint_cache_size = args.constraint_cache_size
return recipe
elif args.attack_from_file:
if ":" in args.attack_from_file:
@@ -218,7 +224,11 @@ def parse_attack_from_args(args):
else:
raise ValueError(f"Error: unsupported attack {args.search}")
return textattack.shared.Attack(
goal_function, constraints, transformation, search_method
goal_function,
constraints,
transformation,
search_method,
constraint_cache_size=args.constraint_cache_size,
)
@@ -295,6 +305,22 @@ def parse_model_from_args(args):
model = textattack.shared.utils.load_textattack_model_from_path(
args.model, model_path
)
elif args.model and os.path.exists(args.model):
# If `args.model` is a path/directory, let's assume it was a model
# trained with textattack, and try and load it.
model_args_json_path = os.path.join(args.model, "train_args.json")
if not os.path.exists(model_args_json_path):
raise FileNotFoundError(
f"Tried to load model from path {args.model} - could not find train_args.json."
)
model_train_args = json.loads(open(model_args_json_path).read())
model_train_args["model"] = args.model
num_labels = model_train_args["num_labels"]
from textattack.commands.train_model.train_args_helpers import (
model_from_args,
)
model = model_from_args(argparse.Namespace(**model_train_args), num_labels)
else:
raise ValueError(f"Error: unsupported TextAttack model {args.model}")
return model
@@ -306,7 +332,32 @@ def parse_dataset_from_args(args):
if args.model in HUGGINGFACE_DATASET_BY_MODEL:
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
_, dataset = TEXTATTACK_DATASET_BY_MODEL[args.model]
if dataset[0].startswith("textattack"):
# unsavory way to pass custom dataset classes
# ex: dataset = ('textattack.datasets.translation.TedMultiTranslationDataset', 'en', 'de')
dataset = eval(f"{dataset[0]}")(*dataset[1:])
return dataset
else:
args.dataset_from_nlp = dataset
# Automatically detect dataset for models trained with textattack.
elif args.model and os.path.exists(args.model):
model_args_json_path = os.path.join(args.model, "train_args.json")
if not os.path.exists(model_args_json_path):
raise FileNotFoundError(
f"Tried to load model from path {args.model} - could not find train_args.json."
)
model_train_args = json.loads(open(model_args_json_path).read())
try:
args.dataset_from_nlp = (
model_train_args["dataset"],
None,
model_train_args["dataset_dev_split"],
)
except KeyError:
raise KeyError(
f"Tried to load model from path {args.model} but can't initialize dataset from train_args.json."
)
# Get dataset from args.
if args.dataset_from_file:
@@ -331,8 +382,11 @@ def parse_dataset_from_args(args):
)
elif args.dataset_from_nlp:
dataset_args = args.dataset_from_nlp
if ":" in dataset_args:
dataset_args = dataset_args.split(":")
if isinstance(dataset_args, str):
if ":" in dataset_args:
dataset_args = dataset_args.split(":")
else:
dataset_args = (dataset_args,)
dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, shuffle=args.shuffle
)
@@ -350,8 +404,10 @@ def parse_logger_from_args(args):
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "outputs"
current_dir, os.pardir, os.pardir, os.pardir, "outputs", "attacks"
)
if not os.path.exists(outputs_dir):
os.makedirs(outputs_dir)
args.out_dir = os.path.normpath(outputs_dir)
# if "--log-to-file" specified in terminal command, then save it to a txt file

View File

@@ -157,6 +157,24 @@ class AttackCommand(TextAttackCommand):
default=float("inf"),
help="The maximum number of model queries allowed per example attacked.",
)
parser.add_argument(
"--model-batch-size",
type=int,
default=32,
help="The batch size for making calls to the model.",
)
parser.add_argument(
"--model-cache-size",
type=int,
default=2 ** 18,
help="The maximum number of items to keep in the model results cache at once.",
)
parser.add_argument(
"--constraint-cache-size",
type=int,
default=2 ** 18,
help="The maximum number of items to keep in the constraints cache at once.",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())

View File

@@ -20,15 +20,15 @@ def set_env_variables(gpu_id):
# Set sharing strategy to file_system to avoid file descriptor leaks
torch.multiprocessing.set_sharing_strategy("file_system")
# Only use one GPU, if we have one.
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# For Tensorflow
# TODO: Using USE with `--parallel` raises similar issue as https://github.com/tensorflow/tensorflow/issues/38518#
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# For PyTorch
torch.cuda.set_device(gpu_id)
# Disable tensorflow logs, except in the case of an error.
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Cache TensorFlow Hub models here, if not otherwise specified.
if "TFHUB_CACHE_DIR" not in os.environ:
os.environ["TFHUB_CACHE_DIR"] = os.path.expanduser("~/.cache/tensorflow-hub")
def attack_from_queue(args, in_queue, out_queue):
@@ -125,7 +125,10 @@ def run(args, checkpoint=None):
pbar.update()
num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult:
if (
type(result) == textattack.attack_results.SuccessfulAttackResult
or type(result) == textattack.attack_results.MaximizedAttackResult
):
num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1
@@ -170,6 +173,8 @@ def run(args, checkpoint=None):
finish_time = time.time()
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
return attack_log_manager.results
def pytorch_multiprocessing_workaround():
# This is a fix for a known bug

View File

@@ -18,14 +18,12 @@ logger = textattack.shared.logger
def run(args, checkpoint=None):
# Only use one GPU, if we have one.
# TODO: Running Universal Sentence Encoder uses multiple GPUs
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Disable tensorflow logs, except in the case of an error.
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Cache TensorFlow Hub models here, if not otherwise specified.
if "TFHUB_CACHE_DIR" not in os.environ:
os.environ["TFHUB_CACHE_DIR"] = os.path.expanduser("~/.cache/tensorflow-hub")
if args.checkpoint_resume:
num_remaining_attacks = checkpoint.num_remaining_attacks
@@ -110,7 +108,10 @@ def run(args, checkpoint=None):
num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult:
if (
type(result) == textattack.attack_results.SuccessfulAttackResult
or type(result) == textattack.attack_results.MaximizedAttackResult
):
num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1
@@ -141,6 +142,8 @@ def run(args, checkpoint=None):
finish_time = time.time()
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
return attack_log_manager.results
if __name__ == "__main__":
run(get_args())

View File

@@ -0,0 +1 @@
from .eval_model_command import EvalModelCommand

View File

@@ -12,11 +12,11 @@ def _cb(s):
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
class BenchmarkModelCommand(TextAttackCommand):
class EvalModelCommand(TextAttackCommand):
"""
The TextAttack model benchmarking module:
A command line parser to benchmark a model from user specifications.
A command line parser to evaluatate a model from user specifications.
"""
def get_num_successes(self, model, ids, true_labels):
@@ -83,7 +83,7 @@ class BenchmarkModelCommand(TextAttackCommand):
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"benchmark-model",
"eval",
help="evaluate a model with TextAttack",
formatter_class=ArgumentDefaultsHelpFormatter,
)
@@ -97,4 +97,4 @@ class BenchmarkModelCommand(TextAttackCommand):
default=256,
help="Batch size for model inference.",
)
parser.set_defaults(func=BenchmarkModelCommand())
parser.set_defaults(func=EvalModelCommand())

View File

@@ -0,0 +1,82 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import collections
import re
import numpy as np
import textattack
from textattack.commands import TextAttackCommand
from textattack.commands.attack.attack_args_helpers import (
add_dataset_args,
parse_dataset_from_args,
)
def _cb(s):
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
logger = textattack.shared.logger
class PeekDatasetCommand(TextAttackCommand):
"""
The peek dataset module:
Takes a peek into a dataset in textattack.
"""
def run(self, args):
UPPERCASE_LETTERS_REGEX = re.compile("[A-Z]")
args.model = None # set model to None for parse_dataset_from_args to work
dataset = parse_dataset_from_args(args)
num_words = []
attacked_texts = []
data_all_lowercased = True
outputs = []
for inputs, output in dataset:
at = textattack.shared.AttackedText(inputs)
if data_all_lowercased:
# Test if any of the letters in the string are lowercase.
if re.search(UPPERCASE_LETTERS_REGEX, at.text):
data_all_lowercased = False
attacked_texts.append(at)
num_words.append(len(at.words))
outputs.append(output)
logger.info(f"Number of samples: {_cb(len(attacked_texts))}")
logger.info(f"Number of words per input:")
num_words = np.array(num_words)
logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}')
mean_words = f"{num_words.mean():.2f}"
logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}')
std_words = f"{num_words.std():.2f}"
logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}')
logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}')
logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}')
logger.info(f"Dataset lowercased: {_cb(data_all_lowercased)}")
logger.info("First sample:")
print(attacked_texts[0].printable_text(), "\n")
logger.info("Last sample:")
print(attacked_texts[-1].printable_text(), "\n")
logger.info(f"Found {len(set(outputs))} distinct outputs.")
if len(outputs) < 20:
print(sorted(set(outputs)))
logger.info(f"Most common outputs:")
for i, (key, value) in enumerate(collections.Counter(outputs).most_common(20)):
print("\t", str(key)[:5].ljust(5), f" ({value})")
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"peek-dataset",
help="show main statistics about a dataset",
formatter_class=ArgumentDefaultsHelpFormatter,
)
add_dataset_args(parser)
parser.set_defaults(func=PeekDatasetCommand())

View File

@@ -5,9 +5,10 @@ import sys
from textattack.commands.attack import AttackCommand, AttackResumeCommand
from textattack.commands.augment import AugmentCommand
from textattack.commands.benchmark_model import BenchmarkModelCommand
from textattack.commands.benchmark_recipe import BenchmarkRecipeCommand
from textattack.commands.eval_model import EvalModelCommand
from textattack.commands.list_things import ListThingsCommand
from textattack.commands.peek_dataset import PeekDatasetCommand
from textattack.commands.train_model import TrainModelCommand
@@ -23,10 +24,11 @@ def main():
AttackCommand.register_subcommand(subparsers)
AttackResumeCommand.register_subcommand(subparsers)
AugmentCommand.register_subcommand(subparsers)
BenchmarkModelCommand.register_subcommand(subparsers)
BenchmarkRecipeCommand.register_subcommand(subparsers)
EvalModelCommand.register_subcommand(subparsers)
ListThingsCommand.register_subcommand(subparsers)
TrainModelCommand.register_subcommand(subparsers)
PeekDatasetCommand.register_subcommand(subparsers)
# Let's go
args = parser.parse_args()

View File

@@ -1,18 +0,0 @@
from argparse import ArgumentParser
from textattack.commands import TextAttackCommand
class TrainModelCommand(TextAttackCommand):
"""
The TextAttack train module:
A command line parser to train a model from user specifications.
"""
def run(self, args):
raise NotImplementedError("Cannot train models yet - stay tuned!!")
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser("train", help="train a model")

View File

@@ -0,0 +1 @@
from .train_model_command import TrainModelCommand

View File

@@ -0,0 +1,382 @@
import json
import logging
import os
import time
import numpy as np
import scipy
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import tqdm
import transformers
import textattack
from .train_args_helpers import dataset_from_args, model_from_args, write_readme
device = textattack.shared.utils.device
logger = textattack.shared.logger
def make_directories(output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def batch_encode(tokenizer, text_list):
if hasattr(tokenizer, "batch_encode"):
return tokenizer.batch_encode(text_list)
else:
return [tokenizer.encode(text_input) for text_input in text_list]
def train_model(args):
logger.warn(
"WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues."
)
start_time = time.time()
make_directories(args.output_dir)
num_gpus = torch.cuda.device_count()
# Save logger writes to file
log_txt_path = os.path.join(args.output_dir, "log.txt")
fh = logging.FileHandler(log_txt_path)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
logger.info(f"Writing logs to {log_txt_path}.")
# Use Weights & Biases, if enabled.
if args.enable_wandb:
global wandb
import wandb
wandb.init(sync_tensorboard=True)
# Get list of text and list of label (integers) from disk.
train_text, train_labels, eval_text, eval_labels = dataset_from_args(args)
# Filter labels
if args.allowed_labels:
logger.info(f"Filtering samples with labels outside of {args.allowed_labels}.")
final_train_text, final_train_labels = [], []
for text, label in zip(train_text, train_labels):
if label in args.allowed_labels:
final_train_text.append(text)
final_train_labels.append(label)
logger.info(
f"Filtered {len(train_text)} train samples to {len(final_train_text)} points."
)
train_text, train_labels = final_train_text, final_train_labels
final_eval_text, final_eval_labels = [], []
for text, label in zip(eval_text, eval_labels):
if label in args.allowed_labels:
final_eval_text.append(text)
final_eval_labels.append(label)
logger.info(
f"Filtered {len(eval_text)} dev samples to {len(final_eval_text)} points."
)
eval_text, eval_labels = final_eval_text, final_eval_labels
label_id_len = len(train_labels)
label_set = set(train_labels)
args.num_labels = len(label_set)
logger.info(
f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})"
)
if isinstance(train_labels[0], float):
# TODO come up with a more sophisticated scheme for when to do regression
logger.warn(f"Detected float labels. Doing regression.")
args.num_labels = 1
args.do_regression = True
else:
args.do_regression = False
train_examples_len = len(train_text)
if len(train_labels) != train_examples_len:
raise ValueError(
f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})"
)
if len(eval_labels) != len(eval_text):
raise ValueError(
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})"
)
model = model_from_args(args, args.num_labels)
tokenizer = model.tokenizer
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
train_text_ids = batch_encode(tokenizer, train_text)
logger.info(f"Tokenizing eval data (len: {len(eval_labels)})")
eval_text_ids = batch_encode(tokenizer, eval_text)
load_time = time.time()
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
# multi-gpu training
if num_gpus > 1:
model = torch.nn.DataParallel(model)
logger.info(f"Training model across {num_gpus} GPUs")
num_train_optimization_steps = (
int(train_examples_len / args.batch_size / args.grad_accum_steps)
* args.num_train_epochs
)
if args.model == "lstm" or args.model == "cnn":
need_grad = lambda x: x.requires_grad
optimizer = torch.optim.Adam(
filter(need_grad, model.parameters()), lr=args.learning_rate
)
scheduler = None
else:
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": 0.01,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = transformers.optimization.AdamW(
optimizer_grouped_parameters, lr=args.learning_rate
)
scheduler = transformers.optimization.get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_proportion,
num_training_steps=num_train_optimization_steps,
)
# Start Tensorboard and log hyperparams.
from tensorboardX import SummaryWriter
tb_writer = SummaryWriter(args.output_dir)
def is_writable_type(obj):
for ok_type in [bool, int, str, float]:
if isinstance(obj, ok_type):
return True
return False
args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
tb_writer.add_hparams(args_dict, {})
# Start training
logger.info("***** Running training *****")
logger.info(f"\tNum examples = {train_examples_len}")
logger.info(f"\tBatch size = {args.batch_size}")
logger.info(f"\tMax sequence length = {args.max_length}")
logger.info(f"\tNum steps = {num_train_optimization_steps}")
logger.info(f"\tNum epochs = {args.num_train_epochs}")
logger.info(f"\tLearning rate = {args.learning_rate}")
train_input_ids = np.array(train_text_ids)
train_labels = np.array(train_labels)
train_data = list((ids, label) for ids, label in zip(train_input_ids, train_labels))
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(
train_data, sampler=train_sampler, batch_size=args.batch_size
)
eval_input_ids = np.array(eval_text_ids)
eval_labels = np.array(eval_labels)
eval_data = list((ids, label) for ids, label in zip(eval_input_ids, eval_labels))
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(
eval_data, sampler=eval_sampler, batch_size=args.batch_size
)
def get_eval_score():
model.eval()
correct = 0
total = 0
logits = []
labels = []
for input_ids, batch_labels in eval_dataloader:
if isinstance(input_ids, dict):
## HACK: dataloader collates dict backwards. This is a temporary
# workaround to get ids in the right shape
input_ids = {
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
}
batch_labels = batch_labels.to(device)
with torch.no_grad():
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
logits.extend(batch_logits.cpu().squeeze().tolist())
labels.extend(batch_labels)
model.train()
logits = torch.tensor(logits)
labels = torch.tensor(labels)
if args.do_regression:
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
return pearson_correlation
else:
preds = logits.argmax(dim=1)
correct = (preds == labels).sum()
return float(correct) / len(labels)
def save_model():
model_to_save = (
model.module if hasattr(model, "module") else model
) # Only save the model itself
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, args.weights_name)
output_config_file = os.path.join(args.output_dir, args.config_name)
torch.save(model_to_save.state_dict(), output_model_file)
try:
model_to_save.config.to_json_file(output_config_file)
except AttributeError:
# no config
pass
global_step = 0
tr_loss = 0
def save_model_checkpoint():
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Take care of distributed/parallel training
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info(f"Checkpoint saved to {output_dir}.")
model.train()
args.best_eval_score = 0
args.best_eval_score_epoch = 0
args.epochs_since_best_eval_score = 0
def loss_backward(loss):
if num_gpus > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.grad_accum_steps > 1:
loss = loss / args.grad_accum_steps
loss.backward()
return loss
if args.do_regression:
# TODO integrate with textattack `metrics` package
loss_fct = torch.nn.MSELoss()
else:
loss_fct = torch.nn.CrossEntropyLoss()
for epoch in tqdm.trange(
int(args.num_train_epochs), desc="Epoch", position=0, leave=False
):
prog_bar = tqdm.tqdm(
train_dataloader, desc="Iteration", position=1, leave=False
)
for step, batch in enumerate(prog_bar):
input_ids, labels = batch
labels = labels.to(device)
if isinstance(input_ids, dict):
## HACK: dataloader collates dict backwards. This is a temporary
# workaround to get ids in the right shape
input_ids = {
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
}
logits = textattack.shared.utils.model_predict(model, input_ids)
if args.do_regression:
# TODO integrate with textattack `metrics` package
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
loss = loss_backward(loss)
tr_loss += loss.item()
if global_step % args.tb_writer_step == 0:
tb_writer.add_scalar("loss", loss.item(), global_step)
if scheduler is not None:
tb_writer.add_scalar("lr", scheduler.get_last_lr()[0], global_step)
else:
tb_writer.add_scalar("lr", args.learning_rate, global_step)
if global_step > 0:
prog_bar.set_description(f"Loss {tr_loss/global_step}")
if (step + 1) % args.grad_accum_steps == 0:
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
# Save model checkpoint to file.
if (
global_step > 0
and (args.checkpoint_steps > 0)
and (global_step % args.checkpoint_steps) == 0
):
save_model_checkpoint()
# Inc step counter.
global_step += 1
# Check accuracy after each epoch.
eval_score = get_eval_score()
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
if args.checkpoint_every_epoch:
save_model_checkpoint()
logger.info(
f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
)
if eval_score > args.best_eval_score:
args.best_eval_score = eval_score
args.best_eval_score_epoch = epoch
args.epochs_since_best_eval_score = 0
save_model()
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
else:
args.epochs_since_best_eval_score += 1
if (args.early_stopping_epochs > 0) and (
args.epochs_since_best_eval_score > args.early_stopping_epochs
):
logger.info(
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
)
break
# read the saved model and report its eval performance
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
eval_score = get_eval_score()
logger.info(
f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
)
# end of training, save tokenizer
try:
tokenizer.save_pretrained(args.output_dir)
logger.info(f"Saved tokenizer {tokenizer} to {args.output_dir}.")
except AttributeError:
logger.warn(
f"Error: could not save tokenizer {tokenizer} to {args.output_dir}."
)
# Save a little readme with model info
write_readme(args, args.best_eval_score, args.best_eval_score_epoch)
# Save args to file
args_save_path = os.path.join(args.output_dir, "train_args.json")
final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
with open(args_save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(final_args_dict, indent=2) + "\n")
logger.info(f"Wrote training args to {args_save_path}.")

View File

@@ -0,0 +1,153 @@
import os
import textattack
logger = textattack.shared.logger
def prepare_dataset_for_training(nlp_dataset):
""" Changes an `nlp` dataset into the proper format for tokenization. """
def prepare_example_dict(ex):
""" Returns the values in order corresponding to the data.
ex:
'Some text input'
or in the case of multi-sequence inputs:
('The premise', 'the hypothesis',)
etc.
"""
values = list(ex.values())
if len(values) == 1:
return values[0]
return tuple(values)
text, outputs = zip(*((prepare_example_dict(x[0]), x[1]) for x in nlp_dataset))
return list(text), list(outputs)
def dataset_from_args(args):
""" Returns a tuple of ``HuggingFaceNLPDataset`` for the train and test
datasets for ``args.dataset``.
"""
dataset_args = args.dataset.split(":")
# TODO `HuggingFaceNLPDataset` -> `HuggingFaceDataset`
if args.dataset_train_split:
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split=args.dataset_train_split
)
else:
try:
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split="train"
)
args.dataset_train_split = "train"
except KeyError:
raise KeyError(f"Error: no `train` split found in `{args.dataset}` dataset")
train_text, train_labels = prepare_dataset_for_training(train_dataset)
if args.dataset_dev_split:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split=args.dataset_dev_split
)
else:
# try common dev split names
try:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split="dev"
)
args.dataset_dev_split = "dev"
except KeyError:
try:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split="eval"
)
args.dataset_dev_split = "eval"
except KeyError:
try:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split="validation"
)
args.dataset_dev_split = "validation"
except KeyError:
try:
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, split="test"
)
args.dataset_dev_split = "test"
except KeyError:
raise KeyError(
f"Could not find `dev`, `eval`, `validation`, or `test` split in dataset {args.dataset}."
)
eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
return train_text, train_labels, eval_text, eval_labels
def model_from_args(args, num_labels):
if args.model == "lstm":
textattack.shared.logger.info("Loading textattack model: LSTMForClassification")
model = textattack.models.helpers.LSTMForClassification(
max_seq_length=args.max_length,
num_labels=num_labels,
emb_layer_trainable=False,
)
elif args.model == "cnn":
textattack.shared.logger.info(
"Loading textattack model: WordCNNForClassification"
)
model = textattack.models.helpers.WordCNNForClassification(
max_seq_length=args.max_length,
num_labels=num_labels,
emb_layer_trainable=False,
)
else:
import transformers
textattack.shared.logger.info(
f"Loading transformers AutoModelForSequenceClassification: {args.model}"
)
config = transformers.AutoConfig.from_pretrained(
args.model, num_labels=num_labels, finetuning_task=args.dataset
)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
args.model, config=config,
)
tokenizer = textattack.models.tokenizers.AutoTokenizer(
args.model, use_fast=True, max_length=args.max_length
)
setattr(model, "tokenizer", tokenizer)
model = model.to(textattack.shared.utils.device)
return model
def write_readme(args, best_eval_score, best_eval_score_epoch):
# Save args to file
readme_save_path = os.path.join(args.output_dir, "README.md")
dataset_name = args.dataset.split(":")[0] if ":" in args.dataset else args.dataset
task_name = "regression" if args.do_regression else "classification"
loss_func = "mean squared error" if args.do_regression else "cross-entropy"
metric_name = "pearson correlation" if args.do_regression else "accuracy"
epoch_info = f"{best_eval_score_epoch} epoch" + (
"s" if best_eval_score_epoch > 1 else ""
)
readme_text = f"""
## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
This `{args.model}` model was fine-tuned for sequence classification using TextAttack
and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned
for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning
rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}.
Since this was a {task_name} task, the model was trained with a {loss_func} loss function.
The best score the model achieved on this task was {best_eval_score}, as measured by the
eval set {metric_name}, found after {epoch_info}.
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
"""
with open(readme_save_path, "w", encoding="utf-8") as f:
f.write(readme_text.strip() + "\n")
logger.info(f"Wrote README to {readme_save_path}.")

View File

@@ -0,0 +1,155 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import datetime
import os
from textattack.commands import TextAttackCommand
class TrainModelCommand(TextAttackCommand):
"""
The TextAttack train module:
A command line parser to train a model from user specifications.
"""
def run(self, args):
date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M")
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "outputs", "training"
)
outputs_dir = os.path.normpath(outputs_dir)
args.output_dir = os.path.join(
outputs_dir, f"{args.model}-{args.dataset}-{date_now}/"
)
from .run_training import train_model
train_model(args)
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
"train",
help="train a model for sequence classification",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model", type=str, required=True, help="directory of model to train",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
default="yelp",
help="dataset for training; will be loaded from "
"`nlp` library. if dataset has a subset, separate with a colon. "
" ex: `glue:sst2` or `rotten_tomatoes`",
)
parser.add_argument(
"--dataset-train-split",
"--train-split",
type=str,
default="",
help="train dataset split, if non-standard "
"(can automatically detect 'train'",
)
parser.add_argument(
"--dataset-dev-split",
"--dataset-eval-split",
"--dev-split",
type=str,
default="",
help="val dataset split, if non-standard "
"(can automatically detect 'dev', 'validation', 'eval')",
)
parser.add_argument(
"--tb-writer-step",
type=int,
default=1000,
help="Number of steps before writing to tensorboard",
)
parser.add_argument(
"--checkpoint-steps",
type=int,
default=-1,
help="save model after this many steps (-1 for no checkpointing)",
)
parser.add_argument(
"--checkpoint-every_epoch",
action="store_true",
default=False,
help="save model checkpoint after each epoch",
)
parser.add_argument(
"--num-train-epochs",
"--epochs",
type=int,
default=100,
help="Total number of epochs to train for",
)
parser.add_argument(
"--allowed-labels",
type=int,
nargs="*",
default=[],
help="Labels allowed for training (examples with other labels will be discarded)",
)
parser.add_argument(
"--early-stopping-epochs",
type=int,
default=-1,
help="Number of epochs validation must increase"
" before stopping early (-1 for no early stopping)",
)
parser.add_argument(
"--batch-size", type=int, default=128, help="Batch size for training"
)
parser.add_argument(
"--max-length",
type=int,
default=512,
help="Maximum length of a sequence (anything beyond this will "
"be truncated)",
)
parser.add_argument(
"--learning-rate",
"--lr",
type=float,
default=2e-5,
help="Learning rate for Adam Optimization",
)
parser.add_argument(
"--grad-accum-steps",
type=int,
default=1,
help="Number of steps to accumulate gradients before optimizing, "
"advancing scheduler, etc.",
)
parser.add_argument(
"--warmup-proportion",
type=float,
default=0.1,
help="Warmup proportion for linear scheduling",
)
parser.add_argument(
"--config-name",
type=str,
default="config.json",
help="Filename to save BERT config as",
)
parser.add_argument(
"--weights-name",
type=str,
default="pytorch_model.bin",
help="Filename to save model weights as",
)
parser.add_argument(
"--enable-wandb",
default=False,
action="store_true",
help="log metrics to Weights & Biases",
)
parser.set_defaults(func=TrainModelCommand())

View File

@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
import textattack
from textattack.shared.utils import default_class_repr
class Constraint:
class Constraint(ABC):
"""
An abstract class that represents constraints on adversial text examples.
Constraints evaluate whether transformations from a ``AttackedText`` to another
@@ -68,9 +71,9 @@ class Constraint:
current_text: The current ``AttackedText``.
original_text: The original ``AttackedText`` from which the attack began.
"""
if not isinstance(transformed_text, AttackedText):
if not isinstance(transformed_text, textattack.shared.AttackedText):
raise TypeError("transformed_text must be of type AttackedText")
if not isinstance(current_text, AttackedText):
if not isinstance(current_text, textattack.shared.AttackedText):
raise TypeError("current_text must be of type AttackedText")
try:
@@ -86,6 +89,7 @@ class Constraint:
transformed_text, current_text, original_text=original_text
)
@abstractmethod
def _check_constraint(self, transformed_text, current_text, original_text=None):
"""
Returns True if the constraint is fulfilled, False otherwise. Must be overridden by

View File

@@ -1,3 +1,5 @@
from .language_model_constraint import LanguageModelConstraint
from .google_language_model import Google1BillionWordsLanguageModel
from .gpt2 import GPT2
from .language_model_constraint import LanguageModelConstraint
from .learning_to_write import LearningToWriteLanguageModel

View File

@@ -51,12 +51,15 @@ class GoogleLanguageModel(Constraint):
def get_probs(current_text, transformed_texts):
word_swap_index = current_text.first_word_diff_index(transformed_texts[0])
if word_swap_index is None:
return []
prefix = current_text.words[word_swap_index - 1]
swapped_words = np.array(
[t.words[word_swap_index] for t in transformed_texts]
)
if self.print_step:
print(prefix, swapped_words, suffix)
print(prefix, swapped_words)
probs = self.lm.get_words_probs(prefix, swapped_words)
return probs
@@ -104,6 +107,11 @@ class GoogleLanguageModel(Constraint):
return [transformed_texts[i] for i in max_el_indices]
def _check_constraint(self, transformed_text, current_text, original_text=None):
return self._check_constraint_many(
[transformed_text], current_text, original_text=original_text
)
def __call__(self, x, x_adv):
raise NotImplementedError()

View File

@@ -52,7 +52,7 @@ class GPT2(LanguageModelConstraint):
probs = []
for attacked_text in text_list:
nxt_word_ids = self.tokenizer.encode(attacked_text.words[word_index])
next_word_ids = self.tokenizer.encode(attacked_text.words[word_index])
next_word_prob = predictions[0, -1, next_word_ids[0]]
probs.append(next_word_prob)

View File

@@ -1,25 +1,27 @@
import math
import torch
from abc import ABC, abstractmethod
from textattack.constraints import Constraint
class LanguageModelConstraint(Constraint):
class LanguageModelConstraint(Constraint, ABC):
"""
Determines if two sentences have a swapped word that has a similar
probability according to a language model.
Args:
max_log_prob_diff (float): the maximum difference in log-probability
between x and x_adv
max_log_prob_diff (float): the maximum decrease in log-probability
in swapped words from x to x_adv
compare_against_original (bool): whether to compare against the original
text or the most recent
"""
def __init__(self, max_log_prob_diff=None):
def __init__(self, max_log_prob_diff=None, compare_against_original=True):
if max_log_prob_diff is None:
raise ValueError("Must set max_log_prob_diff")
self.max_log_prob_diff = max_log_prob_diff
self.compare_against_original = compare_against_original
@abstractmethod
def get_log_probs_at_index(self, text_list, word_index):
""" Gets the log-probability of items in `text_list` at index
`word_index` according to a language model.
@@ -27,6 +29,9 @@ class LanguageModelConstraint(Constraint):
raise NotImplementedError()
def _check_constraint(self, transformed_text, current_text, original_text=None):
if self.compare_against_original:
current_text = original_text
try:
indices = transformed_text.attack_attrs["newly_modified_indices"]
except KeyError:
@@ -41,9 +46,7 @@ class LanguageModelConstraint(Constraint):
f"Error: get_log_probs_at_index returned {len(probs)} values for 2 inputs"
)
cur_prob, transformed_prob = probs
if self.max_log_prob_diff is None:
cur_prob, transformed_prob = math.log(p1), math.log(p2)
if abs(cur_prob - transformed_prob) > self.max_log_prob_diff:
if transformed_prob <= cur_prob - self.max_log_prob_diff:
return False
return True

View File

@@ -0,0 +1 @@
from .learning_to_write import LearningToWriteLanguageModel

View File

@@ -0,0 +1,103 @@
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.functional import log_softmax
import textattack
class AdaptiveSoftmax(nn.Module):
def __init__(self, input_size, cutoffs, scale_down=4):
super().__init__()
self.input_size = input_size
self.cutoffs = cutoffs
self.output_size = cutoffs[0] + len(cutoffs) - 1
self.head = nn.Linear(input_size, self.output_size)
self.tail = nn.ModuleList()
for i in range(len(cutoffs) - 1):
seq = nn.Sequential(
nn.Linear(input_size, input_size // scale_down, False),
nn.Linear(input_size // scale_down, cutoffs[i + 1] - cutoffs[i], False),
)
self.tail.append(seq)
def reset(self, init=0.1):
self.head.weight.data.uniform_(-init, init)
for tail in self.tail:
for layer in tail:
layer.weight.data.uniform_(-init, init)
def set_target(self, target):
self.id = []
for i in range(len(self.cutoffs) - 1):
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
if mask.sum() > 0:
self.id.append(Variable(mask.float().nonzero().squeeze(1)))
else:
self.id.append(None)
def forward(self, inp):
assert len(inp.size()) == 2
output = [self.head(inp)]
for i in range(len(self.id)):
if self.id[i] is not None:
output.append(self.tail[i](inp.index_select(0, self.id[i])))
else:
output.append(None)
return output
def log_prob(self, inp):
assert len(inp.size()) == 2
head_out = self.head(inp)
n = inp.size(0)
prob = torch.zeros(n, self.cutoffs[-1]).to(textattack.shared.utils.device)
lsm_head = log_softmax(head_out, dim=head_out.dim() - 1)
prob.narrow(1, 0, self.output_size).add_(
lsm_head.narrow(1, 0, self.output_size).data
)
for i in range(len(self.tail)):
pos = self.cutoffs[i]
i_size = self.cutoffs[i + 1] - pos
buff = lsm_head.narrow(1, self.cutoffs[0] + i, 1)
buff = buff.expand(n, i_size)
temp = self.tail[i](inp)
lsm_tail = log_softmax(temp, dim=temp.dim() - 1)
prob.narrow(1, pos, i_size).copy_(buff.data).add_(lsm_tail.data)
return prob
class AdaptiveLoss(nn.Module):
def __init__(self, cutoffs):
super().__init__()
self.cutoffs = cutoffs
self.criterions = nn.ModuleList()
for i in self.cutoffs:
self.criterions.append(nn.CrossEntropyLoss(size_average=False))
def reset(self):
for criterion in self.criterions:
criterion.zero_grad()
def remap_target(self, target):
new_target = [target.clone()]
for i in range(len(self.cutoffs) - 1):
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
if mask.sum() > 0:
new_target[0][mask] = self.cutoffs[0] + i
new_target.append(target[mask].add(-self.cutoffs[i]))
else:
new_target.append(None)
return new_target
def forward(self, inp, target):
n = inp[0].size(0)
target = self.remap_target(target.data)
loss = 0
for i in range(len(inp)):
if inp[i] is not None:
assert target[i].min() >= 0 and target[i].max() <= inp[i].size(1)
criterion = self.criterions[i]
loss += criterion(inp[i], Variable(target[i]))
loss /= n
return loss

View File

@@ -0,0 +1,115 @@
import os
import numpy as np
import torch
import torchfile
from .rnn_model import RNNModel
class QueryHandler:
def __init__(self, model, word_to_idx, mapto, device):
self.model = model
self.word_to_idx = word_to_idx
self.mapto = mapto
self.device = device
def query(self, sentences, swapped_words, batch_size=32):
""" Since we don't filter prefixes for OOV ahead of time, it's possible that
some of them will have different lengths. When this is the case,
we can't do RNN prediction in batch.
This method _tries_ to do prediction in batch, and, when it fails,
just does prediction sequentially and concatenates all of the results.
"""
try:
return self.try_query(sentences, swapped_words, batch_size=batch_size)
except:
probs = []
for s, w in zip(sentences, swapped_words):
probs.append(self.try_query([s], [w], batch_size=1)[0])
return probs
def try_query(self, sentences, swapped_words, batch_size=32):
# TODO use caching
sentence_length = len(sentences[0])
if any(len(s) != sentence_length for s in sentences):
raise ValueError("Only same length batches are allowed")
log_probs = []
for start in range(0, len(sentences), batch_size):
swapped_words_batch = swapped_words[
start : min(len(sentences), start + batch_size)
]
batch = sentences[start : min(len(sentences), start + batch_size)]
raw_idx_list = [[] for i in range(sentence_length + 1)]
for i, s in enumerate(batch):
s = [word for word in s if word in self.word_to_idx]
words = ["<S>"] + s
word_idxs = [self.word_to_idx[w] for w in words]
for t in range(sentence_length + 1):
if t < len(word_idxs):
raw_idx_list[t].append(word_idxs[t])
orig_num_idxs = len(raw_idx_list)
raw_idx_list = [x for x in raw_idx_list if len(x)]
num_idxs_dropped = orig_num_idxs - len(raw_idx_list)
all_raw_idxs = torch.tensor(
raw_idx_list, device=self.device, dtype=torch.long
)
word_idxs = self.mapto[all_raw_idxs]
hidden = self.model.init_hidden(len(batch))
source = word_idxs[:-1, :]
target = word_idxs[1:, :]
decode, hidden = self.model(source, hidden)
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
for i in range(len(batch)):
if swapped_words_batch[i] not in self.word_to_idx:
log_probs.append(float("-inf"))
else:
log_probs.append(
sum(
[
decode[t, i, target[t, i]].item()
for t in range(sentence_length - num_idxs_dropped)
]
)
)
return log_probs
@staticmethod
def load_model(lm_folder_path, device):
word_map = torchfile.load(os.path.join(lm_folder_path, "word_map.th7"))
word_map = [w.decode("utf-8") for w in word_map]
word_to_idx = {w: i for i, w in enumerate(word_map)}
word_freq = torchfile.load(
os.path.join(os.path.join(lm_folder_path, "word_freq.th7"))
)
mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device)
model_file = open(os.path.join(lm_folder_path, "lm-state-dict.pt"), "rb")
model = RNNModel(
"GRU",
793471,
256,
2048,
1,
[4200, 35000, 180000, 793471],
dropout=0.01,
proj=True,
lm1b=True,
)
model.load_state_dict(torch.load(model_file, map_location=device))
model.full = True # Use real softmax--important!
model.to(device)
model.eval()
model_file.close()
return QueryHandler(model, word_to_idx, mapto, device)
def util_reverse(item):
new_item = np.zeros(len(item))
for idx, val in enumerate(item):
new_item[val] = idx
return new_item

View File

@@ -0,0 +1,56 @@
import torch
import textattack
from textattack.constraints.grammaticality.language_models import (
LanguageModelConstraint,
)
from .language_model_helpers import QueryHandler
class LearningToWriteLanguageModel(LanguageModelConstraint):
""" A constraint based on the L2W language model.
The RNN-based language model from "Learning to Write With Cooperative
Discriminators" (Holtzman et al, 2018).
https://arxiv.org/pdf/1805.06087.pdf
https://github.com/windweller/l2w
Reused by Jia et al., 2019, as a substitution for the Google 1-billion
words language model (in a revised version the attack of Alzantot et
al., 2018).
https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689
"""
CACHE_PATH = "constraints/grammaticality/language-models/learning-to-write"
def __init__(self, window_size=5, **kwargs):
self.window_size = window_size
lm_folder_path = textattack.shared.utils.download_if_needed(
LearningToWriteLanguageModel.CACHE_PATH
)
self.query_handler = QueryHandler.load_model(
lm_folder_path, textattack.shared.utils.device
)
super().__init__(**kwargs)
def get_log_probs_at_index(self, text_list, word_index):
""" Gets the probability of the word at index `word_index` according
to the language model.
"""
queries = []
query_words = []
for attacked_text in text_list:
word = attacked_text.words[word_index]
window_text = attacked_text.text_window_around_index(
word_index, self.window_size
)
query = textattack.shared.utils.words_from_text(window_text)
queries.append(query)
query_words.append(word)
log_probs = self.query_handler.query(queries, query_words)
return torch.tensor(log_probs)

View File

@@ -0,0 +1,101 @@
from torch import nn as nn
from torch.autograd import Variable
from .adaptive_softmax import AdaptiveSoftmax
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder. Based on official pytorch examples"""
def __init__(
self,
rnn_type,
ntoken,
ninp,
nhid,
nlayers,
cutoffs,
proj=False,
dropout=0.5,
tie_weights=False,
lm1b=False,
):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.lm1b = lm1b
if rnn_type == "GRU":
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
except KeyError:
raise ValueError(
"""An invalid option for `--model` was supplied,
options are ['GRU', 'RNN_TANH' or 'RNN_RELU']"""
)
self.rnn = nn.RNN(
ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
)
self.proj = proj
if ninp != nhid and proj:
self.proj_layer = nn.Linear(nhid, ninp)
# if tie_weights:
# if nhid != ninp and not proj:
# raise ValueError('When using the tied flag, nhid must be equal to emsize')
# self.decoder = nn.Linear(ninp, ntoken)
# self.decoder.weight = self.encoder.weight
# else:
# if nhid != ninp and not proj:
# if not lm1b:
# self.decoder = nn.Linear(nhid, ntoken)
# else:
# self.decoder = adapt_loss
# else:
# self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers
if proj:
self.softmax = AdaptiveSoftmax(ninp, cutoffs)
else:
self.softmax = AdaptiveSoftmax(nhid, cutoffs)
self.full = False
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
# self.decoder.bias.data.fill_(0)
# self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
if "proj" in vars(self):
if self.proj:
output = self.proj_layer(output)
output = output.view(output.size(0) * output.size(1), output.size(2))
if self.full:
decode = self.softmax.log_prob(output)
else:
decode = self.softmax(output)
return decode, hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())

View File

@@ -48,9 +48,9 @@ class PartOfSpeech(Constraint):
)
if self.tagger_type == "flair":
word_list, pos_list = zip_flair_result(
self._flair_pos_tagger.predict(context_key)[0]
)
context_key_sentence = Sentence(context_key)
self._flair_pos_tagger.predict(context_key_sentence)
word_list, pos_list = zip_flair_result(context_key_sentence)
self._pos_tag_cache[context_key] = (word_list, pos_list)

View File

@@ -1,5 +1,7 @@
from .pre_transformation_constraint import PreTransformationConstraint
from .stopword_modification import StopwordModification
from .repeat_modification import RepeatModification
from .input_column_modification import InputColumnModification
from .max_word_index_modification import MaxWordIndexModification
from .min_word_length import MinWordLength
from .repeat_modification import RepeatModification
from .stopword_modification import StopwordModification

View File

@@ -0,0 +1,43 @@
from textattack.constraints.pre_transformation import PreTransformationConstraint
class InputColumnModification(PreTransformationConstraint):
"""
A constraint disallowing the modification of words within a specific input
column.
For example, can prevent modification of 'premise' during
entailment.
"""
def __init__(self, matching_column_labels, columns_to_ignore):
self.matching_column_labels = matching_column_labels
self.columns_to_ignore = columns_to_ignore
def _get_modifiable_indices(self, current_text):
""" Returns the word indices in current_text which are able to be
deleted.
If ``current_text.column_labels`` doesn't match
``self.matching_column_labels``, do nothing, and allow all words
to be modified.
If it does match, only allow words to be modified if they are not
in columns from ``columns_to_ignore``.
"""
if current_text.column_labels != self.matching_column_labels:
return set(range(len(current_text.words)))
idx = 0
indices_to_modify = set()
for column, words in zip(
current_text.column_labels, current_text.words_per_input
):
num_words = len(words)
if column not in self.columns_to_ignore:
indices_to_modify |= set(range(idx, idx + num_words))
idx += num_words
return indices_to_modify
def extra_repr_keys(self):
return ["matching_column_labels", "columns_to_ignore"]

View File

@@ -1,5 +1,4 @@
from textattack.constraints.pre_transformation import PreTransformationConstraint
from textattack.shared.utils import default_class_repr
class MaxWordIndexModification(PreTransformationConstraint):
@@ -14,3 +13,5 @@ class MaxWordIndexModification(PreTransformationConstraint):
""" Returns the word indices in current_text which are able to be deleted """
return set(range(min(self.max_length, len(current_text.words))))
def extra_repr_keys(self):
return ["max_length"]

View File

@@ -2,7 +2,6 @@ from textattack.constraints.pre_transformation import PreTransformationConstrain
class MinWordLength(PreTransformationConstraint):
def __init__(self, min_length):
self.min_length = min_length

View File

@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from textattack.constraints import Constraint
from textattack.shared.utils import default_class_repr
class PreTransformationConstraint(Constraint):
class PreTransformationConstraint(Constraint, ABC):
"""
An abstract class that represents constraints which are applied before
the transformation. These restrict which words are allowed to be modified
@@ -23,6 +24,7 @@ class PreTransformationConstraint(Constraint):
return set(range(len(current_text.words)))
return self._get_modifiable_indices(current_text)
@abstractmethod
def _get_modifiable_indices(current_text):
"""
Returns the word indices in ``current_text`` which are able to be modified.
@@ -32,3 +34,8 @@ class PreTransformationConstraint(Constraint):
current_text: The ``AttackedText`` input to consider.
"""
raise NotImplementedError()
def _check_constraint(self):
raise RuntimeError(
"PreTransformationConstraints do not support `_check_constraint()`."
)

View File

@@ -1,5 +1,4 @@
from textattack.constraints.pre_transformation import PreTransformationConstraint
from textattack.shared.utils import default_class_repr
class RepeatModification(PreTransformationConstraint):

View File

@@ -8,12 +8,11 @@
"""
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
"""
import time
import numpy as np
import torch
import torch.nn as nn
from torch import nn as nn
class InferSentModel(nn.Module):

View File

@@ -21,7 +21,8 @@ class SentenceEncoder(Constraint):
compare_with_original (bool): Whether to compare `x_adv` to the previous `x_adv`
or the original `x`.
window_size (int): The number of words to use in the similarity
comparison.
comparison. `None` indicates no windowing (encoding is based on the
full input).
"""
def __init__(
@@ -38,6 +39,9 @@ class SentenceEncoder(Constraint):
self.window_size = window_size
self.skip_text_shorter_than_window = skip_text_shorter_than_window
if not self.window_size:
self.window_size = float("inf")
if metric == "cosine":
self.sim_metric = torch.nn.CosineSimilarity(dim=1)
elif metric == "angular":
@@ -68,7 +72,9 @@ class SentenceEncoder(Constraint):
The similarity between the starting and transformed text using the metric.
"""
try:
modified_index = next(iter(x_adv.attack_attrs["newly_modified_indices"]))
modified_index = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
except KeyError:
raise KeyError(
"Cannot apply sentence encoder constraint without `newly_modified_indices`"
@@ -107,7 +113,7 @@ class SentenceEncoder(Constraint):
``transformed_texts``. If ``transformed_texts`` is empty,
an empty tensor is returned
"""
# Return an empty tensor if x_adv_list is empty.
# Return an empty tensor if transformed_texts is empty.
# This prevents us from calling .repeat(x, 0), which throws an
# error on machines with multiple GPUs (pytorch 1.2).
if len(transformed_texts) == 0:
@@ -137,9 +143,9 @@ class SentenceEncoder(Constraint):
)
)
embeddings = self.encode(starting_text_windows + transformed_text_windows)
starting_embeddings = torch.tensor(embeddings[: len(transformed_texts)]).to(
utils.device
)
if not isinstance(embeddings, torch.Tensor):
embeddings = torch.tensor(embeddings)
starting_embeddings = embeddings[: len(transformed_texts)].to(utils.device)
transformed_embeddings = torch.tensor(
embeddings[len(transformed_texts) :]
).to(utils.device)
@@ -147,18 +153,12 @@ class SentenceEncoder(Constraint):
starting_raw_text = starting_text.text
transformed_raw_texts = [t.text for t in transformed_texts]
embeddings = self.encode([starting_raw_text] + transformed_raw_texts)
if isinstance(embeddings[0], torch.Tensor):
starting_embedding = embeddings[0].to(utils.device)
else:
# If the embedding is not yet a tensor, make it one.
starting_embedding = torch.tensor(embeddings[0]).to(utils.device)
if not isinstance(embeddings, torch.Tensor):
embeddings = torch.tensor(embeddings)
if isinstance(embeddings, list):
# If `encode` did not return a Tensor of all embeddings, combine
# into a tensor.
transformed_embeddings = torch.stack(embeddings[1:]).to(utils.device)
else:
transformed_embeddings = torch.tensor(embeddings[1:]).to(utils.device)
starting_embedding = embeddings[0].to(utils.device)
transformed_embeddings = embeddings[1:].to(utils.device)
# Repeat original embedding to size of perturbed embedding.
starting_embeddings = starting_embedding.unsqueeze(dim=0).repeat(
@@ -209,7 +209,7 @@ class SentenceEncoder(Constraint):
"Must provide original text when compare_with_original is true."
)
else:
scores = self._sim_score(current_text, transformed_texts)
scores = self._sim_score(current_text, transformed_text)
transformed_text.attack_attrs["similarity_score"] = score
return score >= self.threshold

Some files were not shown because too many files have changed in this diff Show More