mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge
This commit is contained in:
34
.github/workflows/check-formatting.yml
vendored
Normal file
34
.github/workflows/check-formatting.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Formatting with black & isort
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
pip install black flake8 isort # Testing packages
|
||||
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
|
||||
pip install -e .
|
||||
- name: Check code format with black and isort
|
||||
run: |
|
||||
make lint
|
||||
37
.github/workflows/make-docs.yml
vendored
Normal file
37
.github/workflows/make-docs.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Build documentation with Sphinx
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo sed -i 's/azure\.//' /etc/apt/sources.list # workaround for flaky pandoc install
|
||||
sudo apt-get update # from here https://github.com/actions/virtual-environments/issues/675
|
||||
sudo apt-get install pandoc -o Acquire::Retries=3 # install pandoc
|
||||
python -m pip install --upgrade pip setuptools wheel # update python
|
||||
pip install ipython --upgrade # needed for Github for whatever reason
|
||||
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
|
||||
pip install -e . ".[dev]" # This should install all packages for development
|
||||
- name: Build docs with Sphinx and check for errors
|
||||
run: |
|
||||
sphinx-build -b html docs docs/_build/html -W
|
||||
31
.github/workflows/publish-to-pypi.yml
vendored
Normal file
31
.github/workflows/publish-to-pypi.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
# This workflows will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
name: Upload Python Package to PyPI
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
pip install setuptools wheel twine
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
run: |
|
||||
python setup.py sdist bdist_wheel
|
||||
twine upload dist/*
|
||||
34
.github/workflows/run-pytest.yml
vendored
Normal file
34
.github/workflows/run-pytest.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Test with PyTest
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
pip install pytest pytest-xdist # Testing packages
|
||||
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
|
||||
pip install -e .
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest tests -vx --dist=loadfile -n auto
|
||||
@@ -9,6 +9,10 @@ It also helps us if you spread the word: reference the library from blog posts
|
||||
on the awesome projects it made possible, shout out on Twitter every time it has
|
||||
helped you, or simply star the repo to say "thank you".
|
||||
|
||||
## Slack Channel
|
||||
|
||||
For help and realtime updates related to TextAttack, please [join the TextAttack Slack](https://join.slack.com/t/textattack/shared_invite/zt-ez3ts03b-Nr55tDiqgAvCkRbbz8zz9g)!
|
||||
|
||||
## Ways to contribute
|
||||
|
||||
There are lots of ways you can contribute to TextAttack:
|
||||
@@ -113,7 +117,7 @@ Follow these steps to start contributing:
|
||||
|
||||
```bash
|
||||
$ cd TextAttack
|
||||
$ pip install -e .
|
||||
$ pip install -e . ".[dev]"
|
||||
$ pip install black isort pytest pytest-xdist
|
||||
```
|
||||
|
||||
@@ -175,11 +179,25 @@ Follow these steps to start contributing:
|
||||
$ git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
6. Add documentation.
|
||||
|
||||
Our docs are in the `docs/` folder. Thanks to `sphinx-automodule`, this
|
||||
should just be two lines. Our docs will automatically generate from the
|
||||
comments you added to your code. If you're adding an attack recipe, add a
|
||||
reference in `attack_recipes.rst`. If you're adding a transformation, add
|
||||
a reference in `transformation.rst`, etc.
|
||||
|
||||
You can build the docs and view the updates using `make docs`. If you're
|
||||
adding a tutorial or something where you want to update the docs multiple
|
||||
times, you can run `make docs-auto`. This will run a server using
|
||||
`sphinx-autobuild` that should automatically reload whenever you change
|
||||
a file.
|
||||
|
||||
7. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
8. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
|
||||
14
Makefile
14
Makefile
@@ -1,22 +1,26 @@
|
||||
format: FORCE ## Run black and isort (rewriting files)
|
||||
black .
|
||||
isort --atomic --recursive tests textattack
|
||||
isort --atomic tests textattack
|
||||
|
||||
|
||||
lint: FORCE ## Run black (in check mode)
|
||||
lint: FORCE ## Run black, isort, flake8 (in check mode)
|
||||
black . --check
|
||||
isort --check-only --recursive tests textattack
|
||||
isort --check-only tests textattack
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8
|
||||
|
||||
test: FORCE ## Run tests using pytest
|
||||
python -m pytest -qx --dist=loadfile -n auto
|
||||
python -m pytest --dist=loadfile -n auto
|
||||
|
||||
docs: FORCE ## Build docs using Sphinx.
|
||||
sphinx-build -b html docs docs/_build/html
|
||||
|
||||
docs-check: FORCE ## Builds docs using Sphinx. If there is an error, exit with an error code (instead of warning & continuing).
|
||||
sphinx-build -b html docs docs/_build/html -W
|
||||
|
||||
docs-auto: FORCE ## Build docs using Sphinx and run hotreload server using Sphinx autobuild.
|
||||
sphinx-autobuild docs docs/_build/html -H 0.0.0.0 -p 8765
|
||||
|
||||
all: format lint test ## Format, lint, and test.
|
||||
all: format lint docs-check test ## Format, lint, and test.
|
||||
|
||||
.PHONY: help
|
||||
|
||||
|
||||
110
README.md
110
README.md
@@ -1,28 +1,41 @@
|
||||
|
||||
|
||||
<h1 align="center">TextAttack 🐙</h1>
|
||||
|
||||
<p align="center">Generating adversarial examples for NLP models</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://textattack.readthedocs.io/">Docs</a> •
|
||||
<a href="https://textattack.readthedocs.io/">[TextAttack Documentation on ReadTheDocs]</a>
|
||||
<br> <br>
|
||||
<a href="#about">About</a> •
|
||||
<a href="#setup">Setup</a> •
|
||||
<a href="#usage">Usage</a> •
|
||||
<a href="#design">Design</a>
|
||||
<br> <br>
|
||||
<a target="_blank" href="https://travis-ci.org/QData/TextAttack">
|
||||
<img src="https://travis-ci.org/QData/TextAttack.svg?branch=master" alt="Coverage Status">
|
||||
<a target="_blank">
|
||||
<img src="https://github.com/QData/TextAttack/workflows/Github%20PyTest/badge.svg" alt="Github Runner Covergae Status">
|
||||
</a>
|
||||
<a href="https://badge.fury.io/py/textattack">
|
||||
<img src="https://badge.fury.io/py/textattack.svg" alt="PyPI version" height="18">
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
<img src="http://jackxmorris.com/files/textattack.gif" alt="TextAttack Demo GIF" style="display: block; margin: 0 auto;" />
|
||||
|
||||
## About
|
||||
|
||||
TextAttack is a Python framework for running adversarial attacks against NLP models. TextAttack builds attacks from four components: a search method, goal function, transformation, and set of constraints. TextAttack's modular design makes it easily extensible to new NLP tasks, models, and attack strategies. TextAttack currently supports attacks on models trained for classification, entailment, and translation.
|
||||
TextAttack is a Python framework for adversarial attacks, data augmentation, and model training in NLP.
|
||||
|
||||
## Slack Channel
|
||||
|
||||
For help and realtime updates related to TextAttack, please [join the TextAttack Slack](https://join.slack.com/t/textattack/shared_invite/zt-ez3ts03b-Nr55tDiqgAvCkRbbz8zz9g)!
|
||||
|
||||
### *Why TextAttack?*
|
||||
|
||||
There are lots of reasons to use TextAttack:
|
||||
|
||||
1. **Understand NLP models better** by running different adversarial attacks on them and examining the output
|
||||
2. **Research and develop different NLP adversarial attacks** using the TextAttack framework and library of components
|
||||
3. **Augment your dataset** to increase model generalization and robustness downstream
|
||||
3. **Train NLP models** using just a single command (all downloads included!)
|
||||
|
||||
## Setup
|
||||
|
||||
@@ -35,12 +48,11 @@ pip install textattack
|
||||
```
|
||||
|
||||
Once TextAttack is installed, you can run it via command-line (`textattack ...`)
|
||||
or via the python module (`python -m textattack ...`).
|
||||
or via python module (`python -m textattack ...`).
|
||||
|
||||
### Configuration
|
||||
TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
|
||||
dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
|
||||
environment variable `TA_CACHE_DIR`.
|
||||
> **Tip**: TextAttack downloads files to `~/.cache/textattack/` by default. This includes pretrained models,
|
||||
> dataset samples, and the configuration file `config.yaml`. To change the cache path, set the
|
||||
> environment variable `TA_CACHE_DIR`. (for example: `TA_CACHE_DIR=/tmp/ textattack attack ...`).
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -49,11 +61,15 @@ common commands are `textattack attack <args>`, and `textattack augment <args>`.
|
||||
information about all commands using `textattack --help`, or a specific command using, for example,
|
||||
`textattack attack --help`.
|
||||
|
||||
The [`examples/`](examples/) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. The[documentation website](https://textattack.readthedocs.io/en/latest) contains walkthroughs explaining basic usage of TextAttack, including building a custom transformation and a custom constraint..
|
||||
|
||||
### Running Attacks
|
||||
|
||||
The [`examples/`](docs/examples/) folder contains notebooks explaining basic usage of TextAttack, including building a custom transformation and a custom constraint. These examples can also be viewed through the [documentation website](https://textattack.readthedocs.io/en/latest).
|
||||
The easiest way to try out an attack is via the command-line interface, `textattack attack`.
|
||||
|
||||
The easiest way to try out an attack is via the command-line interface, `textattack attack`. Here are some concrete examples:
|
||||
> **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance.
|
||||
|
||||
Here are some concrete examples:
|
||||
|
||||
*TextFooler on an LSTM trained on the MR sentiment classification dataset*:
|
||||
```bash
|
||||
@@ -73,15 +89,7 @@ textattack attack --model lstm-mr --num-examples 20 \
|
||||
--goal-function untargeted-classification
|
||||
```
|
||||
|
||||
*Non-overlapping output attack using a greedy word swap and WordNet word substitutions on T5 English-to-German translation:*
|
||||
```bash
|
||||
textattack attack --attack-n --goal-function non-overlapping-output \
|
||||
--model t5-en2de --num-examples 10 --transformation word-swap-wordnet \
|
||||
--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword \
|
||||
--search greedy
|
||||
```
|
||||
|
||||
> **Tip:** If your machine has multiple GPUs, you can distribute the attack across them using the `--parallel` option. For some attacks, this can really help performance.
|
||||
> **Tip:** Instead of specifying a dataset and number of examples, you can pass `--interactive` to attack samples inputted by the user.
|
||||
|
||||
### Attacks and Papers Implemented ("Attack Recipes")
|
||||
|
||||
@@ -89,16 +97,19 @@ We include attack recipes which implement attacks from the literature. You can l
|
||||
|
||||
To run an attack recipe: `textattack attack --recipe [recipe_name]`
|
||||
|
||||
The first are for classification tasks, like sentiment classification and entailment:
|
||||
Attacks on classification tasks, like sentiment classification and entailment:
|
||||
- **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)).
|
||||
- **bae**: BERT masked language model transformation attack from (["BAE: BERT-based Adversarial Examples for Text Classification" (Garg & Ramakrishnan, 2019)](https://arxiv.org/abs/2004.01970)).
|
||||
- **bert-attack**: BERT masked language model transformation attack with subword replacements (["BERT-ATTACK: Adversarial Attack Against BERT Using BERT" (Li et al., 2020)](https://arxiv.org/abs/2004.09984)).
|
||||
- **deepwordbug**: Greedy replace-1 scoring and multi-transformation character-swap attack (["Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers" (Gao et al., 2018)](https://arxiv.org/abs/1801.04354)).
|
||||
- **hotflip**: Beam search and gradient-based word swap (["HotFlip: White-Box Adversarial Examples for Text Classification" (Ebrahimi et al., 2017)](https://arxiv.org/abs/1712.06751)).
|
||||
- **input-reduction**: Reducing the input while maintaining the prediction through word importance ranking (["Pathologies of Neural Models Make Interpretation Difficult" (Feng et al., 2018)](https://arxiv.org/pdf/1804.07781.pdf)).
|
||||
- **kuleshov**: Greedy search and counterfitted embedding swap (["Adversarial Examples for Natural Language Classification Problems" (Kuleshov et al., 2018)](https://openreview.net/pdf?id=r1QZ3zbAZ)).
|
||||
- **pwws**: Greedy attack with word importance ranking based on word saliency and synonym swap scores (["Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency" (Ren et al., 2019)](https://www.aclweb.org/anthology/P19-1103/)).
|
||||
- **textbugger**: Greedy attack with word importance ranking and character-based swaps ([(["TextBugger: Generating Adversarial Text Against Real-world Applications" (Li et al., 2018)](https://arxiv.org/abs/1812.05271)).
|
||||
- **textfooler**: Greedy attack with word importance ranking and counter-fitted embedding swap (["Is Bert Really Robust?" (Jin et al., 2019)](https://arxiv.org/abs/1907.11932)).
|
||||
- **PWWS**: Greedy attack with word importance ranking based on word saliency and synonym swap scores (["Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency" (Ren et al., 2019)](https://www.aclweb.org/anthology/P19-1103/)).
|
||||
|
||||
The final is for sequence-to-sequence models:
|
||||
Attacks on sequence-to-sequence models:
|
||||
- **seq2sick**: Greedy attack with goal of changing every word in the output translation. Currently implemented as black-box with plans to change to white-box as done in paper (["Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples" (Cheng et al., 2018)](https://arxiv.org/abs/1803.01128)).
|
||||
|
||||
#### Recipe Usage Examples
|
||||
@@ -112,7 +123,7 @@ textattack attack --model bert-base-uncased-sst2 --recipe textfooler --num-examp
|
||||
|
||||
*seq2sick (black-box) against T5 fine-tuned for English-German translation:*
|
||||
```bash
|
||||
textattack attack --recipe seq2sick --model t5-en2de --num-examples 100
|
||||
textattack attack --model t5-en-de --recipe seq2sick --num-examples 100
|
||||
```
|
||||
|
||||
### Augmenting Text
|
||||
@@ -175,6 +186,32 @@ of a string or a list of strings. Here's an example of how to use the `Embedding
|
||||
['What I notable create, I do not understand.', 'What I significant create, I do not understand.', 'What I cannot engender, I do not understand.', 'What I cannot creating, I do not understand.', 'What I cannot creations, I do not understand.', 'What I cannot create, I do not comprehend.', 'What I cannot create, I do not fathom.', 'What I cannot create, I do not understanding.', 'What I cannot create, I do not understands.', 'What I cannot create, I do not understood.', 'What I cannot create, I do not realise.']
|
||||
```
|
||||
|
||||
### Training Models
|
||||
|
||||
Our model training code is available via `textattack train` to help you train LSTMs,
|
||||
CNNs, and `transformers` models using TextAttack out-of-the-box. Datasets are
|
||||
automatically loaded using the `nlp` package.
|
||||
|
||||
#### Training Examples
|
||||
*Train our default LSTM for 50 epochs on the Yelp Polarity dataset:*
|
||||
```bash
|
||||
textattack train --model lstm --dataset yelp_polarity --batch-size 64 --epochs 50 --learning-rate 1e-5
|
||||
```
|
||||
|
||||
*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs**:
|
||||
```bash
|
||||
textattack train --model bert-base-uncased --dataset glue:cola --batch-size 32 --epochs 5
|
||||
```
|
||||
|
||||
## `textattack peek-dataset`
|
||||
|
||||
To take a closer look at a dataset, use `textattack peek-dataset`. TextAttack will print some cursory statistics about the inputs and outputs from the dataset. For example, `textattack peek-dataset --dataset-from-nlp snli` will show information about the SNLI dataset from the NLP package.
|
||||
|
||||
|
||||
## `textattack list`
|
||||
|
||||
There are lots of pieces in TextAttack, and it can be difficult to keep track of all of them. You can use `textattack list` to list components, for example, pretrained models (`textattack list models`) or available search methods (`textattack list search-methods`).
|
||||
|
||||
## Design
|
||||
|
||||
### AttackedText
|
||||
@@ -190,10 +227,10 @@ TextAttack is model-agnostic! You can use `TextAttack` to analyze any model that
|
||||
|
||||
TextAttack also comes built-in with models and datasets. Our command-line interface will automatically match the correct
|
||||
dataset to the correct model. We include various pre-trained models for each of the nine [GLUE](https://gluebenchmark.com/)
|
||||
tasks, as well as some common classification datasets, translation, and summarization. You can
|
||||
tasks, as well as some common datasets for classification, translation, and summarization. You can
|
||||
see the full list of provided models & datasets via `textattack attack --help`.
|
||||
|
||||
Here's an example of using one of the built-in models:
|
||||
Here's an example of using one of the built-in models (the SST-2 dataset is automatically loaded):
|
||||
|
||||
```bash
|
||||
textattack attack --model roberta-base-sst2 --recipe textfooler --num-examples 10
|
||||
@@ -206,11 +243,11 @@ and datasets from the [`nlp` package](https://github.com/huggingface/nlp)! Here'
|
||||
and attacking a pre-trained model and dataset:
|
||||
|
||||
```bash
|
||||
textattack attack --model_from_huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset_from_nlp glue:sst2 --recipe deepwordbug --num-examples 10
|
||||
textattack attack --model-from-huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset-from-nlp glue:sst2 --recipe deepwordbug --num-examples 10
|
||||
```
|
||||
|
||||
You can explore other pre-trained models using the `--model_from_huggingface` argument, or other datasets by changing
|
||||
`--dataset_from_nlp`.
|
||||
You can explore other pre-trained models using the `--model-from-huggingface` argument, or other datasets by changing
|
||||
`--dataset-from-nlp`.
|
||||
|
||||
|
||||
#### Loading a model or dataset from a file
|
||||
@@ -229,7 +266,7 @@ model = load_model()
|
||||
tokenizer = load_tokenizer()
|
||||
```
|
||||
|
||||
Then, run an attack with the argument `--model_from_file my_model.py`. The model and tokenizer will be loaded automatically.
|
||||
Then, run an attack with the argument `--model-from-file my_model.py`. The model and tokenizer will be loaded automatically.
|
||||
|
||||
#### Dataset from a file
|
||||
|
||||
@@ -240,7 +277,7 @@ The following example would load a sentiment classification dataset from file `m
|
||||
dataset = [('Today was....', 1), ('This movie is...', 0), ...]
|
||||
```
|
||||
|
||||
You can then run attacks on samples from this dataset by adding the argument `--dataset_from_file my_dataset.py`.
|
||||
You can then run attacks on samples from this dataset by adding the argument `--dataset-from-file my_dataset.py`.
|
||||
|
||||
### Attacks
|
||||
|
||||
@@ -248,7 +285,7 @@ The `attack_one` method in an `Attack` takes as input an `AttackedText`, and out
|
||||
|
||||
### Goal Functions
|
||||
|
||||
A `GoalFunction` takes as input an `AttackedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
|
||||
A `GoalFunction` takes as input an `AttackedText` object, scores it, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
|
||||
|
||||
### Constraints
|
||||
|
||||
@@ -262,10 +299,13 @@ A `Transformation` takes as input an `AttackedText` and returns a list of possib
|
||||
|
||||
A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a final `GoalFunctionResult` The search is given access to the `get_transformations` function, which takes as input an `AttackedText` object and outputs a list of possible transformations filtered by meeting all of the attack’s constraints. A search consists of successive calls to `get_transformations` until the search succeeds (determined using `get_goal_results`) or is exhausted.
|
||||
|
||||
|
||||
## Contributing to TextAttack
|
||||
|
||||
We welcome suggestions and contributions! Submit an issue or pull request and we will do our best to respond in a timely manner. TextAttack is currently in an "alpha" stage in which we are working to improve its capabilities and design.
|
||||
|
||||
See [CONTRIBUTING.md](https://github.com/QData/TextAttack/blob/master/CONTRIBUTING.md) for detailed information on contributing.
|
||||
|
||||
## Citing TextAttack
|
||||
|
||||
If you use TextAttack for your research, please cite [TextAttack: A Framework for Adversarial Attacks in Natural Language Processing](https://arxiv.org/abs/2005.05909).
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
Attack
|
||||
========
|
||||
|
||||
TextAttack builds attacks from four components:
|
||||
|
||||
- `Goal Functions <../attacks/goal_function.html>`__ stipulate the goal of the attack, like to change the prediction score of a classification model, or to change all of the words in a translation output.
|
||||
- `Constraints <../attacks/constraint.html>`__ determine if a potential perturbation is valid with respect to the original input.
|
||||
- `Transformations <../attacks/transformation.html>`__ take a text input and transform it by inserting and deleting characters, words, and/or phrases.
|
||||
- `Search Methods <../attacks/search_method.html>`__ explore the space of possible **transformations** within the defined **constraints** and attempt to find a successful perturbation which satisfies the **goal function**.
|
||||
|
||||
The ``Attack`` class represents an adversarial attack composed of a goal function, search method, transformation, and constraints.
|
||||
|
||||
.. automodule:: textattack.shared.attack
|
||||
|
||||
@@ -5,53 +5,82 @@ We provide a number of pre-built attack recipes. To run an attack recipe, run::
|
||||
|
||||
textattack attack --recipe [recipe_name]
|
||||
|
||||
Alzantot
|
||||
###########
|
||||
Alzantot Genetic Algorithm (Generating Natural Language Adversarial Examples)
|
||||
###################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.alzantot_2018
|
||||
.. automodule:: textattack.attack_recipes.genetic_algorithm_alzantot_2018
|
||||
:members:
|
||||
|
||||
Faster Alzantot Genetic Algorithm (Certified Robustness to Adversarial Word Substitutions)
|
||||
##############################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.faster_genetic_algorithm_jia_2019
|
||||
:members:
|
||||
|
||||
BAE (BAE: BERT-Based Adversarial Examples)
|
||||
#############################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.bae_garg_2019
|
||||
:members:
|
||||
|
||||
DeepWordBug
|
||||
############
|
||||
BERT-Attack: (BERT-Attack: Adversarial Attack Against BERT Using BERT)
|
||||
#########################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.bert_attack_li_2020
|
||||
:members:
|
||||
|
||||
DeepWordBug (Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers)
|
||||
######################################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.deepwordbug_gao_2018
|
||||
:members:
|
||||
|
||||
HotFlip
|
||||
###########
|
||||
HotFlip (HotFlip: White-Box Adversarial Examples for Text Classification)
|
||||
##############################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.hotflip_ebrahimi_2017
|
||||
:members:
|
||||
|
||||
Input Reduction
|
||||
################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.input_reduction_feng_2018
|
||||
:members:
|
||||
|
||||
Kuleshov
|
||||
###########
|
||||
Kuleshov (Adversarial Examples for Natural Language Classification Problems)
|
||||
##############################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.kuleshov_2017
|
||||
:members:
|
||||
|
||||
Seq2Sick
|
||||
###########
|
||||
Particle Swarm Optimization (Word-level Textual Adversarial Attacking as Combinatorial Optimization)
|
||||
#####################################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.PSO_zang_2020
|
||||
:members:
|
||||
|
||||
PWWS (Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency)
|
||||
###################################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.pwws_ren_2019
|
||||
:members:
|
||||
|
||||
Seq2Sick (Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples)
|
||||
#########################################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.seq2sick_cheng_2018_blackbox
|
||||
:members:
|
||||
|
||||
|
||||
TextFooler
|
||||
###########
|
||||
TextFooler (Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and Entailment)
|
||||
########################################################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.textfooler_jin_2019
|
||||
:members:
|
||||
|
||||
|
||||
PWWS
|
||||
###########
|
||||
|
||||
.. automodule:: textattack.attack_recipes.pwws_ren_2019
|
||||
:members:
|
||||
|
||||
|
||||
TextBugger
|
||||
###########
|
||||
TextBugger (TextBugger: Generating Adversarial Text Against Real-world Applications)
|
||||
########################################################################################
|
||||
|
||||
.. automodule:: textattack.attack_recipes.textbugger_li_2018
|
||||
:members:
|
||||
|
||||
@@ -83,6 +83,12 @@ GPT-2
|
||||
|
||||
.. automodule:: textattack.constraints.grammaticality.language_models.gpt2
|
||||
:members:
|
||||
|
||||
"Learning To Write" Language Model
|
||||
************************************
|
||||
|
||||
.. automodule:: textattack.constraints.grammaticality.language_models.learning_to_write.learning_to_write
|
||||
:members:
|
||||
|
||||
|
||||
Google 1-Billion Words Language Model
|
||||
@@ -136,7 +142,7 @@ Maximum Words Perturbed
|
||||
.. _pre_transformation:
|
||||
|
||||
Pre-Transformation
|
||||
----------
|
||||
-------------------------
|
||||
|
||||
Pre-transformation constraints determine if a transformation is valid based on
|
||||
only the original input and the position of the replacement. These constraints
|
||||
@@ -145,7 +151,7 @@ constraints can prevent search methods from swapping words at the same index
|
||||
twice, or from replacing stopwords.
|
||||
|
||||
Pre-Transformation Constraint
|
||||
########################
|
||||
###############################
|
||||
.. automodule:: textattack.constraints.pre_transformation.pre_transformation_constraint
|
||||
:special-members: __call__
|
||||
:private-members:
|
||||
@@ -160,3 +166,13 @@ Repeat Modification
|
||||
########################
|
||||
.. automodule:: textattack.constraints.pre_transformation.repeat_modification
|
||||
:members:
|
||||
|
||||
Input Column Modification
|
||||
#############################
|
||||
.. automodule:: textattack.constraints.pre_transformation.input_column_modification
|
||||
:members:
|
||||
|
||||
Max Word Index Modification
|
||||
###############################
|
||||
.. automodule:: textattack.constraints.pre_transformation.max_word_index_modification
|
||||
:members:
|
||||
|
||||
@@ -69,7 +69,7 @@ Word Swap by Random Character Insertion
|
||||
:members:
|
||||
|
||||
Word Swap by Random Character Substitution
|
||||
---------------------------------------
|
||||
-------------------------------------------
|
||||
|
||||
.. automodule:: textattack.transformations.word_swap_random_character_substitution
|
||||
:members:
|
||||
|
||||
@@ -22,7 +22,7 @@ copyright = "2020, UVA QData Lab"
|
||||
author = "UVA QData Lab"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = "0.0.3.1"
|
||||
release = "0.1.5"
|
||||
|
||||
# Set master doc to `index.rst`.
|
||||
master_doc = "index"
|
||||
@@ -37,7 +37,10 @@ extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx_rtd_theme",
|
||||
# Enable .ipynb doc files
|
||||
"nbsphinx",
|
||||
# Enable .md doc files
|
||||
"recommonmark",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
|
||||
@@ -6,40 +6,10 @@ Datasets
|
||||
:members:
|
||||
:private-members:
|
||||
|
||||
Classification
|
||||
###############
|
||||
.. automodule:: textattack.datasets.classification.classification_dataset
|
||||
.. automodule:: textattack.datasets.huggingface_nlp_dataset
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.datasets.classification.ag_news
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.datasets.classification.imdb_sentiment
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.datasets.classification.kaggle_fake_news
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.datasets.classification.movie_review_sentiment
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.datasets.classification.yelp_sentiment
|
||||
:members:
|
||||
|
||||
Entailment
|
||||
############
|
||||
.. automodule:: textattack.datasets.entailment.entailment_dataset
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.datasets.entailment.mnli
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.datasets.entailment.snli
|
||||
.. automodule:: textattack.datasets.translation.ted_multi
|
||||
:members:
|
||||
|
||||
|
||||
Translation
|
||||
#############
|
||||
.. automodule:: textattack.datasets.translation.translation_datasets
|
||||
:members:
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ We split models up into two broad categories:
|
||||
|
||||
**Classification models:**
|
||||
|
||||
:ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using transformers_.
|
||||
:ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using ``transformers``.
|
||||
|
||||
:ref:`LSTM`: a standard LSTM fine-tuned on various datasets.
|
||||
|
||||
@@ -20,85 +20,32 @@ We split models up into two broad categories:
|
||||
|
||||
**Text-to-text models:**
|
||||
|
||||
:ref:`T5`: ``T5`` fine-tuned on various datasets using transformers_.
|
||||
:ref:`T5`: ``T5`` fine-tuned on various datasets using ``transformers``.
|
||||
|
||||
|
||||
.. _BERT:
|
||||
|
||||
BERT
|
||||
********
|
||||
.. _BERT:
|
||||
|
||||
.. automodule:: textattack.models.helpers.bert_for_classification
|
||||
:members:
|
||||
|
||||
|
||||
We provide pre-trained BERT models on the following datasets:
|
||||
|
||||
.. automodule:: textattack.models.classification.bert.bert_for_ag_news_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.bert.bert_for_imdb_sentiment_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.bert.bert_for_mr_sentiment_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.bert.bert_for_yelp_sentiment_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.entailment.bert.bert_for_mnli
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.entailment.bert.bert_for_snli
|
||||
:members:
|
||||
|
||||
LSTM
|
||||
*******
|
||||
.. _LSTM:
|
||||
|
||||
LSTM
|
||||
*******
|
||||
.. automodule:: textattack.models.helpers.lstm_for_classification
|
||||
:members:
|
||||
|
||||
|
||||
We provide pre-trained LSTM models on the following datasets:
|
||||
|
||||
.. automodule:: textattack.models.classification.lstm.lstm_for_ag_news_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.lstm.lstm_for_imdb_sentiment_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.lstm.lstm_for_mr_sentiment_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.lstm.lstm_for_yelp_sentiment_classification
|
||||
:members:
|
||||
|
||||
|
||||
.. _CNN:
|
||||
|
||||
Word-CNN
|
||||
************
|
||||
.. _CNN:
|
||||
|
||||
.. automodule:: textattack.models.helpers.word_cnn_for_classification
|
||||
:members:
|
||||
|
||||
|
||||
We provide pre-trained CNN models on the following datasets:
|
||||
|
||||
.. automodule:: textattack.models.classification.cnn.word_cnn_for_ag_news_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.cnn.word_cnn_for_imdb_sentiment_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.cnn.word_cnn_for_mr_sentiment_classification
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.models.classification.cnn.word_cnn_for_yelp_sentiment_classification
|
||||
:members:
|
||||
|
||||
|
||||
.. _T5:
|
||||
|
||||
T5
|
||||
@@ -106,21 +53,3 @@ T5
|
||||
|
||||
.. automodule:: textattack.models.helpers.t5_for_text_to_text
|
||||
:members:
|
||||
|
||||
|
||||
We provide pre-trained T5 models on the following tasks & datasets:
|
||||
|
||||
Translation
|
||||
##############
|
||||
|
||||
.. automodule:: textattack.models.translation.t5.t5_models
|
||||
:members:
|
||||
|
||||
Summarization
|
||||
##############
|
||||
|
||||
.. automodule:: textattack.models.summarization.t5_summarization
|
||||
:members:
|
||||
|
||||
|
||||
.. _transformers: https://github.com/huggingface/transformers
|
||||
|
||||
@@ -2,20 +2,14 @@
|
||||
Tokenizers
|
||||
===========
|
||||
|
||||
.. automodule:: textattack.tokenizers.tokenizer
|
||||
.. automodule:: textattack.models.tokenizers.auto_tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.tokenizers.auto_tokenizer
|
||||
.. automodule:: textattack.models.tokenizers.glove_tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.tokenizers.spacy_tokenizer
|
||||
.. automodule:: textattack.models.tokenizers.t5_tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.tokenizers.t5_tokenizer
|
||||
.. automodule:: textattack.models.tokenizers.bert_tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.tokenizers.bert_tokenizer
|
||||
:members:
|
||||
|
||||
.. automodule:: textattack.tokenizers.bert_entailment_tokenizer
|
||||
:members:
|
||||
9987
docs/examples/0_End_to_End.ipynb
Normal file
9987
docs/examples/0_End_to_End.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# The TextAttack🐙 ecosystem: search, transformations, and constraints\n",
|
||||
"# The TextAttack ecosystem: search, transformations, and constraints\n",
|
||||
"\n",
|
||||
"An attack in TextAttack consists of four parts.\n",
|
||||
"\n",
|
||||
@@ -31,9 +31,9 @@
|
||||
"This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\" (these don't intend to preserve the grammaticality of inputs). Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Banana word swap 🍌\n",
|
||||
"### Banana word swap \n",
|
||||
"\n",
|
||||
"As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'."
|
||||
"As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'. 🍌"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -308,9 +308,9 @@
|
||||
"collapsed": true
|
||||
},
|
||||
"source": [
|
||||
"### Conclusion 🍌\n",
|
||||
"### Conclusion n",
|
||||
"\n",
|
||||
"We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a single \"banana\", while others needed up to 17 \"banana\" substitutions to change the class score. Wow!"
|
||||
"We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a couple \"banana\"s, while others needed up to 17 \"banana\" substitutions to change the class score. Wow! 🍌"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -16,7 +16,7 @@ TextAttack provides a framework for constructing and thinking about attacks via
|
||||
TextAttack provides a set of `Attack Recipes <attacks/attack_recipes.html>`__ that assemble attacks from the literature from these four components.
|
||||
|
||||
Data Augmentation
|
||||
-------------
|
||||
--------------------
|
||||
Data augmentation is easy and extremely common in computer vision but harder and less common in NLP. We provide a `Data Augmentation <augmentation/augmenter.html>`__ module using transformations and constraints.
|
||||
|
||||
Features
|
||||
@@ -31,12 +31,14 @@ TextAttack has some other features that make it a pleasure to use:
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
:caption: Quickstart
|
||||
:caption: Getting Started
|
||||
|
||||
quickstart/installation
|
||||
quickstart/overview
|
||||
Example 1: Transformations <examples/1_Introduction_and_Transformations.ipynb>
|
||||
Example 2: Constraints <examples/2_Constraints.ipynb>
|
||||
|
||||
Installation <quickstart/installation>
|
||||
Command-Line Usage <quickstart/command_line_usage>
|
||||
Tutorial 0: TextAttack End-To-End (Train, Eval, Attack) <examples/0_End_to_End.ipynb>
|
||||
Tutorial 1: Transformations <examples/1_Introduction_and_Transformations.ipynb>
|
||||
Tutorial 2: Constraints <examples/2_Constraints.ipynb>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -73,7 +75,7 @@ TextAttack has some other features that make it a pleasure to use:
|
||||
:hidden:
|
||||
:caption: Miscellaneous
|
||||
|
||||
misc/attacked_text
|
||||
misc/checkpoints
|
||||
misc/loggers
|
||||
misc/validators
|
||||
misc/tokenized_text
|
||||
|
||||
6
docs/misc/attacked_text.rst
Normal file
6
docs/misc/attacked_text.rst
Normal file
@@ -0,0 +1,6 @@
|
||||
===================
|
||||
Attacked Text
|
||||
===================
|
||||
|
||||
.. automodule:: textattack.shared.attacked_text
|
||||
:members:
|
||||
@@ -1,6 +0,0 @@
|
||||
===================
|
||||
Tokenized Text
|
||||
===================
|
||||
|
||||
.. automodule:: textattack.shared.tokenized_text
|
||||
:members:
|
||||
135
docs/quickstart/command_line_usage.md
Normal file
135
docs/quickstart/command_line_usage.md
Normal file
@@ -0,0 +1,135 @@
|
||||
Command-Line Usage
|
||||
=======================================
|
||||
|
||||
The easiest way to use textattack is from the command-line. Installing textattack
|
||||
will provide you with the handy `textattack` command which will allow you to do
|
||||
just about anything TextAttack offers in a single bash command.
|
||||
|
||||
> *Tip*: If you are for some reason unable to use the `textattack` command, you
|
||||
> can access all the same functionality by prepending `python -m` to the command
|
||||
> (`python -m textattack ...`).
|
||||
|
||||
To see all available commands, type `textattack --help`. This page explains
|
||||
some of the most important functionalities of textattack: NLP data augmentation,
|
||||
adversarial attacks, and training and evaluating models.
|
||||
|
||||
## Data Augmentation with `textattack augment`
|
||||
|
||||
The easiest way to use our data augmentation tools is with `textattack augment <args>`. `textattack augment`
|
||||
takes an input CSV file and text column to augment, along with the number of words to change per augmentation
|
||||
and the number of augmentations per input example. It outputs a CSV in the same format with all the augmentation
|
||||
examples corresponding to the proper columns.
|
||||
|
||||
For example, given the following as `examples.csv`:
|
||||
|
||||
```
|
||||
"text",label
|
||||
"the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1
|
||||
"the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1
|
||||
"take care of my cat offers a refreshingly different slice of asian cinema .", 1
|
||||
"a technically well-made suspenser . . . but its abrupt drop in iq points as it races to the finish line proves simply too discouraging to let slide .", 0
|
||||
"it's a mystery how the movie could be released in this condition .", 0
|
||||
```
|
||||
|
||||
The command:
|
||||
```
|
||||
textattack augment --csv examples.csv --input-column text --recipe embedding --num-words-to-swap 4 \
|
||||
--transformations-per-example 2 --exclude-original
|
||||
```
|
||||
will augment the `text` column with four swaps per augmentation, twice as many augmentations as original inputs, and exclude the original inputs from the
|
||||
output CSV. (All of this will be saved to `augment.csv` by default.)
|
||||
|
||||
After augmentation, here are the contents of `augment.csv`:
|
||||
```
|
||||
text,label
|
||||
"the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1
|
||||
"the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1
|
||||
the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of expression significant adequately describe co-writer/director pedro jackson's expanded vision of j . rs . r . tolkien's middle-earth .,1
|
||||
the gorgeously elaborate continuation of 'the lordy of the piercings' trilogy is so huge that a column of mots cannot adequately describe co-novelist/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .,1
|
||||
take care of my cat offerings a pleasantly several slice of asia cinema .,1
|
||||
taking care of my cat offers a pleasantly different slice of asiatic kino .,1
|
||||
a technically good-made suspenser . . . but its abrupt drop in iq points as it races to the finish bloodline proves straightforward too disheartening to let slide .,0
|
||||
a technically well-made suspenser . . . but its abrupt drop in iq dot as it races to the finish line demonstrates simply too disheartening to leave slide .,0
|
||||
it's a enigma how the film wo be releases in this condition .,0
|
||||
it's a enigma how the filmmaking wo be publicized in this condition .,0
|
||||
```
|
||||
|
||||
The 'embedding' augmentation recipe uses counterfitted embedding nearest-neighbors to augment data.
|
||||
|
||||
## Adversarial Attacks with `textattack attack`
|
||||
|
||||
The heart of textattack is running adversarial attacks on NLP models with
|
||||
`textattack attack`. You can build an attack from the command-line in several ways:
|
||||
1. Use an **attack recipe** to launch an attack from the literature: `textattack attack --recipe deepwordbug`
|
||||
2. Build your attack from components:
|
||||
```
|
||||
textattack attack --model lstm-mr --num-examples 20 --search-method beam-search:beam_width=4 \
|
||||
--transformation word-swap-embedding \
|
||||
--constraints repeat stopword max-words-perturbed:max_num_words=2 embedding:min_cos_sim=0.8 part-of-speech \
|
||||
--goal-function untargeted-classification
|
||||
```
|
||||
3. Create a python file that builds your attack and load it: `textattack attack --attack-from-file my_file.py:my_attack_name`
|
||||
|
||||
## Training Models with `textattack train`
|
||||
|
||||
With textattack, you can train models on any classification or regression task
|
||||
from [`nlp`](https://github.com/huggingface/nlp/) using a single line.
|
||||
|
||||
### Available Models
|
||||
#### TextAttack Models
|
||||
TextAttack has two build-in model types, a 1-layer bidirectional LSTM with a hidden
|
||||
state size of 150 (`lstm`), and a WordCNN with 3 window sizes
|
||||
(3, 4, 5) and 100 filters for the window size (`cnn`). Both models set dropout
|
||||
to 0.3 and use a base of the 200-dimensional GLoVE embeddings.
|
||||
|
||||
#### `transformers` Models
|
||||
Along with the `lstm` and `cnn`, you can theoretically fine-tune any model based
|
||||
in the huggingface [transformers](https://github.com/huggingface/transformers/)
|
||||
repo. Just type the model name (like `bert-base-cased`) and it will be automatically
|
||||
loaded.
|
||||
|
||||
Here are some models from transformers that have worked well for us:
|
||||
- `bert-base-uncased` and `bert-base-cased`
|
||||
- `distilbert-base-uncased` and `distilbert-base-cased`
|
||||
- `albert-base-v2`
|
||||
- `roberta-base`
|
||||
- `xlnet-base-cased`
|
||||
|
||||
## Evaluating Models with `textattack eval-model`
|
||||
|
||||
|
||||
## Other Commands
|
||||
|
||||
### Checkpoints and `textattack attack-resume`
|
||||
|
||||
Some attacks can take a very long time. Sometimes this is because they're using
|
||||
a very slow search method (like beam search with a high beam width) or sometimes
|
||||
they're just attacking a large number of samples. In these cases, it can be
|
||||
useful to save attack checkpoints throughout the course of the attack. Then,
|
||||
if the attack crashes for some reason, you can resume without restarting from
|
||||
scratch.
|
||||
|
||||
- To save checkpoints while running an attack, add the argument `--checkpoint-interval X`,
|
||||
where X is the number of attacks you want to run between checkpoints (for example `textattack attack <args> --checkpoint-interval 5`).
|
||||
- To load an attack from a checkpoint, use `textattack attack-resume --checkpoint-file <checkpoint-file>`.
|
||||
|
||||
### Listing features with `textattack list`
|
||||
|
||||
TextAttack has a lot of built-in features (models, search methods, constraints, etc.)
|
||||
and it can get overwhelming to keep track of all the options. To list all of the
|
||||
options within a given category, use `textattack list`.
|
||||
|
||||
For example:
|
||||
- list all the built-in models: `textattack list models`
|
||||
- list all constraints: `textattack list constraints`
|
||||
- list all search methods: `textattack list search-methods`
|
||||
|
||||
### Examining datasets with `textattack peek-dataset`
|
||||
It can be useful to take a cursory look at and compute some basic statistics of
|
||||
whatever dataset you're working with. Whether you're loading a dataset of your
|
||||
own from a file, or one from NLP, you can use `textattack peek-dataset` to
|
||||
see some basic information about the dataset.
|
||||
|
||||
For example, use `textattack peek-dataset --dataset-from-nlp glue:mrpc` to see
|
||||
information about the MRPC dataset (from the GLUE set of datasets). This will
|
||||
print statistics like the number of labels, average number of words, etc.
|
||||
@@ -8,7 +8,7 @@ To use TextAttack, you must be running Python 3.6+. A CUDA-compatible GPU is opt
|
||||
|
||||
You're now all set to use TextAttack! Try running an attack from the command line::
|
||||
|
||||
textattack attack --recipe textfooler --model bert-mr --num-examples 10
|
||||
textattack attack --recipe textfooler --model bert-base-uncased-mr --num-examples 10
|
||||
|
||||
This will run an attack using the TextFooler_ recipe, attacking BERT fine-tuned on the MR dataset. It will attack the first 10 samples. Once everything downloads and starts running, you should see attack results print to ``stdout``.
|
||||
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
===========
|
||||
Overview
|
||||
===========
|
||||
TextAttack builds attacks from four components:
|
||||
|
||||
- `Goal Functions <../attacks/goal_function.html>`__ stipulate the goal of the attack, like to change the prediction score of a classification model, or to change all of the words in a translation output.
|
||||
- `Constraints <../attacks/constraint.html>`__ determine if a potential perturbation is valid with respect to the original input.
|
||||
- `Transformations <../attacks/transformation.html>`__ take a text input and transform it by inserting and deleting characters, words, and/or phrases.
|
||||
- `Search Methods <../attacks/search_method.html>`__ explore the space of possible **transformations** within the defined **constraints** and attempt to find a successful perturbation which satisfies the **goal function**.
|
||||
|
||||
Any model that overrides ``__call__``, takes ``TokenizedText`` as input, and formats output correctly can be used with TextAttack. TextAttack also has built-in datasets and pre-trained models on these datasets. Below is an example of attacking a pre-trained model on the AGNews dataset::
|
||||
|
||||
from tqdm import tqdm
|
||||
from textattack.loggers import FileLogger
|
||||
|
||||
from textattack.datasets.classification import AGNews
|
||||
from textattack.models.classification.lstm import LSTMForAGNewsClassification
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
|
||||
from textattack.shared import Attack
|
||||
from textattack.search_methods import GreedySearch
|
||||
from textattack.transformations import WordSwapEmbedding
|
||||
from textattack.constraints.grammaticality import PartOfSpeech
|
||||
from textattack.constraints.semantics import RepeatModification, StopwordModification
|
||||
|
||||
# Create the model and goal function
|
||||
model = LSTMForAGNewsClassification()
|
||||
goal_function = UntargetedClassification(model)
|
||||
|
||||
# Use the default WordSwapEmbedding transformation
|
||||
transformation = WordSwapEmbedding()
|
||||
|
||||
# Add a constraint, note that an empty list can be used if no constraints are wanted
|
||||
constraints = [
|
||||
RepeatModification(),
|
||||
StopwordModification(),
|
||||
PartOfSpeech()
|
||||
]
|
||||
|
||||
# Choose a search method
|
||||
search = GreedySearch()
|
||||
|
||||
# Make an attack with the above parameters
|
||||
attack = Attack(goal_function, constraints, transformation, search)
|
||||
|
||||
# Run the attack on 5 examples and see the results using a logger to output to stdout
|
||||
results = attack.attack_dataset(AGNews(), num_examples=5, attack_n=True)
|
||||
|
||||
logger = FileLogger(stdout=True)
|
||||
|
||||
for result in tqdm(results, total=5):
|
||||
logger.log_attack_result(result)
|
||||
@@ -1,2 +1,4 @@
|
||||
recommonmark
|
||||
nbsphinx
|
||||
sphinx-autobuild
|
||||
sphinx-rtd-theme
|
||||
|
||||
7
examples/attack/attack_from_components.sh
Executable file
7
examples/attack/attack_from_components.sh
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
# Shows how to build an attack from components and use it on a pre-trained
|
||||
# model on the Yelp dataset.
|
||||
textattack attack --attack-n --goal-function untargeted-classification \
|
||||
--model bert-base-uncased-yelp --num-examples 8 --transformation word-swap-wordnet \
|
||||
--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword \
|
||||
--search greedy
|
||||
4
examples/attack/attack_huggingface_deepwordbug.sh
Executable file
4
examples/attack/attack_huggingface_deepwordbug.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
# Shows how to attack a DistilBERT model fine-tuned on SST2 dataset *from the
|
||||
# huggingface model repository& using the DeepWordBug recipe and 10 examples.
|
||||
textattack attack --model-from-huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset-from-nlp glue:sst2 --recipe deepwordbug --num-examples 10
|
||||
4
examples/attack/attack_roberta_sst2_textfooler.sh
Executable file
4
examples/attack/attack_roberta_sst2_textfooler.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
# Shows how to attack our RoBERTA model fine-tuned on SST2 using the TextFooler
|
||||
# recipe and 10 examples.
|
||||
textattack attack --model roberta-base-sst2 --recipe textfooler --num-examples 10
|
||||
2
examples/augmentation/example.csv
Normal file
2
examples/augmentation/example.csv
Normal file
@@ -0,0 +1,2 @@
|
||||
"text",label
|
||||
"it's a mystery how the movie could be released in this condition .", 0
|
||||
|
5
examples/train/train_albert_snli_entailment.sh
Executable file
5
examples/train/train_albert_snli_entailment.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a
|
||||
# demonstration of how our training script can handle different `transformers`
|
||||
# models and customize for different datasets.
|
||||
textattack train --model albert-base-v2 --dataset snli --batch-size 128 --epochs 5 --max-length 128 --learning-rate 1e-5 --allowed-labels 0 1 2
|
||||
4
examples/train/train_bert_stsb_similarity.sh
Executable file
4
examples/train/train_bert_stsb_similarity.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a demonstration
|
||||
# of how our training script handles regression.
|
||||
textattack train --model bert-base-cased --dataset glue:stsb --batch-size 128 --epochs 3 --max-length 128 --learning-rate 1e-5
|
||||
4
examples/train/train_lstm_rotten_tomatoes_sentiment_classification.sh
Executable file
4
examples/train/train_lstm_rotten_tomatoes_sentiment_classification.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a basic
|
||||
# demonstration of our training script and `nlp` integration.
|
||||
textattack train --model lstm --dataset rotten_romatoes --batch-size 64 --epochs 50 --learning-rate 1e-5
|
||||
@@ -1,24 +1,23 @@
|
||||
click
|
||||
bert-score
|
||||
editdistance
|
||||
flair==0.5.1
|
||||
filelock
|
||||
language_tool_python
|
||||
lru-dict
|
||||
nlp
|
||||
nltk
|
||||
numpy
|
||||
pandas
|
||||
pyyaml>=5.1
|
||||
pandas>=1.0.1
|
||||
scikit-learn
|
||||
scipy==1.4.1
|
||||
sentence_transformers
|
||||
spacy
|
||||
sentence_transformers==0.2.6.1
|
||||
torch
|
||||
transformers>=2.5.1
|
||||
transformers>=3
|
||||
tensorflow>=2
|
||||
tensorflow_hub
|
||||
tensorboardX
|
||||
terminaltables
|
||||
tokenizers==0.8.0-rc4
|
||||
tqdm
|
||||
visdom
|
||||
wandb
|
||||
flair
|
||||
bert-score
|
||||
|
||||
14
setup.cfg
14
setup.cfg
@@ -1,9 +1,3 @@
|
||||
[flake8]
|
||||
ignore = E203, E266, E501, W503
|
||||
max-line-length = 120
|
||||
per-file-ignores = __init__.py:F401
|
||||
mypy_config = mypy.ini
|
||||
|
||||
[isort]
|
||||
line_length = 88
|
||||
skip = __init__.py
|
||||
@@ -14,3 +8,11 @@ multi_line_output = 3
|
||||
include_trailing_comma = True
|
||||
use_parentheses = True
|
||||
force_grid_wrap = 0
|
||||
|
||||
[flake8]
|
||||
exclude = .git,__pycache__,wandb,build,dist
|
||||
ignore = E203, E266, E501, W503, D203
|
||||
max-complexity = 10
|
||||
max-line-length = 120
|
||||
mypy_config = mypy.ini
|
||||
per-file-ignores = __init__.py:F401
|
||||
|
||||
11
setup.py
11
setup.py
@@ -6,6 +6,14 @@ from docs import conf as docs_conf
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
extras = {}
|
||||
# Packages required for installing docs.
|
||||
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
|
||||
# Packages required for formatting code & running tests.
|
||||
extras["test"] = ["black", "isort==5.0.3", "flake8", "pytest", "pytest-xdist"]
|
||||
# For developers, install development tools along with all optional dependencies.
|
||||
extras["dev"] = extras["docs"] + extras["test"]
|
||||
|
||||
setuptools.setup(
|
||||
name="textattack",
|
||||
version=docs_conf.release,
|
||||
@@ -22,12 +30,13 @@ setuptools.setup(
|
||||
"build*",
|
||||
"docs*",
|
||||
"dist*",
|
||||
"examples*",
|
||||
"outputs*",
|
||||
"tests*",
|
||||
"local_test*",
|
||||
"wandb*",
|
||||
]
|
||||
),
|
||||
extras_require=extras,
|
||||
entry_points={
|
||||
"console_scripts": ["textattack=textattack.commands.textattack_cli:main"],
|
||||
},
|
||||
|
||||
@@ -28,6 +28,10 @@
|
||||
)
|
||||
(3): RepeatModification
|
||||
(4): StopwordModification
|
||||
(5): InputColumnModification(
|
||||
(matching_column_labels): ['premise', 'hypothesis']
|
||||
(columns_to_ignore): {'premise'}
|
||||
)
|
||||
(is_black_box): True
|
||||
)
|
||||
/.*/
|
||||
|
||||
57
tests/sample_outputs/kuleshov_cnn_sst_2.txt
Normal file
57
tests/sample_outputs/kuleshov_cnn_sst_2.txt
Normal file
@@ -0,0 +1,57 @@
|
||||
/.*/Attack(
|
||||
(search_method): GreedySearch
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 15
|
||||
(embedding_type): paragramcf
|
||||
)
|
||||
(constraints):
|
||||
(0): MaxWordsPerturbed(
|
||||
(max_percent): 0.5
|
||||
)
|
||||
(1): ThoughtVector(
|
||||
(embedding_type): paragramcf
|
||||
(metric): max_euclidean
|
||||
(threshold): -0.2
|
||||
(compare_with_original): False
|
||||
(window_size): inf
|
||||
(skip_text_shorter_than_window): False
|
||||
)
|
||||
(2): GPT2(
|
||||
(max_log_prob_diff): 2.0
|
||||
)
|
||||
(3): RepeatModification
|
||||
(4): StopwordModification
|
||||
(is_black_box): True
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91mNegative (69%)[0m
|
||||
|
||||
it 's a [92mcharming[0m and [92moften[0m affecting journey .
|
||||
|
||||
it 's a [91mloveable[0m and [91mordinarily[0m affecting journey .
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[91mNegative (83%)[0m --> [92mPositive (90%)[0m
|
||||
|
||||
unflinchingly bleak and [91mdesperate[0m
|
||||
|
||||
unflinchingly bleak and [92mdesperation[0m
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 2 |
|
||||
| Number of failed attacks: | 0 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 0.0% |
|
||||
| Attack success rate: | 100.0% |
|
||||
| Average perturbed word %: | 25.0% |
|
||||
| Average num. words per input: | 6.0 |
|
||||
| Avg num queries: | 48.5 |
|
||||
+-------------------------------+--------+
|
||||
@@ -14,19 +14,19 @@
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91mNegative (73%)[0m
|
||||
[92mPositive (100%)[0m --> [91mNegative (88%)[0m
|
||||
|
||||
[92mlovingly[0m photographed in the manner of a golden book sprung to life , stuart little 2 manages [92msweetness[0m largely without stickiness .
|
||||
[92mlovingly[0m photographed in the [92mmanner[0m of a golden book sprung to [92mlife[0m , stuart little 2 manages sweetness largely without stickiness .
|
||||
|
||||
[91mcovingly[0m photographed in the manner of a golden book sprung to life , stuart little 2 manages [91mseetness[0m largely without stickiness .
|
||||
[91mlocingly[0m photographed in the [91mmanenr[0m of a golden book sprung to [91mlief[0m , stuart little 2 manages sweetness largely without stickiness .
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91mNegative (62%)[0m
|
||||
[92mPositive (100%)[0m --> [91mNegative (61%)[0m
|
||||
|
||||
consistently [92mclever[0m and [92msuspenseful[0m .
|
||||
[92mconsistently[0m [92mclever[0m and [92msuspenseful[0m .
|
||||
|
||||
consistently [91mclevger[0m and [91msurpenseful[0m .
|
||||
[91mcnosistently[0m [91mMclever[0m and [91msuspensWful[0m .
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ consistently [91mclevger[0m and [91msurpenseful[0m .
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 0.0% |
|
||||
| Attack success rate: | 100.0% |
|
||||
| Average perturbed word %: | 30.26% |
|
||||
| Average num. words per input: | 11.5 |
|
||||
| Avg num queries: | 20.5 |
|
||||
| Average perturbed word %: | 45.0% |
|
||||
| Average num. words per input: | 12.0 |
|
||||
| Avg num queries: | 27.0 |
|
||||
+-------------------------------+--------+
|
||||
|
||||
66
tests/sample_outputs/run_attack_faster_alzantot_recipe.txt
Normal file
66
tests/sample_outputs/run_attack_faster_alzantot_recipe.txt
Normal file
@@ -0,0 +1,66 @@
|
||||
/.*/Attack(
|
||||
(search_method): GeneticAlgorithm(
|
||||
(pop_size): 60
|
||||
(max_iters): 20
|
||||
(temp): 0.3
|
||||
(give_up_if_no_improvement): False
|
||||
)
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapEmbedding(
|
||||
(max_candidates): 8
|
||||
(embedding_type): paragramcf
|
||||
)
|
||||
(constraints):
|
||||
(0): MaxWordsPerturbed(
|
||||
(max_percent): 0.2
|
||||
)
|
||||
(1): WordEmbeddingDistance(
|
||||
(embedding_type): paragramcf
|
||||
(max_mse_dist): 0.5
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
)
|
||||
(2): LearningToWriteLanguageModel(
|
||||
(max_log_prob_diff): 5.0
|
||||
)
|
||||
(3): RepeatModification
|
||||
(4): StopwordModification
|
||||
(is_black_box): True
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91mNegative (73%)[0m
|
||||
|
||||
this kind of hands-on [92mstorytelling[0m is ultimately what makes shanghai ghetto move beyond a good , [92mdry[0m , [92mreliable[0m textbook and what allows it to rank with its worthy predecessors .
|
||||
|
||||
this kind of hands-on [91mtale[0m is ultimately what makes shanghai ghetto move beyond a good , [91msecs[0m , [91mcredible[0m textbook and what allows it to rank with its worthy predecessors .
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[92mPositive (80%)[0m --> [91mNegative (97%)[0m
|
||||
|
||||
making such a tragedy the backdrop to a love story risks trivializing it , [92mthough[0m chouraqui no doubt intended the film to affirm love's power to help people endure almost [92munimaginable[0m horror .
|
||||
|
||||
making such a tragedy the backdrop to a love story risks trivializing it , [91mnotwithstanding[0m chouraqui no doubt intended the film to affirm love's power to help people endure almost [91mincomprehensible[0m horror .
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[92mPositive (92%)[0m --> [91m[FAILED][0m
|
||||
|
||||
grown-up quibbles are beside the point here . the little girls understand , and mccracken knows that's all that matters .
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 2 |
|
||||
| Number of failed attacks: | 1 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 33.33% |
|
||||
| Attack success rate: | 66.67% |
|
||||
| Average perturbed word %: | 8.58% |
|
||||
| Average num. words per input: | 25.67 |
|
||||
| Avg num queries: |/.*/|
|
||||
+-------------------------------+--------+
|
||||
@@ -1,5 +1,4 @@
|
||||
/.*/
|
||||
Attack(
|
||||
/.*/Attack(
|
||||
(search_method): GreedyWordSwapWIR(
|
||||
(wir_method): unk
|
||||
)
|
||||
@@ -25,44 +24,48 @@ Attack(
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91m[FAILED][0m
|
||||
[92mPositive (100%)[0m --> [91mNegative (98%)[0m
|
||||
|
||||
this is a film well worth seeing , talking and singing heads and all .
|
||||
exposing the ways we fool ourselves is one [92mhour[0m photo's real [92mstrength[0m .
|
||||
|
||||
exposing the ways we fool ourselves is one [91mstopwatch[0m photo's real [91mkraft[0m .
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91mNegative (57%)[0m
|
||||
[92mPositive (96%)[0m --> [91mNegative (99%)[0m
|
||||
|
||||
what really [92msurprises[0m about wisegirls is its low-[92mkey[0m quality and [92mgenuine[0m [92mtenderness[0m .
|
||||
it's up to you to decide whether to admire these people's dedication to their cause or be [92mrepelled[0m by their dogmatism , manipulativeness and narrow , [92mfearful[0m view of american life .
|
||||
|
||||
what really [91mdumbfounded[0m about wisegirls is its low-[91mvital[0m quality and [91mveritable[0m [91msensibility[0m .
|
||||
it's up to you to decide whether to admire these people's dedication to their cause or be [91mrescheduled[0m by their dogmatism , manipulativeness and narrow , [91mshitless[0m view of american life .
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91mNegative (84%)[0m
|
||||
[92mPositive (100%)[0m --> [91mNegative (96%)[0m
|
||||
|
||||
( wendigo is ) why we go to the [92mcinema[0m : to be [92mfed[0m through the [92meye[0m , the [92mheart[0m , the [92mmind[0m .
|
||||
mostly , [goldbacher] just lets her complicated characters be [92munruly[0m , confusing and , through it all , [92mhuman[0m .
|
||||
|
||||
( wendigo is ) why we go to the [91mmovie[0m : to be [91mstoked[0m through the [91meyelids[0m , the [91mcoeur[0m , the [91mbother[0m .
|
||||
mostly , [goldbacher] just lets her complicated characters be [91mhaphazard[0m , confusing and , through it all , [91mhumanistic[0m .
|
||||
|
||||
|
||||
--------------------------------------------- Result 4 ---------------------------------------------
|
||||
[92mPositive (99%)[0m --> [91m[FAILED][0m
|
||||
[92mPositive (99%)[0m --> [91mNegative (90%)[0m
|
||||
|
||||
one of the greatest family-oriented , fantasy-adventure movies ever .
|
||||
. . . [92mquite[0m good at [92mproviding[0m some good old fashioned [92mspooks[0m .
|
||||
|
||||
. . . [91mrather[0m good at [91mprovision[0m some good old fashioned [91mbugging[0m .
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 2 |
|
||||
| Number of failed attacks: | 2 |
|
||||
| Number of successful attacks: | 4 |
|
||||
| Number of failed attacks: | 0 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 50.0% |
|
||||
| Attack success rate: | 50.0% |
|
||||
| Average perturbed word %: | 29.27% |
|
||||
| Average num. words per input: | 13.5 |
|
||||
| Avg num queries: | 63.25 |
|
||||
| Accuracy under attack: | 0.0% |
|
||||
| Attack success rate: | 100.0% |
|
||||
| Average perturbed word %: | 17.56% |
|
||||
| Average num. words per input: | 16.25 |
|
||||
| Avg num queries: | 45.5 |
|
||||
+-------------------------------+--------+
|
||||
|
||||
@@ -9,31 +9,19 @@
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92mPositive (63%)[0m --> [37m[SKIPPED][0m
|
||||
[92mPositive (75%)[0m --> [91mNegative (71%)[0m
|
||||
|
||||
I was surprised how much I enjoyed this. Sure it is a bit slow moving in parts, but what else would one expect from Rollin? Also there is plenty of nudity, nothing wrong with that, particularly as it includes lots of the gorgeous, Brigitte Lahaie. There are also some spectacularly eroticised female dead, bit more dodgey, perhaps, but most effective. There is also a sci-fi like storyline with a brief explanation at the end, but I wouldn't bother too much with that. No, here we have a most interesting exploration of memory and the effect of memory loss and to just what extent one is still 'alive' without memory. My DVD sleeve mentions David Cronenberg and whilst this is perhaps not quite as good as his best films, there is some similarity here, particularly with the great use of seemingly menacing architecture and the effective and creepy use of inside space. As I have tried to indicate this is by no means a rip roaring thriller, it is a captivating, nightmare like movie that makes the very most of its locations, including a stunning railway setting at the end.
|
||||
I was surprised how much I enjoyed this. Sure it is a [92mbit[0m slow moving in parts, but what else would one expect from Rollin? Also there is plenty of nudity, nothing wrong with that, particularly as it includes lots of the gorgeous, Brigitte Lahaie. There are also some spectacularly eroticised female dead, bit more dodgey, perhaps, but most effective. There is also a sci-fi like storyline with a brief explanation at the end, but I wouldn't bother too much with that. No, here we have a most interesting exploration of memory and the effect of memory loss and to just what extent one is still 'alive' without memory. My DVD sleeve mentions David Cronenberg and whilst this is perhaps not quite as good as his best films, there is some similarity here, particularly with the great use of seemingly menacing architecture and the effective and creepy use of inside space. As I have tried to indicate this is by no means a rip roaring thriller, it is a captivating, nightmare like movie that makes the very most of its locations, including a stunning railway setting at the end.
|
||||
|
||||
I was surprised how much I enjoyed this. Sure it is a [91mbct[0m slow moving in parts, but what else would one expect from Rollin? Also there is plenty of nudity, nothing wrong with that, particularly as it includes lots of the gorgeous, Brigitte Lahaie. There are also some spectacularly eroticised female dead, bit more dodgey, perhaps, but most effective. There is also a sci-fi like storyline with a brief explanation at the end, but I wouldn't bother too much with that. No, here we have a most interesting exploration of memory and the effect of memory loss and to just what extent one is still 'alive' without memory. My DVD sleeve mentions David Cronenberg and whilst this is perhaps not quite as good as his best films, there is some similarity here, particularly with the great use of seemingly menacing architecture and the effective and creepy use of inside space. As I have tried to indicate this is by no means a rip roaring thriller, it is a captivating, nightmare like movie that makes the very most of its locations, including a stunning railway setting at the end.
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[92mPositive (87%)[0m --> [91mNegative (54%)[0m
|
||||
[92mPositive (69%)[0m --> [91mNegative (53%)[0m
|
||||
|
||||
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks [92mand[0m [92mgrim[0m memories, they eventually get to [92mthe[0m bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
|
||||
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks and [92mgrim[0m memories, they eventually get to the bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
|
||||
|
||||
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks [91macd[0m [91mgAim[0m memories, they eventually get to [91mthr[0m bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[92mPositive (83%)[0m --> [37m[SKIPPED][0m
|
||||
|
||||
I have certainly not seen all of Jean Rollin's films, but they mostly seem to be bloody vampire naked women fests, which if you like that sort of thing is not bad, but this is a major departure and could almost be Cronenberg minus the bio-mechanical nightmarish stuff. Except it's in French with subtitles of course. A man driving on the road at night comes across a woman that is in her slippers and bathrobe and picks her up, while in the background yet another woman lingers, wearing nothing. As they drive along it's obvious that there is something not right about the woman, in that she forgets things almost as quickly as they happen. Still though, that doesn't prevent the man from having sex with her once they return to Paris & his apartment. The man leaves for work and some strangers show up at his place and take the woman away to this 'tower block', a huge apartment building referred to as the Black Tower, where others of her kind (for whom the 'no memory' things seems to be the least of their problems) are being held for some reason. Time and events march by in the movie, which involve mostly trying to find what's going on and get out of the building for this woman, and she does manage to call Robert, the guy that picked her up in the first place, to come rescue her. The revelation as to what's going on comes in the last few moments of the movie, which has a rather strange yet touching end to it. In avoiding what seemed to be his "typical" formula, Rollin created, in this, what I feel is his most fascinating and disturbing film. I like this one a lot, check it out. 8 out of 10.
|
||||
|
||||
|
||||
--------------------------------------------- Result 4 ---------------------------------------------
|
||||
[92mPositive (98%)[0m --> [91mNegative (51%)[0m
|
||||
|
||||
Since this cartoon was made in the old days, Felix talks using cartoon bubbles and the animation style is very crude when compared to today. However, compared to its contemporaries, it's a pretty [92mgood[0m cartoon and still holds up well. That's because despite its age, the cartoon is very creative and funny.<br /><br />Felix meets a guy whose shoe business is folding because he can't sell any shoes. Well, Felix needs money so he can go to Hollywood, so he tells the guy at the shop he'll get every shoe sold. Felix spreads chewing gum [92mall[0m over town and soon people are stuck and leave their shoes--rushing to buy new ones from the shoe store. In gratitude, the guy gives Felix $500! However, Felix's owner wants to take the money and go alone, so Felix figures out a way to sneak along.<br /><br />Once there, Felix barges into a studio and makes a bit of a nuisance of himself. Along the way, he meets cartoon versions of comics Ben Turpin and Charlie Chaplin. In the end, though, through luck, Felix is discovered and offered a movie contract. Hurray!
|
||||
|
||||
Since this cartoon was made in the old days, Felix talks using cartoon bubbles and the animation style is very crude when compared to today. However, compared to its contemporaries, it's a pretty [91mgogd[0m cartoon and still holds up well. That's because despite its age, the cartoon is very creative and funny.<br /><br />Felix meets a guy whose shoe business is folding because he can't sell any shoes. Well, Felix needs money so he can go to Hollywood, so he tells the guy at the shop he'll get every shoe sold. Felix spreads chewing gum [91mlll[0m over town and soon people are stuck and leave their shoes--rushing to buy new ones from the shoe store. In gratitude, the guy gives Felix $500! However, Felix's owner wants to take the money and go alone, so Felix figures out a way to sneak along.<br /><br />Once there, Felix barges into a studio and makes a bit of a nuisance of himself. Along the way, he meets cartoon versions of comics Ben Turpin and Charlie Chaplin. In the end, though, through luck, Felix is discovered and offered a movie contract. Hurray!
|
||||
I went into "Night of the Hunted" not knowing what to expect at all. I was really impressed.<br /><br />It is essentially a mystery/thriller where this girl who can't remember anything gets 'rescued' by a guy who happens to be driving past. The two become fast friends and lovers and together, they try to figure out what is going on with her. Through some vague flashbacks and [91mgrEm[0m memories, they eventually get to the bottom of it and the ending is pretty cool.<br /><br />I really liked the setting of this one: a desolate, post-modern Paris is the backdrop with lots of gray skies and tall buildings. Very metropolitan. Groovy soundtrack and lots of nudity.<br /><br />Surprising it was made in 1980; seems somewhat ahead of it's time.<br /><br />8 out of 10, kids.
|
||||
|
||||
|
||||
|
||||
@@ -42,11 +30,11 @@ Since this cartoon was made in the old days, Felix talks using cartoon bubbles a
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 2 |
|
||||
| Number of failed attacks: | 0 |
|
||||
| Number of skipped attacks: | 2 |
|
||||
| Original accuracy: | 50.0% |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 0.0% |
|
||||
| Attack success rate: | 100.0% |
|
||||
| Average perturbed word %: | 1.59% |
|
||||
| Average num. words per input: | 207.5 |
|
||||
| Avg num queries: | 172.0 |
|
||||
| Average perturbed word %: | 0.62% |
|
||||
| Average num. words per input: | 164.0 |
|
||||
| Avg num queries: | 166.0 |
|
||||
+-------------------------------+--------+
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[92mPositive (100%)[0m --> [91m[FAILED][0m
|
||||
[92mPositive (97%)[0m --> [91m[FAILED][0m
|
||||
|
||||
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur supplies with tremendous skill .
|
||||
|
||||
@@ -63,6 +63,6 @@ throws in enough clever and unexpected twists to make the formula feel fresh .
|
||||
| Accuracy under attack: | 75.0% |
|
||||
| Attack success rate: | 25.0% |
|
||||
| Average perturbed word %: | 3.85% |
|
||||
| Average num. words per input: | 15.75 |
|
||||
| Average num. words per input: | 15.5 |
|
||||
| Avg num queries: | 1.25 |
|
||||
+-------------------------------+--------+
|
||||
|
||||
@@ -16,40 +16,43 @@
|
||||
)
|
||||
/.*/
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[91mContradiction (99%)[0m --> [37m[SKIPPED][0m
|
||||
[92mEntailment (99%)[0m --> [37m[SKIPPED][0m
|
||||
|
||||
[1m[4mPremise[0m[0m: The new rights are nice enough
|
||||
[1m[4mHypothesis[0m[0m: Everyone really likes the newest benefits
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[92mEntailment (100%)[0m --> [91m[FAILED][0m
|
||||
[37mNeutral (100%)[0m --> [92mEntailment (56%)[0m
|
||||
|
||||
[1m[4mPremise[0m[0m: This site includes a list of all award winners and a searchable database of Government Executive articles.
|
||||
[1m[4mHypothesis[0m[0m: The Government Executive articles housed on the website are not able to be searched.
|
||||
[1m[4mHypothesis[0m[0m: The Government Executive articles housed on the website are not [37mable[0m to be searched.
|
||||
|
||||
[1m[4mPremise[0m[0m: This site includes a list of all award winners and a searchable database of Government Executive articles.
|
||||
[1m[4mHypothesis[0m[0m: The Government Executive articles housed on the website are not [92mable-bodied[0m to be searched.
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[35mNeutral (99%)[0m --> [91mContradiction (100%)[0m
|
||||
[91mContradiction (99%)[0m --> [92mEntailment (100%)[0m
|
||||
|
||||
[1m[4mPremise[0m[0m: uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him
|
||||
[1m[4mHypothesis[0m[0m: I like him for the most [92mpart[0m, but would still enjoy seeing someone beat him.
|
||||
[1m[4mHypothesis[0m[0m: I like him for the most [91mpart[0m, but would still enjoy seeing someone beat him.
|
||||
|
||||
[1m[4mPremise[0m[0m: uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him
|
||||
[1m[4mHypothesis[0m[0m: I like him for the most [94moffice[0m, but would still enjoy seeing someone beat him.
|
||||
[1m[4mHypothesis[0m[0m: I like him for the most [92moffice[0m, but would still enjoy seeing someone beat him.
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 1 |
|
||||
| Number of failed attacks: | 1 |
|
||||
| Number of successful attacks: | 2 |
|
||||
| Number of failed attacks: | 0 |
|
||||
| Number of skipped attacks: | 1 |
|
||||
| Original accuracy: | 66.67% |
|
||||
| Accuracy under attack: | 33.33% |
|
||||
| Attack success rate: | 50.0% |
|
||||
| Average perturbed word %: | 2.27% |
|
||||
| Average num. words per input: | 29.0 |
|
||||
| Avg num queries: | 447.5 |
|
||||
| Accuracy under attack: | 0.0% |
|
||||
| Attack success rate: | 100.0% |
|
||||
| Average perturbed word %: | 2.78% |
|
||||
| Average num. words per input: | 28.67 |
|
||||
| Avg num queries: | 182.0 |
|
||||
+-------------------------------+--------+
|
||||
|
||||
@@ -12,12 +12,27 @@ def attacked_text():
|
||||
return textattack.shared.AttackedText(raw_text)
|
||||
|
||||
|
||||
raw_pokemon_text = "the threat implied in the title pokémon 4ever is terrifying like locusts in a horde these things will keep coming ."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pokemon_attacked_text():
|
||||
return textattack.shared.AttackedText(raw_pokemon_text)
|
||||
|
||||
|
||||
premise = "Among these are the red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes."
|
||||
hypothesis = "The Patan Museum is down the street from the red brick Royal Palace."
|
||||
raw_text_pair = collections.OrderedDict(
|
||||
[("premise", premise), ("hypothesis", hypothesis)]
|
||||
)
|
||||
|
||||
raw_hyphenated_text = "It's a run-of-the-mill kind of farmer's tan."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hyphenated_text():
|
||||
return textattack.shared.AttackedText(raw_hyphenated_text)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def attacked_text_pair():
|
||||
@@ -25,27 +40,13 @@ def attacked_text_pair():
|
||||
|
||||
|
||||
class TestAttackedText:
|
||||
def test_words(self, attacked_text):
|
||||
def test_words(self, attacked_text, pokemon_attacked_text):
|
||||
# fmt: off
|
||||
assert attacked_text.words == [
|
||||
"A",
|
||||
"person",
|
||||
"walks",
|
||||
"up",
|
||||
"stairs",
|
||||
"into",
|
||||
"a",
|
||||
"room",
|
||||
"and",
|
||||
"sees",
|
||||
"beer",
|
||||
"poured",
|
||||
"from",
|
||||
"a",
|
||||
"keg",
|
||||
"and",
|
||||
"people",
|
||||
"talking",
|
||||
"A", "person", "walks", "up", "stairs", "into", "a", "room", "and", "sees", "beer", "poured", "from", "a", "keg", "and", "people", "talking",
|
||||
]
|
||||
assert pokemon_attacked_text.words == ['the', 'threat', 'implied', 'in', 'the', 'title', 'pokémon', '4ever', 'is', 'terrifying', 'like', 'locusts', 'in', 'a', 'horde', 'these', 'things', 'will', 'keep', 'coming']
|
||||
# fmt: on
|
||||
|
||||
def test_window_around_index(self, attacked_text):
|
||||
assert attacked_text.text_window_around_index(5, 1) == "into"
|
||||
@@ -53,6 +54,10 @@ class TestAttackedText:
|
||||
assert attacked_text.text_window_around_index(5, 3) == "stairs into a"
|
||||
assert attacked_text.text_window_around_index(5, 4) == "up stairs into a"
|
||||
assert attacked_text.text_window_around_index(5, 5) == "up stairs into a room"
|
||||
assert (
|
||||
attacked_text.text_window_around_index(5, float("inf"))
|
||||
== "A person walks up stairs into a room and sees beer poured from a keg and people talking"
|
||||
)
|
||||
|
||||
def test_big_window_around_index(self, attacked_text):
|
||||
assert (
|
||||
@@ -65,8 +70,9 @@ class TestAttackedText:
|
||||
def test_window_around_index_end(self, attacked_text):
|
||||
assert attacked_text.text_window_around_index(17, 3) == "and people talking"
|
||||
|
||||
def test_text(self, attacked_text, attacked_text_pair):
|
||||
def test_text(self, attacked_text, pokemon_attacked_text, attacked_text_pair):
|
||||
assert attacked_text.text == raw_text
|
||||
assert pokemon_attacked_text.text == raw_pokemon_text
|
||||
assert attacked_text_pair.text == "\n".join(raw_text_pair.values())
|
||||
|
||||
def test_printable_text(self, attacked_text, attacked_text_pair):
|
||||
@@ -136,13 +142,13 @@ class TestAttackedText:
|
||||
+ "\n"
|
||||
+ "The Patan Museum is down the street from the red brick Royal Palace."
|
||||
)
|
||||
new_text = new_text.insert_text_after_word_index(38, "and shapes")
|
||||
new_text = new_text.insert_text_after_word_index(37, "and shapes")
|
||||
assert new_text.text == (
|
||||
"Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes."
|
||||
+ "\n"
|
||||
+ "The Patan Museum is down the street from the red brick Royal Palace."
|
||||
)
|
||||
new_text = new_text.insert_text_after_word_index(41, "The")
|
||||
new_text = new_text.insert_text_after_word_index(40, "The")
|
||||
assert new_text.text == (
|
||||
"Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes."
|
||||
+ "\n"
|
||||
@@ -159,7 +165,7 @@ class TestAttackedText:
|
||||
)
|
||||
for old_idx, new_idx in enumerate(new_text.attack_attrs["original_index_map"]):
|
||||
assert (attacked_text.words[old_idx] == new_text.words[new_idx]) or (
|
||||
new_i == -1
|
||||
new_idx == -1
|
||||
)
|
||||
new_text = (
|
||||
new_text.delete_word_at_index(0)
|
||||
@@ -176,3 +182,14 @@ class TestAttackedText:
|
||||
new_text.text
|
||||
== "person walks a very long way up stairs into a room and sees beer poured and people on the couch."
|
||||
)
|
||||
|
||||
def test_hyphen_apostrophe_words(self, hyphenated_text):
|
||||
assert hyphenated_text.words == [
|
||||
"It's",
|
||||
"a",
|
||||
"run-of-the-mill",
|
||||
"kind",
|
||||
"of",
|
||||
"farmer's",
|
||||
"tan",
|
||||
]
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import pdb
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
import pytest
|
||||
|
||||
DEBUG = False
|
||||
|
||||
@@ -112,6 +111,24 @@ attack_test_params = [
|
||||
),
|
||||
# fmt: on
|
||||
#
|
||||
# test: run_attack on LSTM MR using word embedding transformation and genetic algorithm. Simulate alzantot recipe without using expensive LM
|
||||
(
|
||||
"run_attack_faster_alzantot_recipe",
|
||||
(
|
||||
"textattack attack --model lstm-mr --recipe faster-alzantot --num-examples 3 --num-examples-offset 20"
|
||||
),
|
||||
"tests/sample_outputs/run_attack_faster_alzantot_recipe.txt",
|
||||
),
|
||||
#
|
||||
# test: run_attack with kuleshov recipe and sst-2 cnn
|
||||
#
|
||||
(
|
||||
"run_attack_kuleshov_nn",
|
||||
(
|
||||
"textattack attack --recipe kuleshov --num-examples 2 --model cnn-sst --attack-n --query-budget 200"
|
||||
),
|
||||
"tests/sample_outputs/kuleshov_cnn_sst_2.txt",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -144,3 +161,5 @@ def test_command_line_attack(name, command, sample_output_file):
|
||||
if DEBUG and not re.match(desired_re, stdout, flags=re.S):
|
||||
pdb.set_trace()
|
||||
assert re.match(desired_re, stdout, flags=re.S)
|
||||
|
||||
assert result.returncode == 0
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
import pytest
|
||||
|
||||
augment_test_params = [
|
||||
(
|
||||
@@ -37,3 +36,5 @@ def test_command_line_augmentation(name, command, outfile, sample_output_file):
|
||||
# Ensure CSV file exists, then delete it.
|
||||
assert os.path.exists(outfile)
|
||||
os.remove(outfile)
|
||||
|
||||
assert result.returncode == 0
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from helpers import run_command_and_get_result
|
||||
import pytest
|
||||
|
||||
list_test_params = [
|
||||
(
|
||||
@@ -27,3 +26,5 @@ def test_command_line_list(name, command, sample_output_file):
|
||||
print("stderr =>", stderr)
|
||||
|
||||
assert stdout == desired_text
|
||||
|
||||
assert result.returncode == 0
|
||||
|
||||
@@ -10,18 +10,28 @@ from test_augment import augment_test_params
|
||||
from test_list import list_test_params
|
||||
|
||||
|
||||
def update_test(command, outfile):
|
||||
def update_test(command, outfile, add_magic_str=False):
|
||||
if isinstance(command, str):
|
||||
command = (command,)
|
||||
command = command + (f"tee {outfile}",)
|
||||
print("\n".join(f"> {c}" for c in command))
|
||||
run_command_and_get_result(command)
|
||||
print(">", command)
|
||||
else:
|
||||
print("\n".join(f"> {c}" for c in command))
|
||||
result = run_command_and_get_result(command)
|
||||
stdout = result.stdout.decode().strip()
|
||||
if add_magic_str:
|
||||
# add magic string to beginning
|
||||
magic_str = "/.*/"
|
||||
stdout = magic_str + stdout
|
||||
# add magic string after attack
|
||||
mid_attack_str = "\n--------------------------------------------- Result 1"
|
||||
stdout.replace(mid_attack_str, magic_str + mid_attack_str)
|
||||
# write to file
|
||||
open(outfile, "w").write(stdout + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
#### `textattack attack` tests ####
|
||||
for _, command, outfile in attack_test_params:
|
||||
update_test(command, outfile)
|
||||
update_test(command, outfile, add_magic_str=True)
|
||||
#### `textattack augment` tests ####
|
||||
for _, command, outfile, __ in augment_test_params:
|
||||
update_test(command, outfile)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
def test_imports():
|
||||
import textattack
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
|
||||
del textattack, torch
|
||||
|
||||
|
||||
|
||||
123
tests/test_tokenizers.py
Normal file
123
tests/test_tokenizers.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import collections
|
||||
|
||||
import pytest
|
||||
|
||||
import textattack
|
||||
|
||||
news_article = """The unemployment rate dropped to 8.2% last
|
||||
month, but the economy only added 120,000 jobs,
|
||||
when 203,000 new jobs had been predicted,
|
||||
according to today's jobs report. Reaction on the
|
||||
Wall Street Journal's MarketBeat Blog was swift:
|
||||
"Woah!!! Bad number." The unemployment rate,
|
||||
however, is better news; it had been expected to
|
||||
hold steady at 8.3%. But the AP notes that the dip
|
||||
is mostly due to more Americans giving up on
|
||||
seeking employment. """
|
||||
|
||||
question = """What phenomenon makes global winds blow northeast
|
||||
to southwest or the reverse in the northern
|
||||
hemisphere and northwest to southeast or the
|
||||
reverse in the southern hemisphere?"""
|
||||
|
||||
|
||||
premise = "Among these are the red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes."
|
||||
hypothesis = "The Patan Museum is down the street from the red brick Royal Palace."
|
||||
raw_text_pair = collections.OrderedDict(
|
||||
[("premise", premise), ("hypothesis", hypothesis)]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bert_tokenizer():
|
||||
return textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def glove_tokenizer():
|
||||
lstm = textattack.models.helpers.LSTMForClassification()
|
||||
return lstm.tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lstm():
|
||||
lstm = textattack.models.helpers.LSTMForClassification()
|
||||
return lstm
|
||||
|
||||
|
||||
# Disable formatting so that the samples don't wrap over thousands of lines
|
||||
# fmt: off
|
||||
class TestTokenizer:
|
||||
def test_bert_encode(self, bert_tokenizer):
|
||||
assert bert_tokenizer.encode(news_article) == {
|
||||
"input_ids": [
|
||||
101, 1996, 12163, 3446, 3333, 2000, 1022, 1012, 1016, 1003, 2197, 3204, 1010, 2021, 1996, 4610, 2069, 2794, 6036, 1010, 2199, 5841, 1010, 2043, 18540, 1010, 2199, 2047, 5841, 2018, 2042, 10173, 1010, 2429, 2000, 2651, 1005, 1055, 5841, 3189, 1012, 4668, 2006, 1996, 2813, 2395, 3485, 1005, 1055, 3006, 19442, 9927, 2001, 9170, 1024, 1000, 24185, 4430, 999, 999, 999, 2919, 2193, 1012, 1000, 1996, 12163, 3446, 1010, 2174, 1010, 2003, 2488, 2739, 1025, 2009, 2018, 2042, 3517, 2000, 2907, 6706, 2012, 1022, 1012, 1017, 1003, 1012, 2021, 1996, 9706, 3964, 2008, 1996, 16510, 2003, 3262, 2349, 2000, 2062, 4841, 3228, 2039, 2006, 6224, 6107, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "token_type_ids": [
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "attention_mask": [
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
|
||||
}
|
||||
assert bert_tokenizer.encode(question) == {
|
||||
"input_ids": [
|
||||
101, 2054, 9575, 3084, 3795, 7266, 6271, 4794, 2000, 4943, 2030, 1996, 7901, 1999, 1996, 2642, 14130, 1998, 4514, 2000, 4643, 2030, 1996, 7901, 1999, 1996, 2670, 14130, 1029, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "token_type_ids": [
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "attention_mask": [
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
|
||||
}
|
||||
|
||||
def test_bert_multiseq(self, bert_tokenizer):
|
||||
assert bert_tokenizer.encode((premise, hypothesis)) == {
|
||||
"input_ids": [
|
||||
101, 2426, 2122, 2024, 1996, 2417, 5318, 2548, 4186, 1010, 2029, 2085, 3506, 1996, 6986, 2319, 2688, 1006, 8222, 1005, 1055, 10418, 1998, 2087, 2715, 2688, 1007, 1010, 1998, 1010, 5307, 1996, 4186, 2408, 1996, 4867, 5318, 8232, 1010, 2809, 8436, 1997, 2367, 6782, 1998, 10826, 1012, 102, 1996, 6986, 2319, 2688, 2003, 2091, 1996, 2395, 2013, 1996, 2417, 5318, 2548, 4186, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "token_type_ids": [
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], "attention_mask": [
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
|
||||
}
|
||||
|
||||
def test_bert_encode_tuple(self, bert_tokenizer):
|
||||
assert bert_tokenizer.encode(news_article) == bert_tokenizer.encode(
|
||||
(news_article,)
|
||||
)
|
||||
assert bert_tokenizer.encode(question) == bert_tokenizer.encode((question,))
|
||||
|
||||
def test_bert_encode_batch(self, bert_tokenizer):
|
||||
assert bert_tokenizer.batch_encode([premise, hypothesis]) == [
|
||||
{
|
||||
'input_ids': [101, 2426, 2122, 2024, 1996, 2417, 5318, 2548, 4186, 1010, 2029, 2085, 3506, 1996, 6986, 2319, 2688, 1006, 8222, 1005, 1055, 10418, 1998, 2087, 2715, 2688, 1007, 1010, 1998, 1010, 5307, 1996, 4186, 2408, 1996, 4867, 5318, 8232, 1010, 2809, 8436, 1997, 2367, 6782, 1998, 10826, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
},
|
||||
{
|
||||
'input_ids': [101, 1996, 6986, 2319, 2688, 2003, 2091, 1996, 2395, 2013, 1996, 2417, 5318, 2548, 4186, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
}
|
||||
]
|
||||
|
||||
def test_glove_encode(self, glove_tokenizer):
|
||||
assert glove_tokenizer.encode(news_article) == [
|
||||
399999, 2479, 570, 1199, 3, 399999, 75, 399999, 33, 399999, 426, 90, 294, 12277, 399999, 60, 100738, 49, 1051, 39, 50, 399999, 199, 3, 399999, 1051, 399999, 2613, 12, 399999, 1014, 490, 399999, 399999, 7640, 14, 399999, 399999, 977, 399999, 399999, 2479, 399999, 399999, 13, 438, 399999, 19, 39, 50, 286, 3, 801, 3798, 21, 399999, 33, 399999, 1581, 2141, 11, 399999, 10247, 13, 1245, 444, 3, 55, 826, 1226, 59, 12, 1308, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000,
|
||||
]
|
||||
assert glove_tokenizer.encode(question) == [
|
||||
101, 6387, 906, 700, 3683, 3314, 2591, 3, 2735, 45, 399999, 4962, 5, 399999, 528, 8686, 4, 2233, 3, 1980, 45, 399999, 4962, 5, 399999, 481, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000,
|
||||
]
|
||||
|
||||
def test_glove_multiseq(self, glove_tokenizer):
|
||||
# the GloVE tokenizer is expected to throw an error for multi-sequence inputs
|
||||
with pytest.raises(ValueError):
|
||||
glove_tokenizer.encode((premise, hypothesis))
|
||||
|
||||
def test_glove_encode_tuple(self, glove_tokenizer):
|
||||
assert glove_tokenizer.encode(news_article) == glove_tokenizer.encode(
|
||||
(news_article,)
|
||||
)
|
||||
assert glove_tokenizer.encode(question) == glove_tokenizer.encode((question,))
|
||||
|
||||
def test_glove_encode_batch(self, glove_tokenizer):
|
||||
assert glove_tokenizer.batch_encode([premise, hypothesis]) == [
|
||||
[243, 157, 31, 399999, 638, 6061, 1141, 399999, 41, 113, 1630, 399999, 82219, 1132, 399999, 9432, 4, 95, 1193, 399999, 399999, 2094, 399999, 2532, 530, 399999, 3756, 6061, 399999, 501, 9203, 2, 493, 6151, 4, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000,], [399999, 82219, 1132, 13, 134, 399999, 490, 24, 399999, 638, 6061, 1141, 399999, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, 400000, ],]
|
||||
|
||||
def test_glove_hotfliph(self, lstm, glove_tokenizer):
|
||||
# these things are needed for hotflip: pad_id, oov_id, convert_id_to_word
|
||||
assert glove_tokenizer.pad_id == 400_000
|
||||
assert glove_tokenizer.oov_id == 399_999
|
||||
assert glove_tokenizer.convert_id_to_word(2_179) == "jack"
|
||||
assert 2_179 == lstm.word2id["jack"]
|
||||
|
||||
# fmt: on
|
||||
@@ -1,15 +1,17 @@
|
||||
name = "textattack"
|
||||
|
||||
from . import attack_recipes
|
||||
from . import attack_results
|
||||
from . import augmentation
|
||||
from . import commands
|
||||
from . import constraints
|
||||
from . import datasets
|
||||
from . import goal_functions
|
||||
from . import goal_function_results
|
||||
from . import loggers
|
||||
from . import models
|
||||
from . import search_methods
|
||||
from . import shared
|
||||
from . import transformations
|
||||
from . import (
|
||||
attack_recipes,
|
||||
attack_results,
|
||||
augmentation,
|
||||
commands,
|
||||
constraints,
|
||||
datasets,
|
||||
goal_function_results,
|
||||
goal_functions,
|
||||
loggers,
|
||||
models,
|
||||
search_methods,
|
||||
shared,
|
||||
transformations,
|
||||
)
|
||||
|
||||
56
textattack/attack_recipes/PSO_zang_2020.py
Normal file
56
textattack/attack_recipes/PSO_zang_2020.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from textattack.constraints.pre_transformation import (
|
||||
InputColumnModification,
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
from textattack.search_methods import PSOAlgorithm
|
||||
from textattack.shared.attack import Attack
|
||||
from textattack.transformations import WordSwapEmbedding, WordSwapHowNet
|
||||
|
||||
|
||||
def PSOZang2020(model):
|
||||
"""
|
||||
Zang, Y., Yang, C., Qi, F., Liu, Z., Zhang, M., Liu, Q., & Sun, M. (2019).
|
||||
|
||||
Word-level Textual Adversarial Attacking as Combinatorial Optimization.
|
||||
|
||||
https://www.aclweb.org/anthology/2020.acl-main.540.pdf
|
||||
|
||||
Methodology description quoted from the paper:
|
||||
|
||||
"We propose a novel word substitution-based textual attack model, which reforms
|
||||
both the aforementioned two steps. In the first step, we adopt a sememe-based word
|
||||
substitution strategy, which can generate more candidate adversarial examples with
|
||||
better semantic preservation. In the second step, we utilize particle swarm optimization
|
||||
(Eberhart and Kennedy, 1995) as the adversarial example searching algorithm."
|
||||
|
||||
And "Following the settings in Alzantot et al. (2018), we set the max iteration time G to 20."
|
||||
"""
|
||||
#
|
||||
# Swap words with their synonyms extracted based on the HowNet.
|
||||
#
|
||||
transformation = WordSwapHowNet()
|
||||
#
|
||||
# Don't modify the same word twice or stopwords
|
||||
#
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
#
|
||||
#
|
||||
# During entailment, we should only edit the hypothesis - keep the premise
|
||||
# the same.
|
||||
#
|
||||
input_column_modification = InputColumnModification(
|
||||
["premise", "hypothesis"], {"premise"}
|
||||
)
|
||||
constraints.append(input_column_modification)
|
||||
#
|
||||
# Use untargeted classification for demo, can be switched to targeted one
|
||||
#
|
||||
goal_function = UntargetedClassification(model)
|
||||
#
|
||||
# Perform word substitution with a Particle Swarm Optimization (PSO) algorithm.
|
||||
#
|
||||
search_method = PSOAlgorithm(pop_size=60, max_iters=20)
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
@@ -1,9 +1,14 @@
|
||||
from .alzantot_2018 import Alzantot2018
|
||||
from .bae_garg_2019 import BAEGarg2019
|
||||
from .bert_attack_li_2020 import BERTAttackLi2020
|
||||
from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
|
||||
from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
|
||||
from .deepwordbug_gao_2018 import DeepWordBugGao2018
|
||||
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
|
||||
from .input_reduction_feng_2018 import InputReductionFeng2018
|
||||
from .kuleshov_2017 import Kuleshov2017
|
||||
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
|
||||
from .textbugger_li_2018 import TextBuggerLi2018
|
||||
from .textfooler_jin_2019 import TextFoolerJin2019
|
||||
from .pwws_ren_2019 import PWWSRen2019
|
||||
from .pruthi_2019 import Pruthi2019
|
||||
from .PSO_zang_2020 import PSOZang2020
|
||||
|
||||
109
textattack/attack_recipes/bae_garg_2019.py
Normal file
109
textattack/attack_recipes/bae_garg_2019.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from textattack.constraints.grammaticality import PartOfSpeech
|
||||
from textattack.constraints.pre_transformation import (
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
from textattack.search_methods import GreedyWordSwapWIR
|
||||
from textattack.shared.attack import Attack
|
||||
from textattack.transformations import WordSwapMaskedLM
|
||||
|
||||
|
||||
def BAEGarg2019(model):
|
||||
"""
|
||||
Siddhant Garg and Goutham Ramakrishnan, 2019.
|
||||
|
||||
BAE: BERT-based Adversarial Examples for Text Classification.
|
||||
|
||||
https://arxiv.org/pdf/2004.01970
|
||||
|
||||
This is "attack mode" 1 from the paper, BAE-R, word replacement.
|
||||
|
||||
We present 4 attack modes for BAE based on the
|
||||
R and I operations, where for each token t in S:
|
||||
• BAE-R: Replace token t (See Algorithm 1)
|
||||
• BAE-I: Insert a token to the left or right of t
|
||||
• BAE-R/I: Either replace token t or insert a
|
||||
token to the left or right of t
|
||||
• BAE-R+I: First replace token t, then insert a
|
||||
token to the left or right of t
|
||||
"""
|
||||
# "In this paper, we present a simple yet novel technique: BAE (BERT-based
|
||||
# Adversarial Examples), which uses a language model (LM) for token
|
||||
# replacement to best fit the overall context. We perturb an input sentence
|
||||
# by either replacing a token or inserting a new token in the sentence, by
|
||||
# means of masking a part of the input and using a LM to fill in the mask."
|
||||
#
|
||||
# We only consider the top K=50 synonyms from the MLM predictions.
|
||||
#
|
||||
# [from email correspondance with the author]
|
||||
# "When choosing the top-K candidates from the BERT masked LM, we filter out
|
||||
# the sub-words and only retain the whole words (by checking if they are
|
||||
# present in the GloVE vocabulary)"
|
||||
#
|
||||
transformation = WordSwapMaskedLM(method="bae", max_candidates=50)
|
||||
#
|
||||
# Don't modify the same word twice or stopwords.
|
||||
#
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
|
||||
# For the R operations we add an additional check for
|
||||
# grammatical correctness of the generated adversarial example by filtering
|
||||
# out predicted tokens that do not form the same part of speech (POS) as the
|
||||
# original token t_i in the sentence.
|
||||
constraints.append(PartOfSpeech(allow_verb_noun_swap=True))
|
||||
|
||||
# "To ensure semantic similarity on introducing perturbations in the input
|
||||
# text, we filter the set of top-K masked tokens (K is a pre-defined
|
||||
# constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
|
||||
# (Cer et al., 2018)-based sentence similarity scorer."
|
||||
#
|
||||
# "[We] set a threshold of 0.8 for the cosine similarity between USE-based
|
||||
# embeddings of the adversarial and input text."
|
||||
#
|
||||
# [from email correspondence with the author]
|
||||
# "For a fair comparison of the benefits of using a BERT-MLM in our paper,
|
||||
# we retained the majority of TextFooler's specifications. Thus we:
|
||||
# 1. Use the USE for comparison within a window of size 15 around the word
|
||||
# being replaced/inserted.
|
||||
# 2. Set the similarity score threshold to 0.1 for inputs shorter than the
|
||||
# window size (this translates roughly to almost always accepting the new text).
|
||||
# 3. Perform the USE similarity thresholding of 0.8 with respect to the text
|
||||
# just before the replacement/insertion and not the original text (For
|
||||
# example: at the 3rd R/I operation, we compute the USE score on a window
|
||||
# of size 15 of the text obtained after the first 2 R/I operations and not
|
||||
# the original text).
|
||||
# ...
|
||||
# To address point (3) from above, compare the USE with the original text
|
||||
# at each iteration instead of the current one (While doing this change
|
||||
# for the R-operation is trivial, doing it for the I-operation with the
|
||||
# window based USE comparison might be more involved)."
|
||||
|
||||
use_constraint = UniversalSentenceEncoder(
|
||||
threshold=0.8,
|
||||
metric="cosine",
|
||||
compare_with_original=True,
|
||||
window_size=15,
|
||||
skip_text_shorter_than_window=True,
|
||||
)
|
||||
constraints.append(use_constraint)
|
||||
#
|
||||
# Goal s untargeted classification.
|
||||
#
|
||||
goal_function = UntargetedClassification(model)
|
||||
#
|
||||
# "We estimate the token importance Ii of each token
|
||||
# t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
|
||||
# decrease in probability of predicting the correct label y, similar
|
||||
# to (Jin et al., 2019).
|
||||
#
|
||||
# • "If there are multiple tokens can cause C to misclassify S when they
|
||||
# replace the mask, we choose the token which makes Sadv most similar to
|
||||
# the original S based on the USE score."
|
||||
# • "If no token causes misclassification, we choose the perturbation that
|
||||
# decreases the prediction probability P(C(Sadv)=y) the most."
|
||||
#
|
||||
search_method = GreedyWordSwapWIR(wir_method="delete")
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
76
textattack/attack_recipes/bert_attack_li_2020.py
Normal file
76
textattack/attack_recipes/bert_attack_li_2020.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from textattack.constraints.overlap import MaxWordsPerturbed
|
||||
from textattack.constraints.pre_transformation import (
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
from textattack.search_methods import GreedyWordSwapWIR
|
||||
from textattack.shared.attack import Attack
|
||||
from textattack.transformations import WordSwapMaskedLM
|
||||
|
||||
|
||||
def BERTAttackLi2020(model):
|
||||
"""
|
||||
Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020).
|
||||
|
||||
BERT-ATTACK: Adversarial Attack Against BERT Using BERT
|
||||
|
||||
https://arxiv.org/abs/2004.09984
|
||||
|
||||
This is "attack mode" 1 from the paper, BAE-R, word replacement.
|
||||
"""
|
||||
# [from correspondence with the author]
|
||||
# Candidate size K is set to 48 for all data-sets.
|
||||
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)
|
||||
#
|
||||
# Don't modify the same word twice or stopwords.
|
||||
#
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
|
||||
# "We only take ε percent of the most important words since we tend to keep
|
||||
# perturbations minimum."
|
||||
#
|
||||
# [from correspondence with the author]
|
||||
# "Word percentage allowed to change is set to 0.4 for most data-sets, this
|
||||
# parameter is trivial since most attacks only need a few changes. This
|
||||
# epsilon is only used to avoid too much queries on those very hard samples."
|
||||
constraints.append(MaxWordsPerturbed(max_percent=0.4))
|
||||
|
||||
# "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence
|
||||
# Encoder (Cer et al., 2018) to measure the semantic consistency between the
|
||||
# adversarial sample and the original sequence. To balance between semantic
|
||||
# preservation and attack success rate, we set up a threshold of semantic
|
||||
# similarity score to filter the less similar examples."
|
||||
#
|
||||
# [from correspondence with author]
|
||||
# "Over the full texts, after generating all the adversarial samples, we filter
|
||||
# out low USE score samples. Thus the success rate is lower but the USE score
|
||||
# can be higher. (actually USE score is not a golden metric, so we simply
|
||||
# measure the USE score over the final texts for a comparison with TextFooler).
|
||||
# For datasets like IMDB, we set a higher threshold between 0.4-0.7; for
|
||||
# datasets like MNLI, we set threshold between 0-0.2."
|
||||
#
|
||||
# Since the threshold in the real world can't be determined from the training
|
||||
# data, the TextAttack implementation uses a fixed threshold - determined to
|
||||
# be 0.2 to be most fair.
|
||||
use_constraint = UniversalSentenceEncoder(
|
||||
threshold=0.2, metric="cosine", compare_with_original=True, window_size=None,
|
||||
)
|
||||
constraints.append(use_constraint)
|
||||
#
|
||||
# Goal is untargeted classification.
|
||||
#
|
||||
goal_function = UntargetedClassification(model)
|
||||
#
|
||||
# "We first select the words in the sequence which have a high significance
|
||||
# influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote
|
||||
# the input sentence, and oy(S) denote the logit output by the target model
|
||||
# for correct label y, the importance score Iwi is defined as
|
||||
# Iwi = oy(S) − oy(S\wi), where S\wi = [w0, ··· , wi−1, [MASK], wi+1, ···]
|
||||
# is the sentence after replacing wi with [MASK]. Then we rank all the words
|
||||
# according to the ranking score Iwi in descending order to create word list
|
||||
# L."
|
||||
search_method = GreedyWordSwapWIR(wir_method="unk")
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
124
textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py
Normal file
124
textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from textattack.constraints.grammaticality.language_models import (
|
||||
LearningToWriteLanguageModel,
|
||||
)
|
||||
from textattack.constraints.overlap import MaxWordsPerturbed
|
||||
from textattack.constraints.pre_transformation import (
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
from textattack.search_methods import GeneticAlgorithm
|
||||
from textattack.shared.attack import Attack
|
||||
from textattack.transformations import WordSwapEmbedding
|
||||
|
||||
|
||||
def FasterGeneticAlgorithmJia2019(model):
|
||||
"""
|
||||
Certified Robustness to Adversarial Word Substitutions.
|
||||
|
||||
Robin Jia, Aditi Raghunathan, Kerem Göksel, Percy Liang (2019).
|
||||
|
||||
https://arxiv.org/pdf/1909.00986.pdf
|
||||
|
||||
Section 5: Experiments
|
||||
|
||||
We base our sets of allowed word substitutions S(x, i) on the
|
||||
substitutions allowed by Alzantot et al. (2018). They demonstrated that
|
||||
their substitutions lead to adversarial examples that are qualitatively
|
||||
similar to the original input and retain the original label, as judged
|
||||
by humans. Alzantot et al. (2018) define the neighbors N(w) of a word w
|
||||
as the n = 8 nearest neighbors of w in a “counter-fitted” word vector
|
||||
space where antonyms are far apart (Mrksiˇ c´ et al., 2016). The
|
||||
neighbors must also lie within some Euclidean distance threshold. They
|
||||
also use a language model constraint to avoid nonsensical perturbations:
|
||||
they allow substituting xi with x˜i ∈ N(xi) if and only if it does not
|
||||
decrease the log-likelihood of the text under a pre-trained language
|
||||
model by more than some threshold.
|
||||
|
||||
We make three modifications to this approach:
|
||||
|
||||
First, in Alzantot et al. (2018), the adversary
|
||||
applies substitutions one at a time, and the
|
||||
neighborhoods and language model scores are computed.
|
||||
Equation (4) must be applied before the model
|
||||
can combine information from multiple words, but it can
|
||||
be delayed until after processing each word independently.
|
||||
Note that the model itself classifies using a different
|
||||
set of pre-trained word vectors; the counter-fitted vectors
|
||||
are only used to define the set of allowed substitution words.
|
||||
relative to the current altered version of the input.
|
||||
This results in a hard-to-define attack surface, as
|
||||
changing one word can allow or disallow changes
|
||||
to other words. It also requires recomputing
|
||||
language model scores at each iteration of the genetic
|
||||
attack, which is inefficient. Moreover, the same
|
||||
word can be substituted multiple times, leading
|
||||
to semantic drift. We define allowed substitutions
|
||||
relative to the original sentence x, and disallow
|
||||
repeated substitutions.
|
||||
|
||||
Second, we use a faster language model that allows us to query
|
||||
longer contexts; Alzantot et al. (2018) use a slower language
|
||||
model and could only query it with short contexts.
|
||||
|
||||
Finally, we use the language model constraint only
|
||||
at test time; the model is trained against all perturbations in N(w). This encourages the model to be
|
||||
robust to a larger space of perturbations, instead of
|
||||
specializing for the particular choice of language
|
||||
model. See Appendix A.3 for further details. [This is a model-specific
|
||||
adjustment, so does not affect the attack recipe.]
|
||||
|
||||
Appendix A.3:
|
||||
|
||||
In Alzantot et al. (2018), the adversary applies replacements one at a
|
||||
time, and the neighborhoods and language model scores are computed
|
||||
relative to the current altered version of the input. This results in a
|
||||
hard-to-define attack surface, as the same word can be replaced many
|
||||
times, leading to semantic drift. We instead pre-compute the allowed
|
||||
substitutions S(x, i) at index i based on the original x. We define
|
||||
S(x, i) as the set of x_i ∈ N(x_i) such that where probabilities are
|
||||
assigned by a pre-trained language model, and the window radius W and
|
||||
threshold δ are hyperparameters. We use W = 6 and δ = 5.
|
||||
"""
|
||||
# # @TODO update all this stuff
|
||||
# Swap words with their embedding nearest-neighbors.
|
||||
#
|
||||
# Embedding: Counter-fitted Paragram Embeddings.
|
||||
#
|
||||
# "[We] fix the hyperparameter values to S = 60, N = 8, K = 4, and δ = 0.5"
|
||||
#
|
||||
transformation = WordSwapEmbedding(max_candidates=8)
|
||||
#
|
||||
# Don't modify the same word twice or stopwords
|
||||
#
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
#
|
||||
# Maximum words perturbed percentage of 20%
|
||||
#
|
||||
constraints.append(MaxWordsPerturbed(max_percent=0.2))
|
||||
#
|
||||
# Maximum word embedding euclidean distance of 0.5.
|
||||
#
|
||||
constraints.append(WordEmbeddingDistance(max_mse_dist=0.5))
|
||||
#
|
||||
# Language Model
|
||||
#
|
||||
#
|
||||
#
|
||||
constraints.append(
|
||||
LearningToWriteLanguageModel(
|
||||
window_size=6, max_log_prob_diff=5.0, compare_against_original=True
|
||||
)
|
||||
)
|
||||
# constraints.append(LearningToWriteLanguageModel(window_size=5))
|
||||
#
|
||||
# Goal is untargeted classification
|
||||
#
|
||||
goal_function = UntargetedClassification(model)
|
||||
#
|
||||
# Perform word substitution with a genetic algorithm.
|
||||
#
|
||||
search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0)
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
@@ -3,6 +3,7 @@ from textattack.constraints.grammaticality.language_models import (
|
||||
)
|
||||
from textattack.constraints.overlap import MaxWordsPerturbed
|
||||
from textattack.constraints.pre_transformation import (
|
||||
InputColumnModification,
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
@@ -13,7 +14,7 @@ from textattack.shared.attack import Attack
|
||||
from textattack.transformations import WordSwapEmbedding
|
||||
|
||||
|
||||
def Alzantot2018(model):
|
||||
def GeneticAlgorithmAlzantot2018(model):
|
||||
"""
|
||||
Alzantot, M., Sharma, Y., Elgohary, A., Ho, B., Srivastava, M.B., & Chang, K. (2018).
|
||||
|
||||
@@ -34,6 +35,14 @@ def Alzantot2018(model):
|
||||
#
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
#
|
||||
# During entailment, we should only edit the hypothesis - keep the premise
|
||||
# the same.
|
||||
#
|
||||
input_column_modification = InputColumnModification(
|
||||
["premise", "hypothesis"], {"premise"}
|
||||
)
|
||||
constraints.append(input_column_modification)
|
||||
#
|
||||
# Maximum words perturbed percentage of 20%
|
||||
#
|
||||
constraints.append(MaxWordsPerturbed(max_percent=0.2))
|
||||
@@ -52,6 +61,6 @@ def Alzantot2018(model):
|
||||
#
|
||||
# Perform word substitution with a genetic algorithm.
|
||||
#
|
||||
search_method = GeneticAlgorithm(pop_size=60, max_iters=20)
|
||||
search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0)
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
40
textattack/attack_recipes/input_reduction_feng_2018.py
Normal file
40
textattack/attack_recipes/input_reduction_feng_2018.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from textattack.constraints.pre_transformation import (
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.goal_functions import InputReduction
|
||||
from textattack.search_methods import GreedyWordSwapWIR
|
||||
from textattack.shared.attack import Attack
|
||||
from textattack.transformations import WordDeletion
|
||||
|
||||
|
||||
def InputReductionFeng2018(model):
|
||||
"""
|
||||
Feng, Wallace, Grissom, Iyyer, Rodriguez, Boyd-Graber. (2018).
|
||||
|
||||
Pathologies of Neural Models Make Interpretations Difficult.
|
||||
|
||||
ArXiv, abs/1804.07781.
|
||||
"""
|
||||
# At each step, we remove the word with the lowest importance value until
|
||||
# the model changes its prediction.
|
||||
transformation = WordDeletion()
|
||||
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
#
|
||||
# Goal is untargeted classification
|
||||
#
|
||||
goal_function = InputReduction(model, maximizable=True)
|
||||
#
|
||||
# "For each word in an input sentence, we measure its importance by the
|
||||
# change in the confidence of the original prediction when we remove
|
||||
# that word from the sentence."
|
||||
#
|
||||
# "Instead of looking at the words with high importance values—what
|
||||
# interpretation methods commonly do—we take a complementary approach
|
||||
# and study how the model behaves when the supposedly unimportant words are
|
||||
# removed."
|
||||
#
|
||||
search_method = GreedyWordSwapWIR(wir_method="delete")
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
@@ -1,30 +1,42 @@
|
||||
|
||||
|
||||
from textattack.transformations import (
|
||||
WordSwapNeighboringCharacterSwap,
|
||||
WordSwapRandomCharacterDeletion,
|
||||
WordSwapRandomCharacterInsertion,
|
||||
WordSwapQWERTY,
|
||||
CompositeTransformation,
|
||||
)
|
||||
from textattack.constraints.pre_transformation import StopwordModification, MinWordLength, RepeatModification
|
||||
from textattack.constraints.overlap import MaxWordsPerturbed
|
||||
from textattack.search_methods import GreedySearch
|
||||
from textattack.constraints.pre_transformation import (
|
||||
MinWordLength,
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
from textattack.goal_functions import UntargetedClassification
|
||||
from textattack.search_methods import GreedySearch
|
||||
from textattack.shared.attack import Attack
|
||||
from textattack.transformations import (
|
||||
CompositeTransformation,
|
||||
WordSwapNeighboringCharacterSwap,
|
||||
WordSwapQWERTY,
|
||||
WordSwapRandomCharacterDeletion,
|
||||
WordSwapRandomCharacterInsertion,
|
||||
)
|
||||
|
||||
|
||||
def Pruthi2019(model, max_num_word_swaps=1):
|
||||
transformation = CompositeTransformation([
|
||||
WordSwapNeighboringCharacterSwap(random_one=False, skip_first_char=True, skip_last_char=True),
|
||||
WordSwapRandomCharacterDeletion(random_one=False, skip_first_char=True, skip_last_char=True),
|
||||
WordSwapRandomCharacterInsertion(random_one=False, skip_first_char=True, skip_last_char=True),
|
||||
WordSwapQWERTY(random_one=False, skip_first_char=True, skip_last_char=True),
|
||||
])
|
||||
constraints = [MinWordLength(min_length=4), StopwordModification(), MaxWordsPerturbed(max_num_words=max_num_word_swaps), RepeatModification()]
|
||||
transformation = CompositeTransformation(
|
||||
[
|
||||
WordSwapNeighboringCharacterSwap(
|
||||
random_one=False, skip_first_char=True, skip_last_char=True
|
||||
),
|
||||
WordSwapRandomCharacterDeletion(
|
||||
random_one=False, skip_first_char=True, skip_last_char=True
|
||||
),
|
||||
WordSwapRandomCharacterInsertion(
|
||||
random_one=False, skip_first_char=True, skip_last_char=True
|
||||
),
|
||||
WordSwapQWERTY(random_one=False, skip_first_char=True, skip_last_char=True),
|
||||
]
|
||||
)
|
||||
constraints = [
|
||||
MinWordLength(min_length=4),
|
||||
StopwordModification(),
|
||||
MaxWordsPerturbed(max_num_words=max_num_word_swaps),
|
||||
RepeatModification(),
|
||||
]
|
||||
goal_function = UntargetedClassification(model)
|
||||
search_method = GreedySearch()
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -26,5 +26,5 @@ def PWWSRen2019(model):
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
goal_function = UntargetedClassification(model)
|
||||
# search over words based on a combination of their saliency score, and how efficient the WordSwap transform is
|
||||
search_method = GreedyWordSwapWIR("pwws", ascending=False)
|
||||
search_method = GreedyWordSwapWIR("pwws")
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
|
||||
@@ -27,7 +27,7 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"):
|
||||
# Goal is non-overlapping output.
|
||||
#
|
||||
goal_function = NonOverlappingOutput(model)
|
||||
# @TODO implement transformation / search method just like they do in
|
||||
# TODO implement transformation / search method just like they do in
|
||||
# seq2sick.
|
||||
transformation = WordSwapEmbedding(max_candidates=50)
|
||||
#
|
||||
@@ -42,6 +42,6 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"):
|
||||
#
|
||||
# Greedily swap words with "Word Importance Ranking".
|
||||
#
|
||||
search_method = GreedyWordSwapWIR()
|
||||
search_method = GreedyWordSwapWIR(wir_method="unk")
|
||||
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from textattack.constraints.grammaticality import PartOfSpeech
|
||||
from textattack.constraints.pre_transformation import (
|
||||
InputColumnModification,
|
||||
RepeatModification,
|
||||
StopwordModification,
|
||||
)
|
||||
@@ -28,278 +29,20 @@ def TextFoolerJin2019(model):
|
||||
# Don't modify the same word twice or the stopwords defined
|
||||
# in the TextFooler public implementation.
|
||||
#
|
||||
# fmt: off
|
||||
stopwords = set(
|
||||
[
|
||||
"a",
|
||||
"about",
|
||||
"above",
|
||||
"across",
|
||||
"after",
|
||||
"afterwards",
|
||||
"again",
|
||||
"against",
|
||||
"ain",
|
||||
"all",
|
||||
"almost",
|
||||
"alone",
|
||||
"along",
|
||||
"already",
|
||||
"also",
|
||||
"although",
|
||||
"am",
|
||||
"among",
|
||||
"amongst",
|
||||
"an",
|
||||
"and",
|
||||
"another",
|
||||
"any",
|
||||
"anyhow",
|
||||
"anyone",
|
||||
"anything",
|
||||
"anyway",
|
||||
"anywhere",
|
||||
"are",
|
||||
"aren",
|
||||
"aren't",
|
||||
"around",
|
||||
"as",
|
||||
"at",
|
||||
"back",
|
||||
"been",
|
||||
"before",
|
||||
"beforehand",
|
||||
"behind",
|
||||
"being",
|
||||
"below",
|
||||
"beside",
|
||||
"besides",
|
||||
"between",
|
||||
"beyond",
|
||||
"both",
|
||||
"but",
|
||||
"by",
|
||||
"can",
|
||||
"cannot",
|
||||
"could",
|
||||
"couldn",
|
||||
"couldn't",
|
||||
"d",
|
||||
"didn",
|
||||
"didn't",
|
||||
"doesn",
|
||||
"doesn't",
|
||||
"don",
|
||||
"don't",
|
||||
"down",
|
||||
"due",
|
||||
"during",
|
||||
"either",
|
||||
"else",
|
||||
"elsewhere",
|
||||
"empty",
|
||||
"enough",
|
||||
"even",
|
||||
"ever",
|
||||
"everyone",
|
||||
"everything",
|
||||
"everywhere",
|
||||
"except",
|
||||
"first",
|
||||
"for",
|
||||
"former",
|
||||
"formerly",
|
||||
"from",
|
||||
"hadn",
|
||||
"hadn't",
|
||||
"hasn",
|
||||
"hasn't",
|
||||
"haven",
|
||||
"haven't",
|
||||
"he",
|
||||
"hence",
|
||||
"her",
|
||||
"here",
|
||||
"hereafter",
|
||||
"hereby",
|
||||
"herein",
|
||||
"hereupon",
|
||||
"hers",
|
||||
"herself",
|
||||
"him",
|
||||
"himself",
|
||||
"his",
|
||||
"how",
|
||||
"however",
|
||||
"hundred",
|
||||
"i",
|
||||
"if",
|
||||
"in",
|
||||
"indeed",
|
||||
"into",
|
||||
"is",
|
||||
"isn",
|
||||
"isn't",
|
||||
"it",
|
||||
"it's",
|
||||
"its",
|
||||
"itself",
|
||||
"just",
|
||||
"latter",
|
||||
"latterly",
|
||||
"least",
|
||||
"ll",
|
||||
"may",
|
||||
"me",
|
||||
"meanwhile",
|
||||
"mightn",
|
||||
"mightn't",
|
||||
"mine",
|
||||
"more",
|
||||
"moreover",
|
||||
"most",
|
||||
"mostly",
|
||||
"must",
|
||||
"mustn",
|
||||
"mustn't",
|
||||
"my",
|
||||
"myself",
|
||||
"namely",
|
||||
"needn",
|
||||
"needn't",
|
||||
"neither",
|
||||
"never",
|
||||
"nevertheless",
|
||||
"next",
|
||||
"no",
|
||||
"nobody",
|
||||
"none",
|
||||
"noone",
|
||||
"nor",
|
||||
"not",
|
||||
"nothing",
|
||||
"now",
|
||||
"nowhere",
|
||||
"o",
|
||||
"of",
|
||||
"off",
|
||||
"on",
|
||||
"once",
|
||||
"one",
|
||||
"only",
|
||||
"onto",
|
||||
"or",
|
||||
"other",
|
||||
"others",
|
||||
"otherwise",
|
||||
"our",
|
||||
"ours",
|
||||
"ourselves",
|
||||
"out",
|
||||
"over",
|
||||
"per",
|
||||
"please",
|
||||
"s",
|
||||
"same",
|
||||
"shan",
|
||||
"shan't",
|
||||
"she",
|
||||
"she's",
|
||||
"should've",
|
||||
"shouldn",
|
||||
"shouldn't",
|
||||
"somehow",
|
||||
"something",
|
||||
"sometime",
|
||||
"somewhere",
|
||||
"such",
|
||||
"t",
|
||||
"than",
|
||||
"that",
|
||||
"that'll",
|
||||
"the",
|
||||
"their",
|
||||
"theirs",
|
||||
"them",
|
||||
"themselves",
|
||||
"then",
|
||||
"thence",
|
||||
"there",
|
||||
"thereafter",
|
||||
"thereby",
|
||||
"therefore",
|
||||
"therein",
|
||||
"thereupon",
|
||||
"these",
|
||||
"they",
|
||||
"this",
|
||||
"those",
|
||||
"through",
|
||||
"throughout",
|
||||
"thru",
|
||||
"thus",
|
||||
"to",
|
||||
"too",
|
||||
"toward",
|
||||
"towards",
|
||||
"under",
|
||||
"unless",
|
||||
"until",
|
||||
"up",
|
||||
"upon",
|
||||
"used",
|
||||
"ve",
|
||||
"was",
|
||||
"wasn",
|
||||
"wasn't",
|
||||
"we",
|
||||
"were",
|
||||
"weren",
|
||||
"weren't",
|
||||
"what",
|
||||
"whatever",
|
||||
"when",
|
||||
"whence",
|
||||
"whenever",
|
||||
"where",
|
||||
"whereafter",
|
||||
"whereas",
|
||||
"whereby",
|
||||
"wherein",
|
||||
"whereupon",
|
||||
"wherever",
|
||||
"whether",
|
||||
"which",
|
||||
"while",
|
||||
"whither",
|
||||
"who",
|
||||
"whoever",
|
||||
"whole",
|
||||
"whom",
|
||||
"whose",
|
||||
"why",
|
||||
"with",
|
||||
"within",
|
||||
"without",
|
||||
"won",
|
||||
"won't",
|
||||
"would",
|
||||
"wouldn",
|
||||
"wouldn't",
|
||||
"y",
|
||||
"yet",
|
||||
"you",
|
||||
"you'd",
|
||||
"you'll",
|
||||
"you're",
|
||||
"you've",
|
||||
"your",
|
||||
"yours",
|
||||
"yourself",
|
||||
"yourselves",
|
||||
]
|
||||
["a", "about", "above", "across", "after", "afterwards", "again", "against", "ain", "all", "almost", "alone", "along", "already", "also", "although", "am", "among", "amongst", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", "aren", "aren't", "around", "as", "at", "back", "been", "before", "beforehand", "behind", "being", "below", "beside", "besides", "between", "beyond", "both", "but", "by", "can", "cannot", "could", "couldn", "couldn't", "d", "didn", "didn't", "doesn", "doesn't", "don", "don't", "down", "due", "during", "either", "else", "elsewhere", "empty", "enough", "even", "ever", "everyone", "everything", "everywhere", "except", "first", "for", "former", "formerly", "from", "hadn", "hadn't", "hasn", "hasn't", "haven", "haven't", "he", "hence", "her", "here", "hereafter", "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", "how", "however", "hundred", "i", "if", "in", "indeed", "into", "is", "isn", "isn't", "it", "it's", "its", "itself", "just", "latter", "latterly", "least", "ll", "may", "me", "meanwhile", "mightn", "mightn't", "mine", "more", "moreover", "most", "mostly", "must", "mustn", "mustn't", "my", "myself", "namely", "needn", "needn't", "neither", "never", "nevertheless", "next", "no", "nobody", "none", "noone", "nor", "not", "nothing", "now", "nowhere", "o", "of", "off", "on", "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", "ours", "ourselves", "out", "over", "per", "please", "s", "same", "shan", "shan't", "she", "she's", "should've", "shouldn", "shouldn't", "somehow", "something", "sometime", "somewhere", "such", "t", "than", "that", "that'll", "the", "their", "theirs", "them", "themselves", "then", "thence", "there", "thereafter", "thereby", "therefore", "therein", "thereupon", "these", "they", "this", "those", "through", "throughout", "thru", "thus", "to", "too", "toward", "towards", "under", "unless", "until", "up", "upon", "used", "ve", "was", "wasn", "wasn't", "we", "were", "weren", "weren't", "what", "whatever", "when", "whence", "whenever", "where", "whereafter", "whereas", "whereby", "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", "who", "whoever", "whole", "whom", "whose", "why", "with", "within", "without", "won", "won't", "would", "wouldn", "wouldn't", "y", "yet", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves"]
|
||||
)
|
||||
# fmt: on
|
||||
constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)]
|
||||
#
|
||||
# During entailment, we should only edit the hypothesis - keep the premise
|
||||
# the same.
|
||||
#
|
||||
input_column_modification = InputColumnModification(
|
||||
["premise", "hypothesis"], {"premise"}
|
||||
)
|
||||
constraints.append(input_column_modification)
|
||||
# Minimum word embedding cosine similarity of 0.5.
|
||||
# (The paper claims 0.7, but analysis of the released code and some empirical
|
||||
# results show that it's 0.5.)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .maximized_attack_result import MaximizedAttackResult
|
||||
from .failed_attack_result import FailedAttackResult
|
||||
from .skipped_attack_result import SkippedAttackResult
|
||||
from .successful_attack_result import SuccessfulAttackResult
|
||||
|
||||
@@ -13,11 +13,11 @@ class AttackResult:
|
||||
perturbed text. May or may not have been successful.
|
||||
"""
|
||||
|
||||
def __init__(self, original_result, perturbed_result, num_queries=0):
|
||||
def __init__(self, original_result, perturbed_result):
|
||||
if original_result is None:
|
||||
raise ValueError("Attack original result cannot be None")
|
||||
elif not isinstance(original_result, GoalFunctionResult):
|
||||
raise TypeError(f"Invalid original goal function result: {original_text}")
|
||||
raise TypeError(f"Invalid original goal function result: {original_result}")
|
||||
if perturbed_result is None:
|
||||
raise ValueError("Attack perturbed result cannot be None")
|
||||
elif not isinstance(perturbed_result, GoalFunctionResult):
|
||||
@@ -27,7 +27,7 @@ class AttackResult:
|
||||
|
||||
self.original_result = original_result
|
||||
self.perturbed_result = perturbed_result
|
||||
self.num_queries = num_queries
|
||||
self.num_queries = perturbed_result.num_queries
|
||||
|
||||
# We don't want the AttackedText attributes sticking around clogging up
|
||||
# space on our devices. Delete them here, if they're still present,
|
||||
@@ -89,27 +89,34 @@ class AttackResult:
|
||||
i1 = 0
|
||||
i2 = 0
|
||||
|
||||
while i1 < len(t1.words) and i2 < len(t2.words):
|
||||
while i1 < t1.num_words or i2 < t2.num_words:
|
||||
# show deletions
|
||||
while t2.attack_attrs["original_index_map"][i1] == -1:
|
||||
while (
|
||||
i1 < len(t2.attack_attrs["original_index_map"])
|
||||
and t2.attack_attrs["original_index_map"][i1] == -1
|
||||
):
|
||||
words_1.append(utils.color_text(t1.words[i1], color_1, color_method))
|
||||
words_1_idxs.append(i1)
|
||||
i1 += 1
|
||||
# show insertions
|
||||
while i2 < t2.attack_attrs["original_index_map"][i1]:
|
||||
while (
|
||||
i1 < len(t2.attack_attrs["original_index_map"])
|
||||
and i2 < t2.attack_attrs["original_index_map"][i1]
|
||||
):
|
||||
words_2.append(utils.color_text(t1.words[i2], color_2, color_method))
|
||||
words_2_idxs.append(i2)
|
||||
i2 += 1
|
||||
# show swaps
|
||||
word_1 = t1.words[i1]
|
||||
word_2 = t2.words[i2]
|
||||
if word_1 != word_2:
|
||||
words_1.append(utils.color_text(word_1, color_1, color_method))
|
||||
words_2.append(utils.color_text(word_2, color_2, color_method))
|
||||
words_1_idxs.append(i1)
|
||||
words_2_idxs.append(i2)
|
||||
i1 += 1
|
||||
i2 += 1
|
||||
if i1 < t1.num_words and i2 < t2.num_words:
|
||||
word_1 = t1.words[i1]
|
||||
word_2 = t2.words[i2]
|
||||
if word_1 != word_2:
|
||||
words_1.append(utils.color_text(word_1, color_1, color_method))
|
||||
words_2.append(utils.color_text(word_2, color_2, color_method))
|
||||
words_1_idxs.append(i1)
|
||||
words_2_idxs.append(i2)
|
||||
i1 += 1
|
||||
i2 += 1
|
||||
|
||||
t1 = self.original_result.attacked_text.replace_words_at_indices(
|
||||
words_1_idxs, words_1
|
||||
|
||||
@@ -6,9 +6,9 @@ from .attack_result import AttackResult
|
||||
class FailedAttackResult(AttackResult):
|
||||
"""The result of a failed attack."""
|
||||
|
||||
def __init__(self, original_result, perturbed_result=None, num_queries=0):
|
||||
def __init__(self, original_result, perturbed_result=None):
|
||||
perturbed_result = perturbed_result or original_result
|
||||
super().__init__(original_result, perturbed_result, num_queries)
|
||||
super().__init__(original_result, perturbed_result)
|
||||
|
||||
def str_lines(self, color_method=None):
|
||||
lines = (
|
||||
|
||||
5
textattack/attack_results/maximized_attack_result.py
Normal file
5
textattack/attack_results/maximized_attack_result.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .attack_result import AttackResult
|
||||
|
||||
|
||||
class MaximizedAttackResult(AttackResult):
|
||||
""" The result of a successful attack. """
|
||||
@@ -122,13 +122,12 @@ class CharSwapAugmenter(Augmenter):
|
||||
""" Augments words by swapping characters out for other characters. """
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from textattack.transformations import CompositeTransformation
|
||||
from textattack.transformations import (
|
||||
CompositeTransformation,
|
||||
WordSwapNeighboringCharacterSwap,
|
||||
WordSwapRandomCharacterDeletion,
|
||||
WordSwapRandomCharacterInsertion,
|
||||
WordSwapRandomCharacterSubstitution,
|
||||
WordSwapNeighboringCharacterSwap,
|
||||
)
|
||||
|
||||
transformation = CompositeTransformation(
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
from .attack_command import AttackCommand
|
||||
from .attack_resume_command import AttackResumeCommand
|
||||
|
||||
from .run_attack_single_threaded import run as run_attack_single_threaded
|
||||
from .run_attack_parallel import run as run_attack_parallel
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
import textattack
|
||||
|
||||
ATTACK_RECIPE_NAMES = {
|
||||
"alzantot": "textattack.attack_recipes.Alzantot2018",
|
||||
"alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
|
||||
"bae": "textattack.attack_recipes.BAEGarg2019",
|
||||
"bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
|
||||
"faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
|
||||
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
|
||||
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
|
||||
"input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
|
||||
"kuleshov": "textattack.attack_recipes.Kuleshov2017",
|
||||
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
|
||||
"textbugger": "textattack.attack_recipes.TextBuggerLi2018",
|
||||
"textfooler": "textattack.attack_recipes.TextFoolerJin2019",
|
||||
"pwws": "textattack.attack_recipes.PWWSRen2019",
|
||||
"pruthi": "textattack.attack_recipes.Pruthi2019",
|
||||
"pso": "textattack.attack_recipes.PSOZang2020",
|
||||
}
|
||||
|
||||
#
|
||||
@@ -55,6 +60,14 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
"textattack/bert-base-uncased-WNLI",
|
||||
("glue", "wnli", "validation"),
|
||||
),
|
||||
"bert-base-uncased-mr": (
|
||||
"textattack/bert-base-uncased-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
"bert-base-uncased-snli": (
|
||||
"textattack/bert-base-uncased-snli",
|
||||
("snli", None, "test", [1, 2, 0]),
|
||||
),
|
||||
#
|
||||
# distilbert-base-cased
|
||||
#
|
||||
@@ -141,6 +154,24 @@ HUGGINGFACE_DATASET_BY_MODEL = {
|
||||
"textattack/roberta-base-WNLI",
|
||||
("glue", "wnli", "validation"),
|
||||
),
|
||||
"roberta-base-mr": (
|
||||
"textattack/roberta-base-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
#
|
||||
# albert-base-v2 (ALBERT is cased by default)
|
||||
#
|
||||
"albert-base-v2-mr": (
|
||||
"textattack/albert-base-v2-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
#
|
||||
# xlnet-base-cased
|
||||
#
|
||||
"xlnet-base-cased-mr": (
|
||||
"textattack/xlnet-base-cased-rotten-tomatoes",
|
||||
("rotten_tomatoes", None, "test"),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -172,10 +203,6 @@ TEXTATTACK_DATASET_BY_MODEL = {
|
||||
#
|
||||
# Text classification models
|
||||
#
|
||||
"bert-base-uncased-mr": (
|
||||
("models/classification/bert/mr-uncased", 2),
|
||||
("rotten_tomatoes", None, "train"),
|
||||
),
|
||||
"bert-base-cased-imdb": (
|
||||
("models/classification/bert/imdb-cased", 2),
|
||||
("imdb", None, "test"),
|
||||
@@ -194,11 +221,22 @@ TEXTATTACK_DATASET_BY_MODEL = {
|
||||
),
|
||||
#
|
||||
# Translation models
|
||||
# TODO add proper `nlp` datasets for translation & summarization
|
||||
"t5-en-de": (
|
||||
"english_to_german",
|
||||
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
|
||||
),
|
||||
"t5-en-fr": (
|
||||
"english_to_french",
|
||||
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "fr"),
|
||||
),
|
||||
"t5-en-ro": (
|
||||
"english_to_romanian",
|
||||
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
|
||||
),
|
||||
#
|
||||
# Summarization models
|
||||
#
|
||||
#'t5-summ': 'textattack.models.summarization.T5Summarization',
|
||||
"t5-summarization": ("summarization", ("gigaword", None, "test")),
|
||||
}
|
||||
|
||||
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
|
||||
@@ -233,6 +271,7 @@ CONSTRAINT_CLASS_NAMES = {
|
||||
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
|
||||
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
|
||||
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
|
||||
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
|
||||
#
|
||||
# Overlap constraints
|
||||
#
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import copy
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
@@ -30,7 +31,7 @@ def add_model_args(parser):
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help='The pre-trained model to attack. Usage: "--model {model}:{arg_1}={value_1},{arg_3}={value_3},...". Choices: '
|
||||
help='Name of or path to a pre-trained model to attack. Usage: "--model {model}:{arg_1}={value_1},{arg_3}={value_3},...". Choices: '
|
||||
+ str(model_names),
|
||||
)
|
||||
model_group.add_argument(
|
||||
@@ -153,6 +154,8 @@ def parse_goal_function_from_args(args, model):
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported goal_function {goal_function}")
|
||||
goal_function.query_budget = args.query_budget
|
||||
goal_function.model_batch_size = args.model_batch_size
|
||||
goal_function.model_cache_size = args.model_cache_size
|
||||
return goal_function
|
||||
|
||||
|
||||
@@ -191,6 +194,9 @@ def parse_attack_from_args(args):
|
||||
else:
|
||||
raise ValueError(f"Invalid recipe {args.recipe}")
|
||||
recipe.goal_function.query_budget = args.query_budget
|
||||
recipe.goal_function.model_batch_size = args.model_batch_size
|
||||
recipe.goal_function.model_cache_size = args.model_cache_size
|
||||
recipe.constraint_cache_size = args.constraint_cache_size
|
||||
return recipe
|
||||
elif args.attack_from_file:
|
||||
if ":" in args.attack_from_file:
|
||||
@@ -218,7 +224,11 @@ def parse_attack_from_args(args):
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported attack {args.search}")
|
||||
return textattack.shared.Attack(
|
||||
goal_function, constraints, transformation, search_method
|
||||
goal_function,
|
||||
constraints,
|
||||
transformation,
|
||||
search_method,
|
||||
constraint_cache_size=args.constraint_cache_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -295,6 +305,22 @@ def parse_model_from_args(args):
|
||||
model = textattack.shared.utils.load_textattack_model_from_path(
|
||||
args.model, model_path
|
||||
)
|
||||
elif args.model and os.path.exists(args.model):
|
||||
# If `args.model` is a path/directory, let's assume it was a model
|
||||
# trained with textattack, and try and load it.
|
||||
model_args_json_path = os.path.join(args.model, "train_args.json")
|
||||
if not os.path.exists(model_args_json_path):
|
||||
raise FileNotFoundError(
|
||||
f"Tried to load model from path {args.model} - could not find train_args.json."
|
||||
)
|
||||
model_train_args = json.loads(open(model_args_json_path).read())
|
||||
model_train_args["model"] = args.model
|
||||
num_labels = model_train_args["num_labels"]
|
||||
from textattack.commands.train_model.train_args_helpers import (
|
||||
model_from_args,
|
||||
)
|
||||
|
||||
model = model_from_args(argparse.Namespace(**model_train_args), num_labels)
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported TextAttack model {args.model}")
|
||||
return model
|
||||
@@ -306,7 +332,32 @@ def parse_dataset_from_args(args):
|
||||
if args.model in HUGGINGFACE_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
|
||||
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
_, dataset = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
if dataset[0].startswith("textattack"):
|
||||
# unsavory way to pass custom dataset classes
|
||||
# ex: dataset = ('textattack.datasets.translation.TedMultiTranslationDataset', 'en', 'de')
|
||||
dataset = eval(f"{dataset[0]}")(*dataset[1:])
|
||||
return dataset
|
||||
else:
|
||||
args.dataset_from_nlp = dataset
|
||||
# Automatically detect dataset for models trained with textattack.
|
||||
elif args.model and os.path.exists(args.model):
|
||||
model_args_json_path = os.path.join(args.model, "train_args.json")
|
||||
if not os.path.exists(model_args_json_path):
|
||||
raise FileNotFoundError(
|
||||
f"Tried to load model from path {args.model} - could not find train_args.json."
|
||||
)
|
||||
model_train_args = json.loads(open(model_args_json_path).read())
|
||||
try:
|
||||
args.dataset_from_nlp = (
|
||||
model_train_args["dataset"],
|
||||
None,
|
||||
model_train_args["dataset_dev_split"],
|
||||
)
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"Tried to load model from path {args.model} but can't initialize dataset from train_args.json."
|
||||
)
|
||||
|
||||
# Get dataset from args.
|
||||
if args.dataset_from_file:
|
||||
@@ -331,8 +382,11 @@ def parse_dataset_from_args(args):
|
||||
)
|
||||
elif args.dataset_from_nlp:
|
||||
dataset_args = args.dataset_from_nlp
|
||||
if ":" in dataset_args:
|
||||
dataset_args = dataset_args.split(":")
|
||||
if isinstance(dataset_args, str):
|
||||
if ":" in dataset_args:
|
||||
dataset_args = dataset_args.split(":")
|
||||
else:
|
||||
dataset_args = (dataset_args,)
|
||||
dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, shuffle=args.shuffle
|
||||
)
|
||||
@@ -350,8 +404,10 @@ def parse_logger_from_args(args):
|
||||
if not args.out_dir:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(
|
||||
current_dir, os.pardir, os.pardir, os.pardir, "outputs"
|
||||
current_dir, os.pardir, os.pardir, os.pardir, "outputs", "attacks"
|
||||
)
|
||||
if not os.path.exists(outputs_dir):
|
||||
os.makedirs(outputs_dir)
|
||||
args.out_dir = os.path.normpath(outputs_dir)
|
||||
|
||||
# if "--log-to-file" specified in terminal command, then save it to a txt file
|
||||
|
||||
@@ -157,6 +157,24 @@ class AttackCommand(TextAttackCommand):
|
||||
default=float("inf"),
|
||||
help="The maximum number of model queries allowed per example attacked.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The batch size for making calls to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
help="The maximum number of items to keep in the model results cache at once.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--constraint-cache-size",
|
||||
type=int,
|
||||
default=2 ** 18,
|
||||
help="The maximum number of items to keep in the constraints cache at once.",
|
||||
)
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
|
||||
|
||||
@@ -20,15 +20,15 @@ def set_env_variables(gpu_id):
|
||||
# Set sharing strategy to file_system to avoid file descriptor leaks
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
# Only use one GPU, if we have one.
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
# For Tensorflow
|
||||
# TODO: Using USE with `--parallel` raises similar issue as https://github.com/tensorflow/tensorflow/issues/38518#
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
# For PyTorch
|
||||
torch.cuda.set_device(gpu_id)
|
||||
|
||||
# Disable tensorflow logs, except in the case of an error.
|
||||
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
# Cache TensorFlow Hub models here, if not otherwise specified.
|
||||
if "TFHUB_CACHE_DIR" not in os.environ:
|
||||
os.environ["TFHUB_CACHE_DIR"] = os.path.expanduser("~/.cache/tensorflow-hub")
|
||||
|
||||
|
||||
def attack_from_queue(args, in_queue, out_queue):
|
||||
@@ -125,7 +125,10 @@ def run(args, checkpoint=None):
|
||||
pbar.update()
|
||||
num_results += 1
|
||||
|
||||
if type(result) == textattack.attack_results.SuccessfulAttackResult:
|
||||
if (
|
||||
type(result) == textattack.attack_results.SuccessfulAttackResult
|
||||
or type(result) == textattack.attack_results.MaximizedAttackResult
|
||||
):
|
||||
num_successes += 1
|
||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||
num_failures += 1
|
||||
@@ -170,6 +173,8 @@ def run(args, checkpoint=None):
|
||||
finish_time = time.time()
|
||||
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
|
||||
|
||||
return attack_log_manager.results
|
||||
|
||||
|
||||
def pytorch_multiprocessing_workaround():
|
||||
# This is a fix for a known bug
|
||||
|
||||
@@ -18,14 +18,12 @@ logger = textattack.shared.logger
|
||||
|
||||
def run(args, checkpoint=None):
|
||||
# Only use one GPU, if we have one.
|
||||
# TODO: Running Universal Sentence Encoder uses multiple GPUs
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
# Disable tensorflow logs, except in the case of an error.
|
||||
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
# Cache TensorFlow Hub models here, if not otherwise specified.
|
||||
if "TFHUB_CACHE_DIR" not in os.environ:
|
||||
os.environ["TFHUB_CACHE_DIR"] = os.path.expanduser("~/.cache/tensorflow-hub")
|
||||
|
||||
if args.checkpoint_resume:
|
||||
num_remaining_attacks = checkpoint.num_remaining_attacks
|
||||
@@ -110,7 +108,10 @@ def run(args, checkpoint=None):
|
||||
|
||||
num_results += 1
|
||||
|
||||
if type(result) == textattack.attack_results.SuccessfulAttackResult:
|
||||
if (
|
||||
type(result) == textattack.attack_results.SuccessfulAttackResult
|
||||
or type(result) == textattack.attack_results.MaximizedAttackResult
|
||||
):
|
||||
num_successes += 1
|
||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||
num_failures += 1
|
||||
@@ -141,6 +142,8 @@ def run(args, checkpoint=None):
|
||||
finish_time = time.time()
|
||||
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
|
||||
|
||||
return attack_log_manager.results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(get_args())
|
||||
|
||||
1
textattack/commands/eval_model/__init__.py
Normal file
1
textattack/commands/eval_model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .eval_model_command import EvalModelCommand
|
||||
@@ -12,11 +12,11 @@ def _cb(s):
|
||||
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
|
||||
|
||||
|
||||
class BenchmarkModelCommand(TextAttackCommand):
|
||||
class EvalModelCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack model benchmarking module:
|
||||
|
||||
A command line parser to benchmark a model from user specifications.
|
||||
A command line parser to evaluatate a model from user specifications.
|
||||
"""
|
||||
|
||||
def get_num_successes(self, model, ids, true_labels):
|
||||
@@ -83,7 +83,7 @@ class BenchmarkModelCommand(TextAttackCommand):
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"benchmark-model",
|
||||
"eval",
|
||||
help="evaluate a model with TextAttack",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
@@ -97,4 +97,4 @@ class BenchmarkModelCommand(TextAttackCommand):
|
||||
default=256,
|
||||
help="Batch size for model inference.",
|
||||
)
|
||||
parser.set_defaults(func=BenchmarkModelCommand())
|
||||
parser.set_defaults(func=EvalModelCommand())
|
||||
82
textattack/commands/peek_dataset.py
Normal file
82
textattack/commands/peek_dataset.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
import collections
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
import textattack
|
||||
from textattack.commands import TextAttackCommand
|
||||
from textattack.commands.attack.attack_args_helpers import (
|
||||
add_dataset_args,
|
||||
parse_dataset_from_args,
|
||||
)
|
||||
|
||||
|
||||
def _cb(s):
|
||||
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
|
||||
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
class PeekDatasetCommand(TextAttackCommand):
|
||||
"""
|
||||
The peek dataset module:
|
||||
|
||||
Takes a peek into a dataset in textattack.
|
||||
"""
|
||||
|
||||
def run(self, args):
|
||||
UPPERCASE_LETTERS_REGEX = re.compile("[A-Z]")
|
||||
|
||||
args.model = None # set model to None for parse_dataset_from_args to work
|
||||
dataset = parse_dataset_from_args(args)
|
||||
|
||||
num_words = []
|
||||
attacked_texts = []
|
||||
data_all_lowercased = True
|
||||
outputs = []
|
||||
for inputs, output in dataset:
|
||||
at = textattack.shared.AttackedText(inputs)
|
||||
if data_all_lowercased:
|
||||
# Test if any of the letters in the string are lowercase.
|
||||
if re.search(UPPERCASE_LETTERS_REGEX, at.text):
|
||||
data_all_lowercased = False
|
||||
attacked_texts.append(at)
|
||||
num_words.append(len(at.words))
|
||||
outputs.append(output)
|
||||
|
||||
logger.info(f"Number of samples: {_cb(len(attacked_texts))}")
|
||||
logger.info(f"Number of words per input:")
|
||||
num_words = np.array(num_words)
|
||||
logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}')
|
||||
mean_words = f"{num_words.mean():.2f}"
|
||||
logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}')
|
||||
std_words = f"{num_words.std():.2f}"
|
||||
logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}')
|
||||
logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}')
|
||||
logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}')
|
||||
logger.info(f"Dataset lowercased: {_cb(data_all_lowercased)}")
|
||||
|
||||
logger.info("First sample:")
|
||||
print(attacked_texts[0].printable_text(), "\n")
|
||||
logger.info("Last sample:")
|
||||
print(attacked_texts[-1].printable_text(), "\n")
|
||||
|
||||
logger.info(f"Found {len(set(outputs))} distinct outputs.")
|
||||
if len(outputs) < 20:
|
||||
print(sorted(set(outputs)))
|
||||
|
||||
logger.info(f"Most common outputs:")
|
||||
for i, (key, value) in enumerate(collections.Counter(outputs).most_common(20)):
|
||||
print("\t", str(key)[:5].ljust(5), f" ({value})")
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"peek-dataset",
|
||||
help="show main statistics about a dataset",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
add_dataset_args(parser)
|
||||
parser.set_defaults(func=PeekDatasetCommand())
|
||||
@@ -5,9 +5,10 @@ import sys
|
||||
|
||||
from textattack.commands.attack import AttackCommand, AttackResumeCommand
|
||||
from textattack.commands.augment import AugmentCommand
|
||||
from textattack.commands.benchmark_model import BenchmarkModelCommand
|
||||
from textattack.commands.benchmark_recipe import BenchmarkRecipeCommand
|
||||
from textattack.commands.eval_model import EvalModelCommand
|
||||
from textattack.commands.list_things import ListThingsCommand
|
||||
from textattack.commands.peek_dataset import PeekDatasetCommand
|
||||
from textattack.commands.train_model import TrainModelCommand
|
||||
|
||||
|
||||
@@ -23,10 +24,11 @@ def main():
|
||||
AttackCommand.register_subcommand(subparsers)
|
||||
AttackResumeCommand.register_subcommand(subparsers)
|
||||
AugmentCommand.register_subcommand(subparsers)
|
||||
BenchmarkModelCommand.register_subcommand(subparsers)
|
||||
BenchmarkRecipeCommand.register_subcommand(subparsers)
|
||||
EvalModelCommand.register_subcommand(subparsers)
|
||||
ListThingsCommand.register_subcommand(subparsers)
|
||||
TrainModelCommand.register_subcommand(subparsers)
|
||||
PeekDatasetCommand.register_subcommand(subparsers)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from textattack.commands import TextAttackCommand
|
||||
|
||||
|
||||
class TrainModelCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack train module:
|
||||
|
||||
A command line parser to train a model from user specifications.
|
||||
"""
|
||||
|
||||
def run(self, args):
|
||||
raise NotImplementedError("Cannot train models yet - stay tuned!!")
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser("train", help="train a model")
|
||||
1
textattack/commands/train_model/__init__.py
Normal file
1
textattack/commands/train_model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .train_model_command import TrainModelCommand
|
||||
382
textattack/commands/train_model/run_training.py
Normal file
382
textattack/commands/train_model/run_training.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
import tqdm
|
||||
import transformers
|
||||
|
||||
import textattack
|
||||
|
||||
from .train_args_helpers import dataset_from_args, model_from_args, write_readme
|
||||
|
||||
device = textattack.shared.utils.device
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
def make_directories(output_dir):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
|
||||
def batch_encode(tokenizer, text_list):
|
||||
if hasattr(tokenizer, "batch_encode"):
|
||||
return tokenizer.batch_encode(text_list)
|
||||
else:
|
||||
return [tokenizer.encode(text_input) for text_input in text_list]
|
||||
|
||||
|
||||
def train_model(args):
|
||||
logger.warn(
|
||||
"WARNING: TextAttack's model training feature is in beta. Please report any issues on our Github page, https://github.com/QData/TextAttack/issues."
|
||||
)
|
||||
start_time = time.time()
|
||||
make_directories(args.output_dir)
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
||||
# Save logger writes to file
|
||||
log_txt_path = os.path.join(args.output_dir, "log.txt")
|
||||
fh = logging.FileHandler(log_txt_path)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
logger.addHandler(fh)
|
||||
logger.info(f"Writing logs to {log_txt_path}.")
|
||||
|
||||
# Use Weights & Biases, if enabled.
|
||||
if args.enable_wandb:
|
||||
global wandb
|
||||
import wandb
|
||||
|
||||
wandb.init(sync_tensorboard=True)
|
||||
|
||||
# Get list of text and list of label (integers) from disk.
|
||||
train_text, train_labels, eval_text, eval_labels = dataset_from_args(args)
|
||||
|
||||
# Filter labels
|
||||
if args.allowed_labels:
|
||||
logger.info(f"Filtering samples with labels outside of {args.allowed_labels}.")
|
||||
final_train_text, final_train_labels = [], []
|
||||
for text, label in zip(train_text, train_labels):
|
||||
if label in args.allowed_labels:
|
||||
final_train_text.append(text)
|
||||
final_train_labels.append(label)
|
||||
logger.info(
|
||||
f"Filtered {len(train_text)} train samples to {len(final_train_text)} points."
|
||||
)
|
||||
train_text, train_labels = final_train_text, final_train_labels
|
||||
final_eval_text, final_eval_labels = [], []
|
||||
for text, label in zip(eval_text, eval_labels):
|
||||
if label in args.allowed_labels:
|
||||
final_eval_text.append(text)
|
||||
final_eval_labels.append(label)
|
||||
logger.info(
|
||||
f"Filtered {len(eval_text)} dev samples to {len(final_eval_text)} points."
|
||||
)
|
||||
eval_text, eval_labels = final_eval_text, final_eval_labels
|
||||
|
||||
label_id_len = len(train_labels)
|
||||
label_set = set(train_labels)
|
||||
args.num_labels = len(label_set)
|
||||
logger.info(
|
||||
f"Loaded dataset. Found: {args.num_labels} labels: ({sorted(label_set)})"
|
||||
)
|
||||
|
||||
if isinstance(train_labels[0], float):
|
||||
# TODO come up with a more sophisticated scheme for when to do regression
|
||||
logger.warn(f"Detected float labels. Doing regression.")
|
||||
args.num_labels = 1
|
||||
args.do_regression = True
|
||||
else:
|
||||
args.do_regression = False
|
||||
|
||||
train_examples_len = len(train_text)
|
||||
|
||||
if len(train_labels) != train_examples_len:
|
||||
raise ValueError(
|
||||
f"Number of train examples ({train_examples_len}) does not match number of labels ({len(train_labels)})"
|
||||
)
|
||||
if len(eval_labels) != len(eval_text):
|
||||
raise ValueError(
|
||||
f"Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_labels)})"
|
||||
)
|
||||
|
||||
model = model_from_args(args, args.num_labels)
|
||||
tokenizer = model.tokenizer
|
||||
|
||||
logger.info(f"Tokenizing training data. (len: {train_examples_len})")
|
||||
train_text_ids = batch_encode(tokenizer, train_text)
|
||||
logger.info(f"Tokenizing eval data (len: {len(eval_labels)})")
|
||||
eval_text_ids = batch_encode(tokenizer, eval_text)
|
||||
load_time = time.time()
|
||||
logger.info(f"Loaded data and tokenized in {load_time-start_time}s")
|
||||
|
||||
# multi-gpu training
|
||||
if num_gpus > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
logger.info(f"Training model across {num_gpus} GPUs")
|
||||
|
||||
num_train_optimization_steps = (
|
||||
int(train_examples_len / args.batch_size / args.grad_accum_steps)
|
||||
* args.num_train_epochs
|
||||
)
|
||||
|
||||
if args.model == "lstm" or args.model == "cnn":
|
||||
need_grad = lambda x: x.requires_grad
|
||||
optimizer = torch.optim.Adam(
|
||||
filter(need_grad, model.parameters()), lr=args.learning_rate
|
||||
)
|
||||
scheduler = None
|
||||
else:
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.01,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = transformers.optimization.AdamW(
|
||||
optimizer_grouped_parameters, lr=args.learning_rate
|
||||
)
|
||||
|
||||
scheduler = transformers.optimization.get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=args.warmup_proportion,
|
||||
num_training_steps=num_train_optimization_steps,
|
||||
)
|
||||
|
||||
# Start Tensorboard and log hyperparams.
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
tb_writer = SummaryWriter(args.output_dir)
|
||||
|
||||
def is_writable_type(obj):
|
||||
for ok_type in [bool, int, str, float]:
|
||||
if isinstance(obj, ok_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
|
||||
|
||||
tb_writer.add_hparams(args_dict, {})
|
||||
|
||||
# Start training
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f"\tNum examples = {train_examples_len}")
|
||||
logger.info(f"\tBatch size = {args.batch_size}")
|
||||
logger.info(f"\tMax sequence length = {args.max_length}")
|
||||
logger.info(f"\tNum steps = {num_train_optimization_steps}")
|
||||
logger.info(f"\tNum epochs = {args.num_train_epochs}")
|
||||
logger.info(f"\tLearning rate = {args.learning_rate}")
|
||||
|
||||
train_input_ids = np.array(train_text_ids)
|
||||
train_labels = np.array(train_labels)
|
||||
train_data = list((ids, label) for ids, label in zip(train_input_ids, train_labels))
|
||||
train_sampler = RandomSampler(train_data)
|
||||
train_dataloader = DataLoader(
|
||||
train_data, sampler=train_sampler, batch_size=args.batch_size
|
||||
)
|
||||
|
||||
eval_input_ids = np.array(eval_text_ids)
|
||||
eval_labels = np.array(eval_labels)
|
||||
eval_data = list((ids, label) for ids, label in zip(eval_input_ids, eval_labels))
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_data, sampler=eval_sampler, batch_size=args.batch_size
|
||||
)
|
||||
|
||||
def get_eval_score():
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
logits = []
|
||||
labels = []
|
||||
for input_ids, batch_labels in eval_dataloader:
|
||||
if isinstance(input_ids, dict):
|
||||
## HACK: dataloader collates dict backwards. This is a temporary
|
||||
# workaround to get ids in the right shape
|
||||
input_ids = {
|
||||
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
|
||||
}
|
||||
batch_labels = batch_labels.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
batch_logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
logits.extend(batch_logits.cpu().squeeze().tolist())
|
||||
labels.extend(batch_labels)
|
||||
|
||||
model.train()
|
||||
logits = torch.tensor(logits)
|
||||
labels = torch.tensor(labels)
|
||||
|
||||
if args.do_regression:
|
||||
pearson_correlation, pearson_p_value = scipy.stats.pearsonr(logits, labels)
|
||||
return pearson_correlation
|
||||
else:
|
||||
preds = logits.argmax(dim=1)
|
||||
correct = (preds == labels).sum()
|
||||
return float(correct) / len(labels)
|
||||
|
||||
def save_model():
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Only save the model itself
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, args.weights_name)
|
||||
output_config_file = os.path.join(args.output_dir, args.config_name)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
try:
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
except AttributeError:
|
||||
# no config
|
||||
pass
|
||||
|
||||
global_step = 0
|
||||
tr_loss = 0
|
||||
|
||||
def save_model_checkpoint():
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# Take care of distributed/parallel training
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info(f"Checkpoint saved to {output_dir}.")
|
||||
|
||||
model.train()
|
||||
args.best_eval_score = 0
|
||||
args.best_eval_score_epoch = 0
|
||||
args.epochs_since_best_eval_score = 0
|
||||
|
||||
def loss_backward(loss):
|
||||
if num_gpus > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.grad_accum_steps > 1:
|
||||
loss = loss / args.grad_accum_steps
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
if args.do_regression:
|
||||
# TODO integrate with textattack `metrics` package
|
||||
loss_fct = torch.nn.MSELoss()
|
||||
else:
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
|
||||
for epoch in tqdm.trange(
|
||||
int(args.num_train_epochs), desc="Epoch", position=0, leave=False
|
||||
):
|
||||
prog_bar = tqdm.tqdm(
|
||||
train_dataloader, desc="Iteration", position=1, leave=False
|
||||
)
|
||||
for step, batch in enumerate(prog_bar):
|
||||
input_ids, labels = batch
|
||||
labels = labels.to(device)
|
||||
if isinstance(input_ids, dict):
|
||||
## HACK: dataloader collates dict backwards. This is a temporary
|
||||
# workaround to get ids in the right shape
|
||||
input_ids = {
|
||||
k: torch.stack(v).T.to(device) for k, v in input_ids.items()
|
||||
}
|
||||
logits = textattack.shared.utils.model_predict(model, input_ids)
|
||||
|
||||
if args.do_regression:
|
||||
# TODO integrate with textattack `metrics` package
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
loss = loss_backward(loss)
|
||||
tr_loss += loss.item()
|
||||
|
||||
if global_step % args.tb_writer_step == 0:
|
||||
tb_writer.add_scalar("loss", loss.item(), global_step)
|
||||
if scheduler is not None:
|
||||
tb_writer.add_scalar("lr", scheduler.get_last_lr()[0], global_step)
|
||||
else:
|
||||
tb_writer.add_scalar("lr", args.learning_rate, global_step)
|
||||
if global_step > 0:
|
||||
prog_bar.set_description(f"Loss {tr_loss/global_step}")
|
||||
if (step + 1) % args.grad_accum_steps == 0:
|
||||
optimizer.step()
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
# Save model checkpoint to file.
|
||||
if (
|
||||
global_step > 0
|
||||
and (args.checkpoint_steps > 0)
|
||||
and (global_step % args.checkpoint_steps) == 0
|
||||
):
|
||||
save_model_checkpoint()
|
||||
|
||||
# Inc step counter.
|
||||
global_step += 1
|
||||
|
||||
# Check accuracy after each epoch.
|
||||
eval_score = get_eval_score()
|
||||
tb_writer.add_scalar("epoch_eval_score", eval_score, global_step)
|
||||
|
||||
if args.checkpoint_every_epoch:
|
||||
save_model_checkpoint()
|
||||
|
||||
logger.info(
|
||||
f"Eval {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
|
||||
)
|
||||
if eval_score > args.best_eval_score:
|
||||
args.best_eval_score = eval_score
|
||||
args.best_eval_score_epoch = epoch
|
||||
args.epochs_since_best_eval_score = 0
|
||||
save_model()
|
||||
logger.info(f"Best acc found. Saved model to {args.output_dir}.")
|
||||
else:
|
||||
args.epochs_since_best_eval_score += 1
|
||||
if (args.early_stopping_epochs > 0) and (
|
||||
args.epochs_since_best_eval_score > args.early_stopping_epochs
|
||||
):
|
||||
logger.info(
|
||||
f"Stopping early since it's been {args.early_stopping_epochs} steps since validation acc increased"
|
||||
)
|
||||
break
|
||||
|
||||
# read the saved model and report its eval performance
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, args.weights_name)))
|
||||
eval_score = get_eval_score()
|
||||
logger.info(
|
||||
f"Eval of saved model {'pearson correlation' if args.do_regression else 'accuracy'}: {eval_score*100}%"
|
||||
)
|
||||
|
||||
# end of training, save tokenizer
|
||||
try:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
logger.info(f"Saved tokenizer {tokenizer} to {args.output_dir}.")
|
||||
except AttributeError:
|
||||
logger.warn(
|
||||
f"Error: could not save tokenizer {tokenizer} to {args.output_dir}."
|
||||
)
|
||||
|
||||
# Save a little readme with model info
|
||||
write_readme(args, args.best_eval_score, args.best_eval_score_epoch)
|
||||
|
||||
# Save args to file
|
||||
args_save_path = os.path.join(args.output_dir, "train_args.json")
|
||||
final_args_dict = {k: v for k, v in vars(args).items() if is_writable_type(v)}
|
||||
with open(args_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(final_args_dict, indent=2) + "\n")
|
||||
logger.info(f"Wrote training args to {args_save_path}.")
|
||||
153
textattack/commands/train_model/train_args_helpers.py
Normal file
153
textattack/commands/train_model/train_args_helpers.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import os
|
||||
|
||||
import textattack
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
def prepare_dataset_for_training(nlp_dataset):
|
||||
""" Changes an `nlp` dataset into the proper format for tokenization. """
|
||||
|
||||
def prepare_example_dict(ex):
|
||||
""" Returns the values in order corresponding to the data.
|
||||
|
||||
ex:
|
||||
'Some text input'
|
||||
or in the case of multi-sequence inputs:
|
||||
('The premise', 'the hypothesis',)
|
||||
etc.
|
||||
"""
|
||||
values = list(ex.values())
|
||||
if len(values) == 1:
|
||||
return values[0]
|
||||
return tuple(values)
|
||||
|
||||
text, outputs = zip(*((prepare_example_dict(x[0]), x[1]) for x in nlp_dataset))
|
||||
return list(text), list(outputs)
|
||||
|
||||
|
||||
def dataset_from_args(args):
|
||||
""" Returns a tuple of ``HuggingFaceNLPDataset`` for the train and test
|
||||
datasets for ``args.dataset``.
|
||||
"""
|
||||
dataset_args = args.dataset.split(":")
|
||||
# TODO `HuggingFaceNLPDataset` -> `HuggingFaceDataset`
|
||||
if args.dataset_train_split:
|
||||
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split=args.dataset_train_split
|
||||
)
|
||||
else:
|
||||
try:
|
||||
train_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split="train"
|
||||
)
|
||||
args.dataset_train_split = "train"
|
||||
except KeyError:
|
||||
raise KeyError(f"Error: no `train` split found in `{args.dataset}` dataset")
|
||||
train_text, train_labels = prepare_dataset_for_training(train_dataset)
|
||||
|
||||
if args.dataset_dev_split:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split=args.dataset_dev_split
|
||||
)
|
||||
else:
|
||||
# try common dev split names
|
||||
try:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split="dev"
|
||||
)
|
||||
args.dataset_dev_split = "dev"
|
||||
except KeyError:
|
||||
try:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split="eval"
|
||||
)
|
||||
args.dataset_dev_split = "eval"
|
||||
except KeyError:
|
||||
try:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split="validation"
|
||||
)
|
||||
args.dataset_dev_split = "validation"
|
||||
except KeyError:
|
||||
try:
|
||||
eval_dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, split="test"
|
||||
)
|
||||
args.dataset_dev_split = "test"
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"Could not find `dev`, `eval`, `validation`, or `test` split in dataset {args.dataset}."
|
||||
)
|
||||
eval_text, eval_labels = prepare_dataset_for_training(eval_dataset)
|
||||
|
||||
return train_text, train_labels, eval_text, eval_labels
|
||||
|
||||
|
||||
def model_from_args(args, num_labels):
|
||||
if args.model == "lstm":
|
||||
textattack.shared.logger.info("Loading textattack model: LSTMForClassification")
|
||||
model = textattack.models.helpers.LSTMForClassification(
|
||||
max_seq_length=args.max_length,
|
||||
num_labels=num_labels,
|
||||
emb_layer_trainable=False,
|
||||
)
|
||||
elif args.model == "cnn":
|
||||
textattack.shared.logger.info(
|
||||
"Loading textattack model: WordCNNForClassification"
|
||||
)
|
||||
model = textattack.models.helpers.WordCNNForClassification(
|
||||
max_seq_length=args.max_length,
|
||||
num_labels=num_labels,
|
||||
emb_layer_trainable=False,
|
||||
)
|
||||
else:
|
||||
import transformers
|
||||
|
||||
textattack.shared.logger.info(
|
||||
f"Loading transformers AutoModelForSequenceClassification: {args.model}"
|
||||
)
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
args.model, num_labels=num_labels, finetuning_task=args.dataset
|
||||
)
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model, config=config,
|
||||
)
|
||||
tokenizer = textattack.models.tokenizers.AutoTokenizer(
|
||||
args.model, use_fast=True, max_length=args.max_length
|
||||
)
|
||||
setattr(model, "tokenizer", tokenizer)
|
||||
|
||||
model = model.to(textattack.shared.utils.device)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def write_readme(args, best_eval_score, best_eval_score_epoch):
|
||||
# Save args to file
|
||||
readme_save_path = os.path.join(args.output_dir, "README.md")
|
||||
dataset_name = args.dataset.split(":")[0] if ":" in args.dataset else args.dataset
|
||||
task_name = "regression" if args.do_regression else "classification"
|
||||
loss_func = "mean squared error" if args.do_regression else "cross-entropy"
|
||||
metric_name = "pearson correlation" if args.do_regression else "accuracy"
|
||||
epoch_info = f"{best_eval_score_epoch} epoch" + (
|
||||
"s" if best_eval_score_epoch > 1 else ""
|
||||
)
|
||||
readme_text = f"""
|
||||
## {args.model} fine-tuned with TextAttack on the {dataset_name} dataset
|
||||
|
||||
This `{args.model}` model was fine-tuned for sequence classification using TextAttack
|
||||
and the {dataset_name} dataset loaded using the `nlp` library. The model was fine-tuned
|
||||
for {args.num_train_epochs} epochs with a batch size of {args.batch_size}, a learning
|
||||
rate of {args.learning_rate}, and a maximum sequence length of {args.max_length}.
|
||||
Since this was a {task_name} task, the model was trained with a {loss_func} loss function.
|
||||
The best score the model achieved on this task was {best_eval_score}, as measured by the
|
||||
eval set {metric_name}, found after {epoch_info}.
|
||||
|
||||
For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack).
|
||||
|
||||
"""
|
||||
|
||||
with open(readme_save_path, "w", encoding="utf-8") as f:
|
||||
f.write(readme_text.strip() + "\n")
|
||||
logger.info(f"Wrote README to {readme_save_path}.")
|
||||
155
textattack/commands/train_model/train_model_command.py
Normal file
155
textattack/commands/train_model/train_model_command.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from textattack.commands import TextAttackCommand
|
||||
|
||||
|
||||
class TrainModelCommand(TextAttackCommand):
|
||||
"""
|
||||
The TextAttack train module:
|
||||
|
||||
A command line parser to train a model from user specifications.
|
||||
"""
|
||||
|
||||
def run(self, args):
|
||||
|
||||
date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M")
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(
|
||||
current_dir, os.pardir, os.pardir, os.pardir, "outputs", "training"
|
||||
)
|
||||
outputs_dir = os.path.normpath(outputs_dir)
|
||||
|
||||
args.output_dir = os.path.join(
|
||||
outputs_dir, f"{args.model}-{args.dataset}-{date_now}/"
|
||||
)
|
||||
|
||||
from .run_training import train_model
|
||||
|
||||
train_model(args)
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(main_parser: ArgumentParser):
|
||||
parser = main_parser.add_parser(
|
||||
"train",
|
||||
help="train a model for sequence classification",
|
||||
formatter_class=ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, required=True, help="directory of model to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
default="yelp",
|
||||
help="dataset for training; will be loaded from "
|
||||
"`nlp` library. if dataset has a subset, separate with a colon. "
|
||||
" ex: `glue:sst2` or `rotten_tomatoes`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-train-split",
|
||||
"--train-split",
|
||||
type=str,
|
||||
default="",
|
||||
help="train dataset split, if non-standard "
|
||||
"(can automatically detect 'train'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-dev-split",
|
||||
"--dataset-eval-split",
|
||||
"--dev-split",
|
||||
type=str,
|
||||
default="",
|
||||
help="val dataset split, if non-standard "
|
||||
"(can automatically detect 'dev', 'validation', 'eval')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tb-writer-step",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of steps before writing to tensorboard",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-steps",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="save model after this many steps (-1 for no checkpointing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-every_epoch",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="save model checkpoint after each epoch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-train-epochs",
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Total number of epochs to train for",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-labels",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="Labels allowed for training (examples with other labels will be discarded)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early-stopping-epochs",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of epochs validation must increase"
|
||||
" before stopping early (-1 for no early stopping)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=128, help="Batch size for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum length of a sequence (anything beyond this will "
|
||||
"be truncated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
"--lr",
|
||||
type=float,
|
||||
default=2e-5,
|
||||
help="Learning rate for Adam Optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-accum-steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of steps to accumulate gradients before optimizing, "
|
||||
"advancing scheduler, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-proportion",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Warmup proportion for linear scheduling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-name",
|
||||
type=str,
|
||||
default="config.json",
|
||||
help="Filename to save BERT config as",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weights-name",
|
||||
type=str,
|
||||
default="pytorch_model.bin",
|
||||
help="Filename to save model weights as",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-wandb",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="log metrics to Weights & Biases",
|
||||
)
|
||||
parser.set_defaults(func=TrainModelCommand())
|
||||
@@ -1,7 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import textattack
|
||||
from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class Constraint:
|
||||
class Constraint(ABC):
|
||||
"""
|
||||
An abstract class that represents constraints on adversial text examples.
|
||||
Constraints evaluate whether transformations from a ``AttackedText`` to another
|
||||
@@ -68,9 +71,9 @@ class Constraint:
|
||||
current_text: The current ``AttackedText``.
|
||||
original_text: The original ``AttackedText`` from which the attack began.
|
||||
"""
|
||||
if not isinstance(transformed_text, AttackedText):
|
||||
if not isinstance(transformed_text, textattack.shared.AttackedText):
|
||||
raise TypeError("transformed_text must be of type AttackedText")
|
||||
if not isinstance(current_text, AttackedText):
|
||||
if not isinstance(current_text, textattack.shared.AttackedText):
|
||||
raise TypeError("current_text must be of type AttackedText")
|
||||
|
||||
try:
|
||||
@@ -86,6 +89,7 @@ class Constraint:
|
||||
transformed_text, current_text, original_text=original_text
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
"""
|
||||
Returns True if the constraint is fulfilled, False otherwise. Must be overridden by
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from .language_model_constraint import LanguageModelConstraint
|
||||
|
||||
from .google_language_model import Google1BillionWordsLanguageModel
|
||||
from .gpt2 import GPT2
|
||||
from .language_model_constraint import LanguageModelConstraint
|
||||
from .learning_to_write import LearningToWriteLanguageModel
|
||||
|
||||
@@ -51,12 +51,15 @@ class GoogleLanguageModel(Constraint):
|
||||
|
||||
def get_probs(current_text, transformed_texts):
|
||||
word_swap_index = current_text.first_word_diff_index(transformed_texts[0])
|
||||
if word_swap_index is None:
|
||||
return []
|
||||
|
||||
prefix = current_text.words[word_swap_index - 1]
|
||||
swapped_words = np.array(
|
||||
[t.words[word_swap_index] for t in transformed_texts]
|
||||
)
|
||||
if self.print_step:
|
||||
print(prefix, swapped_words, suffix)
|
||||
print(prefix, swapped_words)
|
||||
probs = self.lm.get_words_probs(prefix, swapped_words)
|
||||
return probs
|
||||
|
||||
@@ -104,6 +107,11 @@ class GoogleLanguageModel(Constraint):
|
||||
|
||||
return [transformed_texts[i] for i in max_el_indices]
|
||||
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
return self._check_constraint_many(
|
||||
[transformed_text], current_text, original_text=original_text
|
||||
)
|
||||
|
||||
def __call__(self, x, x_adv):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class GPT2(LanguageModelConstraint):
|
||||
|
||||
probs = []
|
||||
for attacked_text in text_list:
|
||||
nxt_word_ids = self.tokenizer.encode(attacked_text.words[word_index])
|
||||
next_word_ids = self.tokenizer.encode(attacked_text.words[word_index])
|
||||
next_word_prob = predictions[0, -1, next_word_ids[0]]
|
||||
probs.append(next_word_prob)
|
||||
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
|
||||
|
||||
class LanguageModelConstraint(Constraint):
|
||||
class LanguageModelConstraint(Constraint, ABC):
|
||||
"""
|
||||
Determines if two sentences have a swapped word that has a similar
|
||||
probability according to a language model.
|
||||
|
||||
Args:
|
||||
max_log_prob_diff (float): the maximum difference in log-probability
|
||||
between x and x_adv
|
||||
max_log_prob_diff (float): the maximum decrease in log-probability
|
||||
in swapped words from x to x_adv
|
||||
compare_against_original (bool): whether to compare against the original
|
||||
text or the most recent
|
||||
"""
|
||||
|
||||
def __init__(self, max_log_prob_diff=None):
|
||||
def __init__(self, max_log_prob_diff=None, compare_against_original=True):
|
||||
if max_log_prob_diff is None:
|
||||
raise ValueError("Must set max_log_prob_diff")
|
||||
self.max_log_prob_diff = max_log_prob_diff
|
||||
self.compare_against_original = compare_against_original
|
||||
|
||||
@abstractmethod
|
||||
def get_log_probs_at_index(self, text_list, word_index):
|
||||
""" Gets the log-probability of items in `text_list` at index
|
||||
`word_index` according to a language model.
|
||||
@@ -27,6 +29,9 @@ class LanguageModelConstraint(Constraint):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
if self.compare_against_original:
|
||||
current_text = original_text
|
||||
|
||||
try:
|
||||
indices = transformed_text.attack_attrs["newly_modified_indices"]
|
||||
except KeyError:
|
||||
@@ -41,9 +46,7 @@ class LanguageModelConstraint(Constraint):
|
||||
f"Error: get_log_probs_at_index returned {len(probs)} values for 2 inputs"
|
||||
)
|
||||
cur_prob, transformed_prob = probs
|
||||
if self.max_log_prob_diff is None:
|
||||
cur_prob, transformed_prob = math.log(p1), math.log(p2)
|
||||
if abs(cur_prob - transformed_prob) > self.max_log_prob_diff:
|
||||
if transformed_prob <= cur_prob - self.max_log_prob_diff:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .learning_to_write import LearningToWriteLanguageModel
|
||||
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.functional import log_softmax
|
||||
|
||||
import textattack
|
||||
|
||||
|
||||
class AdaptiveSoftmax(nn.Module):
|
||||
def __init__(self, input_size, cutoffs, scale_down=4):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.cutoffs = cutoffs
|
||||
self.output_size = cutoffs[0] + len(cutoffs) - 1
|
||||
self.head = nn.Linear(input_size, self.output_size)
|
||||
self.tail = nn.ModuleList()
|
||||
for i in range(len(cutoffs) - 1):
|
||||
seq = nn.Sequential(
|
||||
nn.Linear(input_size, input_size // scale_down, False),
|
||||
nn.Linear(input_size // scale_down, cutoffs[i + 1] - cutoffs[i], False),
|
||||
)
|
||||
self.tail.append(seq)
|
||||
|
||||
def reset(self, init=0.1):
|
||||
self.head.weight.data.uniform_(-init, init)
|
||||
for tail in self.tail:
|
||||
for layer in tail:
|
||||
layer.weight.data.uniform_(-init, init)
|
||||
|
||||
def set_target(self, target):
|
||||
self.id = []
|
||||
for i in range(len(self.cutoffs) - 1):
|
||||
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
|
||||
if mask.sum() > 0:
|
||||
self.id.append(Variable(mask.float().nonzero().squeeze(1)))
|
||||
else:
|
||||
self.id.append(None)
|
||||
|
||||
def forward(self, inp):
|
||||
assert len(inp.size()) == 2
|
||||
output = [self.head(inp)]
|
||||
for i in range(len(self.id)):
|
||||
if self.id[i] is not None:
|
||||
output.append(self.tail[i](inp.index_select(0, self.id[i])))
|
||||
else:
|
||||
output.append(None)
|
||||
return output
|
||||
|
||||
def log_prob(self, inp):
|
||||
assert len(inp.size()) == 2
|
||||
head_out = self.head(inp)
|
||||
n = inp.size(0)
|
||||
prob = torch.zeros(n, self.cutoffs[-1]).to(textattack.shared.utils.device)
|
||||
lsm_head = log_softmax(head_out, dim=head_out.dim() - 1)
|
||||
prob.narrow(1, 0, self.output_size).add_(
|
||||
lsm_head.narrow(1, 0, self.output_size).data
|
||||
)
|
||||
for i in range(len(self.tail)):
|
||||
pos = self.cutoffs[i]
|
||||
i_size = self.cutoffs[i + 1] - pos
|
||||
buff = lsm_head.narrow(1, self.cutoffs[0] + i, 1)
|
||||
buff = buff.expand(n, i_size)
|
||||
temp = self.tail[i](inp)
|
||||
lsm_tail = log_softmax(temp, dim=temp.dim() - 1)
|
||||
prob.narrow(1, pos, i_size).copy_(buff.data).add_(lsm_tail.data)
|
||||
return prob
|
||||
|
||||
|
||||
class AdaptiveLoss(nn.Module):
|
||||
def __init__(self, cutoffs):
|
||||
super().__init__()
|
||||
self.cutoffs = cutoffs
|
||||
self.criterions = nn.ModuleList()
|
||||
for i in self.cutoffs:
|
||||
self.criterions.append(nn.CrossEntropyLoss(size_average=False))
|
||||
|
||||
def reset(self):
|
||||
for criterion in self.criterions:
|
||||
criterion.zero_grad()
|
||||
|
||||
def remap_target(self, target):
|
||||
new_target = [target.clone()]
|
||||
for i in range(len(self.cutoffs) - 1):
|
||||
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
|
||||
|
||||
if mask.sum() > 0:
|
||||
new_target[0][mask] = self.cutoffs[0] + i
|
||||
new_target.append(target[mask].add(-self.cutoffs[i]))
|
||||
else:
|
||||
new_target.append(None)
|
||||
return new_target
|
||||
|
||||
def forward(self, inp, target):
|
||||
n = inp[0].size(0)
|
||||
target = self.remap_target(target.data)
|
||||
loss = 0
|
||||
for i in range(len(inp)):
|
||||
if inp[i] is not None:
|
||||
assert target[i].min() >= 0 and target[i].max() <= inp[i].size(1)
|
||||
criterion = self.criterions[i]
|
||||
loss += criterion(inp[i], Variable(target[i]))
|
||||
loss /= n
|
||||
return loss
|
||||
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchfile
|
||||
|
||||
from .rnn_model import RNNModel
|
||||
|
||||
|
||||
class QueryHandler:
|
||||
def __init__(self, model, word_to_idx, mapto, device):
|
||||
self.model = model
|
||||
self.word_to_idx = word_to_idx
|
||||
self.mapto = mapto
|
||||
self.device = device
|
||||
|
||||
def query(self, sentences, swapped_words, batch_size=32):
|
||||
""" Since we don't filter prefixes for OOV ahead of time, it's possible that
|
||||
some of them will have different lengths. When this is the case,
|
||||
we can't do RNN prediction in batch.
|
||||
|
||||
This method _tries_ to do prediction in batch, and, when it fails,
|
||||
just does prediction sequentially and concatenates all of the results.
|
||||
"""
|
||||
try:
|
||||
return self.try_query(sentences, swapped_words, batch_size=batch_size)
|
||||
except:
|
||||
probs = []
|
||||
for s, w in zip(sentences, swapped_words):
|
||||
probs.append(self.try_query([s], [w], batch_size=1)[0])
|
||||
return probs
|
||||
|
||||
def try_query(self, sentences, swapped_words, batch_size=32):
|
||||
# TODO use caching
|
||||
sentence_length = len(sentences[0])
|
||||
if any(len(s) != sentence_length for s in sentences):
|
||||
raise ValueError("Only same length batches are allowed")
|
||||
|
||||
log_probs = []
|
||||
for start in range(0, len(sentences), batch_size):
|
||||
swapped_words_batch = swapped_words[
|
||||
start : min(len(sentences), start + batch_size)
|
||||
]
|
||||
batch = sentences[start : min(len(sentences), start + batch_size)]
|
||||
raw_idx_list = [[] for i in range(sentence_length + 1)]
|
||||
for i, s in enumerate(batch):
|
||||
s = [word for word in s if word in self.word_to_idx]
|
||||
words = ["<S>"] + s
|
||||
word_idxs = [self.word_to_idx[w] for w in words]
|
||||
for t in range(sentence_length + 1):
|
||||
if t < len(word_idxs):
|
||||
raw_idx_list[t].append(word_idxs[t])
|
||||
orig_num_idxs = len(raw_idx_list)
|
||||
raw_idx_list = [x for x in raw_idx_list if len(x)]
|
||||
num_idxs_dropped = orig_num_idxs - len(raw_idx_list)
|
||||
all_raw_idxs = torch.tensor(
|
||||
raw_idx_list, device=self.device, dtype=torch.long
|
||||
)
|
||||
word_idxs = self.mapto[all_raw_idxs]
|
||||
hidden = self.model.init_hidden(len(batch))
|
||||
source = word_idxs[:-1, :]
|
||||
target = word_idxs[1:, :]
|
||||
decode, hidden = self.model(source, hidden)
|
||||
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
|
||||
for i in range(len(batch)):
|
||||
if swapped_words_batch[i] not in self.word_to_idx:
|
||||
log_probs.append(float("-inf"))
|
||||
else:
|
||||
log_probs.append(
|
||||
sum(
|
||||
[
|
||||
decode[t, i, target[t, i]].item()
|
||||
for t in range(sentence_length - num_idxs_dropped)
|
||||
]
|
||||
)
|
||||
)
|
||||
return log_probs
|
||||
|
||||
@staticmethod
|
||||
def load_model(lm_folder_path, device):
|
||||
word_map = torchfile.load(os.path.join(lm_folder_path, "word_map.th7"))
|
||||
word_map = [w.decode("utf-8") for w in word_map]
|
||||
word_to_idx = {w: i for i, w in enumerate(word_map)}
|
||||
word_freq = torchfile.load(
|
||||
os.path.join(os.path.join(lm_folder_path, "word_freq.th7"))
|
||||
)
|
||||
mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device)
|
||||
|
||||
model_file = open(os.path.join(lm_folder_path, "lm-state-dict.pt"), "rb")
|
||||
|
||||
model = RNNModel(
|
||||
"GRU",
|
||||
793471,
|
||||
256,
|
||||
2048,
|
||||
1,
|
||||
[4200, 35000, 180000, 793471],
|
||||
dropout=0.01,
|
||||
proj=True,
|
||||
lm1b=True,
|
||||
)
|
||||
|
||||
model.load_state_dict(torch.load(model_file, map_location=device))
|
||||
model.full = True # Use real softmax--important!
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model_file.close()
|
||||
return QueryHandler(model, word_to_idx, mapto, device)
|
||||
|
||||
|
||||
def util_reverse(item):
|
||||
new_item = np.zeros(len(item))
|
||||
for idx, val in enumerate(item):
|
||||
new_item[val] = idx
|
||||
return new_item
|
||||
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
from textattack.constraints.grammaticality.language_models import (
|
||||
LanguageModelConstraint,
|
||||
)
|
||||
|
||||
from .language_model_helpers import QueryHandler
|
||||
|
||||
|
||||
class LearningToWriteLanguageModel(LanguageModelConstraint):
|
||||
""" A constraint based on the L2W language model.
|
||||
|
||||
The RNN-based language model from "Learning to Write With Cooperative
|
||||
Discriminators" (Holtzman et al, 2018).
|
||||
|
||||
https://arxiv.org/pdf/1805.06087.pdf
|
||||
|
||||
https://github.com/windweller/l2w
|
||||
|
||||
|
||||
Reused by Jia et al., 2019, as a substitution for the Google 1-billion
|
||||
words language model (in a revised version the attack of Alzantot et
|
||||
al., 2018).
|
||||
|
||||
https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689
|
||||
"""
|
||||
|
||||
CACHE_PATH = "constraints/grammaticality/language-models/learning-to-write"
|
||||
|
||||
def __init__(self, window_size=5, **kwargs):
|
||||
self.window_size = window_size
|
||||
lm_folder_path = textattack.shared.utils.download_if_needed(
|
||||
LearningToWriteLanguageModel.CACHE_PATH
|
||||
)
|
||||
self.query_handler = QueryHandler.load_model(
|
||||
lm_folder_path, textattack.shared.utils.device
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_log_probs_at_index(self, text_list, word_index):
|
||||
""" Gets the probability of the word at index `word_index` according
|
||||
to the language model.
|
||||
"""
|
||||
queries = []
|
||||
query_words = []
|
||||
for attacked_text in text_list:
|
||||
word = attacked_text.words[word_index]
|
||||
window_text = attacked_text.text_window_around_index(
|
||||
word_index, self.window_size
|
||||
)
|
||||
query = textattack.shared.utils.words_from_text(window_text)
|
||||
queries.append(query)
|
||||
query_words.append(word)
|
||||
log_probs = self.query_handler.query(queries, query_words)
|
||||
return torch.tensor(log_probs)
|
||||
@@ -0,0 +1,101 @@
|
||||
from torch import nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from .adaptive_softmax import AdaptiveSoftmax
|
||||
|
||||
|
||||
class RNNModel(nn.Module):
|
||||
"""Container module with an encoder, a recurrent module, and a decoder. Based on official pytorch examples"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rnn_type,
|
||||
ntoken,
|
||||
ninp,
|
||||
nhid,
|
||||
nlayers,
|
||||
cutoffs,
|
||||
proj=False,
|
||||
dropout=0.5,
|
||||
tie_weights=False,
|
||||
lm1b=False,
|
||||
):
|
||||
super(RNNModel, self).__init__()
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.encoder = nn.Embedding(ntoken, ninp)
|
||||
|
||||
self.lm1b = lm1b
|
||||
|
||||
if rnn_type == "GRU":
|
||||
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
|
||||
else:
|
||||
try:
|
||||
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"""An invalid option for `--model` was supplied,
|
||||
options are ['GRU', 'RNN_TANH' or 'RNN_RELU']"""
|
||||
)
|
||||
self.rnn = nn.RNN(
|
||||
ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
|
||||
)
|
||||
|
||||
self.proj = proj
|
||||
|
||||
if ninp != nhid and proj:
|
||||
self.proj_layer = nn.Linear(nhid, ninp)
|
||||
|
||||
# if tie_weights:
|
||||
# if nhid != ninp and not proj:
|
||||
# raise ValueError('When using the tied flag, nhid must be equal to emsize')
|
||||
# self.decoder = nn.Linear(ninp, ntoken)
|
||||
# self.decoder.weight = self.encoder.weight
|
||||
# else:
|
||||
# if nhid != ninp and not proj:
|
||||
# if not lm1b:
|
||||
# self.decoder = nn.Linear(nhid, ntoken)
|
||||
# else:
|
||||
# self.decoder = adapt_loss
|
||||
# else:
|
||||
# self.decoder = nn.Linear(ninp, ntoken)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
self.rnn_type = rnn_type
|
||||
self.nhid = nhid
|
||||
self.nlayers = nlayers
|
||||
|
||||
if proj:
|
||||
self.softmax = AdaptiveSoftmax(ninp, cutoffs)
|
||||
else:
|
||||
self.softmax = AdaptiveSoftmax(nhid, cutoffs)
|
||||
|
||||
self.full = False
|
||||
|
||||
def init_weights(self):
|
||||
initrange = 0.1
|
||||
self.encoder.weight.data.uniform_(-initrange, initrange)
|
||||
# self.decoder.bias.data.fill_(0)
|
||||
# self.decoder.weight.data.uniform_(-initrange, initrange)
|
||||
|
||||
def forward(self, input, hidden):
|
||||
emb = self.drop(self.encoder(input))
|
||||
output, hidden = self.rnn(emb, hidden)
|
||||
output = self.drop(output)
|
||||
|
||||
if "proj" in vars(self):
|
||||
if self.proj:
|
||||
output = self.proj_layer(output)
|
||||
|
||||
output = output.view(output.size(0) * output.size(1), output.size(2))
|
||||
|
||||
if self.full:
|
||||
decode = self.softmax.log_prob(output)
|
||||
else:
|
||||
decode = self.softmax(output)
|
||||
|
||||
return decode, hidden
|
||||
|
||||
def init_hidden(self, bsz):
|
||||
weight = next(self.parameters()).data
|
||||
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())
|
||||
@@ -48,9 +48,9 @@ class PartOfSpeech(Constraint):
|
||||
)
|
||||
|
||||
if self.tagger_type == "flair":
|
||||
word_list, pos_list = zip_flair_result(
|
||||
self._flair_pos_tagger.predict(context_key)[0]
|
||||
)
|
||||
context_key_sentence = Sentence(context_key)
|
||||
self._flair_pos_tagger.predict(context_key_sentence)
|
||||
word_list, pos_list = zip_flair_result(context_key_sentence)
|
||||
|
||||
self._pos_tag_cache[context_key] = (word_list, pos_list)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from .pre_transformation_constraint import PreTransformationConstraint
|
||||
from .stopword_modification import StopwordModification
|
||||
from .repeat_modification import RepeatModification
|
||||
|
||||
from .input_column_modification import InputColumnModification
|
||||
from .max_word_index_modification import MaxWordIndexModification
|
||||
from .min_word_length import MinWordLength
|
||||
from .repeat_modification import RepeatModification
|
||||
from .stopword_modification import StopwordModification
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from textattack.constraints.pre_transformation import PreTransformationConstraint
|
||||
|
||||
|
||||
class InputColumnModification(PreTransformationConstraint):
|
||||
"""
|
||||
A constraint disallowing the modification of words within a specific input
|
||||
column.
|
||||
|
||||
For example, can prevent modification of 'premise' during
|
||||
entailment.
|
||||
"""
|
||||
|
||||
def __init__(self, matching_column_labels, columns_to_ignore):
|
||||
self.matching_column_labels = matching_column_labels
|
||||
self.columns_to_ignore = columns_to_ignore
|
||||
|
||||
def _get_modifiable_indices(self, current_text):
|
||||
""" Returns the word indices in current_text which are able to be
|
||||
deleted.
|
||||
|
||||
If ``current_text.column_labels`` doesn't match
|
||||
``self.matching_column_labels``, do nothing, and allow all words
|
||||
to be modified.
|
||||
|
||||
If it does match, only allow words to be modified if they are not
|
||||
in columns from ``columns_to_ignore``.
|
||||
"""
|
||||
if current_text.column_labels != self.matching_column_labels:
|
||||
return set(range(len(current_text.words)))
|
||||
|
||||
idx = 0
|
||||
indices_to_modify = set()
|
||||
for column, words in zip(
|
||||
current_text.column_labels, current_text.words_per_input
|
||||
):
|
||||
num_words = len(words)
|
||||
if column not in self.columns_to_ignore:
|
||||
indices_to_modify |= set(range(idx, idx + num_words))
|
||||
idx += num_words
|
||||
return indices_to_modify
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["matching_column_labels", "columns_to_ignore"]
|
||||
@@ -1,5 +1,4 @@
|
||||
from textattack.constraints.pre_transformation import PreTransformationConstraint
|
||||
from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class MaxWordIndexModification(PreTransformationConstraint):
|
||||
@@ -14,3 +13,5 @@ class MaxWordIndexModification(PreTransformationConstraint):
|
||||
""" Returns the word indices in current_text which are able to be deleted """
|
||||
return set(range(min(self.max_length, len(current_text.words))))
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["max_length"]
|
||||
|
||||
@@ -2,7 +2,6 @@ from textattack.constraints.pre_transformation import PreTransformationConstrain
|
||||
|
||||
|
||||
class MinWordLength(PreTransformationConstraint):
|
||||
|
||||
def __init__(self, min_length):
|
||||
self.min_length = min_length
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class PreTransformationConstraint(Constraint):
|
||||
class PreTransformationConstraint(Constraint, ABC):
|
||||
"""
|
||||
An abstract class that represents constraints which are applied before
|
||||
the transformation. These restrict which words are allowed to be modified
|
||||
@@ -23,6 +24,7 @@ class PreTransformationConstraint(Constraint):
|
||||
return set(range(len(current_text.words)))
|
||||
return self._get_modifiable_indices(current_text)
|
||||
|
||||
@abstractmethod
|
||||
def _get_modifiable_indices(current_text):
|
||||
"""
|
||||
Returns the word indices in ``current_text`` which are able to be modified.
|
||||
@@ -32,3 +34,8 @@ class PreTransformationConstraint(Constraint):
|
||||
current_text: The ``AttackedText`` input to consider.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _check_constraint(self):
|
||||
raise RuntimeError(
|
||||
"PreTransformationConstraints do not support `_check_constraint()`."
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from textattack.constraints.pre_transformation import PreTransformationConstraint
|
||||
from textattack.shared.utils import default_class_repr
|
||||
|
||||
|
||||
class RepeatModification(PreTransformationConstraint):
|
||||
|
||||
@@ -8,12 +8,11 @@
|
||||
"""
|
||||
This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
class InferSentModel(nn.Module):
|
||||
|
||||
@@ -21,7 +21,8 @@ class SentenceEncoder(Constraint):
|
||||
compare_with_original (bool): Whether to compare `x_adv` to the previous `x_adv`
|
||||
or the original `x`.
|
||||
window_size (int): The number of words to use in the similarity
|
||||
comparison.
|
||||
comparison. `None` indicates no windowing (encoding is based on the
|
||||
full input).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -38,6 +39,9 @@ class SentenceEncoder(Constraint):
|
||||
self.window_size = window_size
|
||||
self.skip_text_shorter_than_window = skip_text_shorter_than_window
|
||||
|
||||
if not self.window_size:
|
||||
self.window_size = float("inf")
|
||||
|
||||
if metric == "cosine":
|
||||
self.sim_metric = torch.nn.CosineSimilarity(dim=1)
|
||||
elif metric == "angular":
|
||||
@@ -68,7 +72,9 @@ class SentenceEncoder(Constraint):
|
||||
The similarity between the starting and transformed text using the metric.
|
||||
"""
|
||||
try:
|
||||
modified_index = next(iter(x_adv.attack_attrs["newly_modified_indices"]))
|
||||
modified_index = next(
|
||||
iter(transformed_text.attack_attrs["newly_modified_indices"])
|
||||
)
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"Cannot apply sentence encoder constraint without `newly_modified_indices`"
|
||||
@@ -107,7 +113,7 @@ class SentenceEncoder(Constraint):
|
||||
``transformed_texts``. If ``transformed_texts`` is empty,
|
||||
an empty tensor is returned
|
||||
"""
|
||||
# Return an empty tensor if x_adv_list is empty.
|
||||
# Return an empty tensor if transformed_texts is empty.
|
||||
# This prevents us from calling .repeat(x, 0), which throws an
|
||||
# error on machines with multiple GPUs (pytorch 1.2).
|
||||
if len(transformed_texts) == 0:
|
||||
@@ -137,9 +143,9 @@ class SentenceEncoder(Constraint):
|
||||
)
|
||||
)
|
||||
embeddings = self.encode(starting_text_windows + transformed_text_windows)
|
||||
starting_embeddings = torch.tensor(embeddings[: len(transformed_texts)]).to(
|
||||
utils.device
|
||||
)
|
||||
if not isinstance(embeddings, torch.Tensor):
|
||||
embeddings = torch.tensor(embeddings)
|
||||
starting_embeddings = embeddings[: len(transformed_texts)].to(utils.device)
|
||||
transformed_embeddings = torch.tensor(
|
||||
embeddings[len(transformed_texts) :]
|
||||
).to(utils.device)
|
||||
@@ -147,18 +153,12 @@ class SentenceEncoder(Constraint):
|
||||
starting_raw_text = starting_text.text
|
||||
transformed_raw_texts = [t.text for t in transformed_texts]
|
||||
embeddings = self.encode([starting_raw_text] + transformed_raw_texts)
|
||||
if isinstance(embeddings[0], torch.Tensor):
|
||||
starting_embedding = embeddings[0].to(utils.device)
|
||||
else:
|
||||
# If the embedding is not yet a tensor, make it one.
|
||||
starting_embedding = torch.tensor(embeddings[0]).to(utils.device)
|
||||
if not isinstance(embeddings, torch.Tensor):
|
||||
embeddings = torch.tensor(embeddings)
|
||||
|
||||
if isinstance(embeddings, list):
|
||||
# If `encode` did not return a Tensor of all embeddings, combine
|
||||
# into a tensor.
|
||||
transformed_embeddings = torch.stack(embeddings[1:]).to(utils.device)
|
||||
else:
|
||||
transformed_embeddings = torch.tensor(embeddings[1:]).to(utils.device)
|
||||
starting_embedding = embeddings[0].to(utils.device)
|
||||
|
||||
transformed_embeddings = embeddings[1:].to(utils.device)
|
||||
|
||||
# Repeat original embedding to size of perturbed embedding.
|
||||
starting_embeddings = starting_embedding.unsqueeze(dim=0).repeat(
|
||||
@@ -209,7 +209,7 @@ class SentenceEncoder(Constraint):
|
||||
"Must provide original text when compare_with_original is true."
|
||||
)
|
||||
else:
|
||||
scores = self._sim_score(current_text, transformed_texts)
|
||||
scores = self._sim_score(current_text, transformed_text)
|
||||
transformed_text.attack_attrs["similarity_score"] = score
|
||||
return score >= self.threshold
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user