1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

merge from master

This commit is contained in:
Jin Yong Yoo
2020-07-06 10:50:17 -04:00
110 changed files with 1622 additions and 648 deletions

34
.github/workflows/check-formatting.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Formatting with black & isort
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install black flake8 isort # Testing packages
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e .
- name: Check code format with black and isort
run: |
make lint

37
.github/workflows/make-docs.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Build documentation with Sphinx
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo sed -i 's/azure\.//' /etc/apt/sources.list # workaround for flaky pandoc install
sudo apt-get update # from here https://github.com/actions/virtual-environments/issues/675
sudo apt-get install pandoc -o Acquire::Retries=3 # install pandoc
python -m pip install --upgrade pip setuptools wheel # update python
pip install ipython --upgrade # needed for Github for whatever reason
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e . ".[dev]" # This should install all packages for development
- name: Build docs with Sphinx and check for errors
run: |
sphinx-build -b html docs docs/_build/html -W

View File

@@ -1,7 +1,7 @@
# This workflows will upload a Python Package using Twine when a release is created # This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Upload Python Package name: Upload Python Package to PyPI
on: on:
release: release:

View File

@@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Github PyTest name: Test with PyTest
on: on:
push: push:
@@ -26,13 +26,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip setuptools wheel python -m pip install --upgrade pip setuptools wheel
pip install black isort pytest pytest-xdist pip install pytest pytest-xdist # Testing packages
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e . pip install -e .
- name: Check code format with black and isort
run: |
black . --check
isort --check-only --recursive tests textattack
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest tests -vx --dist=loadfile -n auto pytest tests -vx --dist=loadfile -n auto

View File

@@ -179,11 +179,25 @@ Follow these steps to start contributing:
$ git push -u origin a-descriptive-name-for-my-changes $ git push -u origin a-descriptive-name-for-my-changes
``` ```
6. Once you are satisfied (**and the checklist below is happy too**), go to the 6. Add documentation.
Our docs are in the `docs/` folder. Thanks to `sphinx-automodule`, this
should just be two lines. Our docs will automatically generate from the
comments you added to your code. If you're adding an attack recipe, add a
reference in `attack_recipes.rst`. If you're adding a transformation, add
a reference in `transformation.rst`, etc.
You can build the docs and view the updates using `make docs`. If you're
adding a tutorial or something where you want to update the docs multiple
times, you can run `make docs-auto`. This will run a server using
`sphinx-autobuild` that should automatically reload whenever you change
a file.
7. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review. to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors 8. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in branch and push the changes to your fork. They will automatically appear in
the pull request. the pull request.

View File

@@ -3,9 +3,10 @@ format: FORCE ## Run black and isort (rewriting files)
isort --atomic --recursive tests textattack isort --atomic --recursive tests textattack
lint: FORCE ## Run black (in check mode) lint: FORCE ## Run black, isort, flake8 (in check mode)
black . --check black . --check
isort --check-only --recursive tests textattack isort --check-only --recursive tests textattack
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8
test: FORCE ## Run tests using pytest test: FORCE ## Run tests using pytest
python -m pytest --dist=loadfile -n auto python -m pytest --dist=loadfile -n auto
@@ -13,10 +14,13 @@ test: FORCE ## Run tests using pytest
docs: FORCE ## Build docs using Sphinx. docs: FORCE ## Build docs using Sphinx.
sphinx-build -b html docs docs/_build/html sphinx-build -b html docs docs/_build/html
docs-check: FORCE ## Builds docs using Sphinx. If there is an error, exit with an error code (instead of warning & continuing).
sphinx-build -b html docs docs/_build/html -W
docs-auto: FORCE ## Build docs using Sphinx and run hotreload server using Sphinx autobuild. docs-auto: FORCE ## Build docs using Sphinx and run hotreload server using Sphinx autobuild.
sphinx-autobuild docs docs/_build/html -H 0.0.0.0 -p 8765 sphinx-autobuild docs docs/_build/html -H 0.0.0.0 -p 8765
all: format lint test ## Format, lint, and test. all: format lint docs-check test ## Format, lint, and test.
.PHONY: help .PHONY: help

View File

@@ -18,7 +18,7 @@
</a> </a>
</p> </p>
<img src="https://github.com/jxmorris12/jxmorris12.github.io/blob/master/files/render1593035135238.gif?raw=true" style="display: block; margin: 0 auto;" /> <img src="http://jackxmorris.com/files/textattack.gif" alt="TextAttack Demo GIF" style="display: block; margin: 0 auto;" />
## About ## About
@@ -97,18 +97,19 @@ We include attack recipes which implement attacks from the literature. You can l
To run an attack recipe: `textattack attack --recipe [recipe_name]` To run an attack recipe: `textattack attack --recipe [recipe_name]`
These attacks are for classification tasks, like sentiment classification and entailment: Attacks on classification tasks, like sentiment classification and entailment:
- **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)). - **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)).
- **bae**: BERT masked language model transformation attack from (["BAE: BERT-based Adversarial Examples for Text Classification" (Garg & Ramakrishnan, 2019)](https://arxiv.org/abs/2004.01970)). - **bae**: BERT masked language model transformation attack from (["BAE: BERT-based Adversarial Examples for Text Classification" (Garg & Ramakrishnan, 2019)](https://arxiv.org/abs/2004.01970)).
- **bert-attack**: BERT masked language model transformation attack with subword replacements (["BERT-ATTACK: Adversarial Attack Against BERT Using BERT" (Li et al., 2020)](https://arxiv.org/abs/2004.09984)). - **bert-attack**: BERT masked language model transformation attack with subword replacements (["BERT-ATTACK: Adversarial Attack Against BERT Using BERT" (Li et al., 2020)](https://arxiv.org/abs/2004.09984)).
- **deepwordbug**: Greedy replace-1 scoring and multi-transformation character-swap attack (["Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers" (Gao et al., 2018)](https://arxiv.org/abs/1801.04354)). - **deepwordbug**: Greedy replace-1 scoring and multi-transformation character-swap attack (["Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers" (Gao et al., 2018)](https://arxiv.org/abs/1801.04354)).
- **hotflip**: Beam search and gradient-based word swap (["HotFlip: White-Box Adversarial Examples for Text Classification" (Ebrahimi et al., 2017)](https://arxiv.org/abs/1712.06751)). - **hotflip**: Beam search and gradient-based word swap (["HotFlip: White-Box Adversarial Examples for Text Classification" (Ebrahimi et al., 2017)](https://arxiv.org/abs/1712.06751)).
- **input-reduction**: Reducing the input while maintaining the prediction through word importance ranking (["Pathologies of Neural Models Make Interpretation Difficult" (Feng et al., 2018)](https://arxiv.org/pdf/1804.07781.pdf)).
- **kuleshov**: Greedy search and counterfitted embedding swap (["Adversarial Examples for Natural Language Classification Problems" (Kuleshov et al., 2018)](https://openreview.net/pdf?id=r1QZ3zbAZ)). - **kuleshov**: Greedy search and counterfitted embedding swap (["Adversarial Examples for Natural Language Classification Problems" (Kuleshov et al., 2018)](https://openreview.net/pdf?id=r1QZ3zbAZ)).
- **pwws**: Greedy attack with word importance ranking based on word saliency and synonym swap scores (["Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency" (Ren et al., 2019)](https://www.aclweb.org/anthology/P19-1103/)). - **pwws**: Greedy attack with word importance ranking based on word saliency and synonym swap scores (["Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency" (Ren et al., 2019)](https://www.aclweb.org/anthology/P19-1103/)).
- **textbugger**: Greedy attack with word importance ranking and character-based swaps ([(["TextBugger: Generating Adversarial Text Against Real-world Applications" (Li et al., 2018)](https://arxiv.org/abs/1812.05271)). - **textbugger**: Greedy attack with word importance ranking and character-based swaps ([(["TextBugger: Generating Adversarial Text Against Real-world Applications" (Li et al., 2018)](https://arxiv.org/abs/1812.05271)).
- **textfooler**: Greedy attack with word importance ranking and counter-fitted embedding swap (["Is Bert Really Robust?" (Jin et al., 2019)](https://arxiv.org/abs/1907.11932)). - **textfooler**: Greedy attack with word importance ranking and counter-fitted embedding swap (["Is Bert Really Robust?" (Jin et al., 2019)](https://arxiv.org/abs/1907.11932)).
The final is for sequence-to-sequence models: Attacks on sequence-to-sequence models:
- **seq2sick**: Greedy attack with goal of changing every word in the output translation. Currently implemented as black-box with plans to change to white-box as done in paper (["Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples" (Cheng et al., 2018)](https://arxiv.org/abs/1803.01128)). - **seq2sick**: Greedy attack with goal of changing every word in the output translation. Currently implemented as black-box with plans to change to white-box as done in paper (["Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples" (Cheng et al., 2018)](https://arxiv.org/abs/1803.01128)).
#### Recipe Usage Examples #### Recipe Usage Examples
@@ -122,7 +123,7 @@ textattack attack --model bert-base-uncased-sst2 --recipe textfooler --num-examp
*seq2sick (black-box) against T5 fine-tuned for English-German translation:* *seq2sick (black-box) against T5 fine-tuned for English-German translation:*
```bash ```bash
textattack attack --recipe seq2sick --model t5-en2de --num-examples 100 textattack attack --model t5-en-de --recipe seq2sick --num-examples 100
``` ```
### Augmenting Text ### Augmenting Text
@@ -284,7 +285,7 @@ The `attack_one` method in an `Attack` takes as input an `AttackedText`, and out
### Goal Functions ### Goal Functions
A `GoalFunction` takes as input an `AttackedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`. A `GoalFunction` takes as input an `AttackedText` object, scores it, and determines whether the attack has succeeded, returning a `GoalFunctionResult`.
### Constraints ### Constraints
@@ -303,6 +304,8 @@ A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a fi
We welcome suggestions and contributions! Submit an issue or pull request and we will do our best to respond in a timely manner. TextAttack is currently in an "alpha" stage in which we are working to improve its capabilities and design. We welcome suggestions and contributions! Submit an issue or pull request and we will do our best to respond in a timely manner. TextAttack is currently in an "alpha" stage in which we are working to improve its capabilities and design.
See [CONTRIBUTING.md](https://github.com/QData/TextAttack/blob/master/CONTRIBUTING.md) for detailed information on contributing.
## Citing TextAttack ## Citing TextAttack
If you use TextAttack for your research, please cite [TextAttack: A Framework for Adversarial Attacks in Natural Language Processing](https://arxiv.org/abs/2005.05909). If you use TextAttack for your research, please cite [TextAttack: A Framework for Adversarial Attacks in Natural Language Processing](https://arxiv.org/abs/2005.05909).

View File

@@ -6,67 +6,81 @@ We provide a number of pre-built attack recipes. To run an attack recipe, run::
textattack attack --recipe [recipe_name] textattack attack --recipe [recipe_name]
Alzantot Genetic Algorithm (Generating Natural Language Adversarial Examples) Alzantot Genetic Algorithm (Generating Natural Language Adversarial Examples)
########### ###################################################################################
.. automodule:: textattack.attack_recipes.genetic_algorithm_alzantot_2018 .. automodule:: textattack.attack_recipes.genetic_algorithm_alzantot_2018
:members: :members:
Faster Alzantot Genetic Algorithm (Certified Robustness to Adversarial Word Substitutions) Faster Alzantot Genetic Algorithm (Certified Robustness to Adversarial Word Substitutions)
########### ##############################################################################################
.. automodule:: textattack.attack_recipes.faster_genetic_algorithm_jia_2019 .. automodule:: textattack.attack_recipes.faster_genetic_algorithm_jia_2019
:members: :members:
BAE (BAE: BERT-Based Adversarial Examples) BAE (BAE: BERT-Based Adversarial Examples)
############ #############################################
.. automodule:: textattack.attack_recipes.deepwordbug_gao_2018 .. automodule:: textattack.attack_recipes.bae_garg_2019
:members:
BERT-Attack: (BERT-Attack: Adversarial Attack Against BERT Using BERT) BERT-Attack: (BERT-Attack: Adversarial Attack Against BERT Using BERT)
############ #########################################################################
.. automodule:: textattack.attack_recipes.deepwordbug_gao_2018 .. automodule:: textattack.attack_recipes.bert_attack_li_2020
:members:
DeepWordBug (Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers) DeepWordBug (Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers)
############ ######################################################################################################
.. automodule:: textattack.attack_recipes.deepwordbug_gao_2018 .. automodule:: textattack.attack_recipes.deepwordbug_gao_2018
:members: :members:
HotFlip (HotFlip: White-Box Adversarial Examples for Text Classification) HotFlip (HotFlip: White-Box Adversarial Examples for Text Classification)
########### ##############################################################################
.. automodule:: textattack.attack_recipes.hotflip_ebrahimi_2017
:members:
Input Reduction
################
.. automodule:: textattack.attack_recipes.input_reduction_feng_2018 .. automodule:: textattack.attack_recipes.input_reduction_feng_2018
:members: :members:
Kuleshov (Adversarial Examples for Natural Language Classification Problems) Kuleshov (Adversarial Examples for Natural Language Classification Problems)
########### ##############################################################################
.. automodule:: textattack.attack_recipes.kuleshov_2017 .. automodule:: textattack.attack_recipes.kuleshov_2017
:members: :members:
Particle Swarm Optimization (Word-level Textual Adversarial Attacking as Combinatorial Optimization)
#####################################################################################################
.. automodule:: textattack.attack_recipes.PSO_zang_2020
:members:
PWWS (Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency) PWWS (Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency)
########### ###################################################################################################
.. automodule:: textattack.attack_recipes.pwws_ren_2019 .. automodule:: textattack.attack_recipes.pwws_ren_2019
:members: :members:
Seq2Sick (Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples) Seq2Sick (Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples)
########### #########################################################################################################
.. automodule:: textattack.attack_recipes.seq2sick_cheng_2018_blackbox .. automodule:: textattack.attack_recipes.seq2sick_cheng_2018_blackbox
:members: :members:
TextFooler (Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and Entailment) TextFooler (Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and Entailment)
########### ########################################################################################################################
.. automodule:: textattack.attack_recipes.textfooler_jin_2019 .. automodule:: textattack.attack_recipes.textfooler_jin_2019
:members: :members:
TextBugger (TextBugger: Generating Adversarial Text Against Real-world Applications) TextBugger (TextBugger: Generating Adversarial Text Against Real-world Applications)
########### ########################################################################################
.. automodule:: textattack.attack_recipes.textbugger_li_2018 .. automodule:: textattack.attack_recipes.textbugger_li_2018
:members: :members:

View File

@@ -85,7 +85,7 @@ GPT-2
:members: :members:
"Learning To Write" Language Model "Learning To Write" Language Model
******* ************************************
.. automodule:: textattack.constraints.grammaticality.language_models.learning_to_write.learning_to_write .. automodule:: textattack.constraints.grammaticality.language_models.learning_to_write.learning_to_write
:members: :members:
@@ -142,7 +142,7 @@ Maximum Words Perturbed
.. _pre_transformation: .. _pre_transformation:
Pre-Transformation Pre-Transformation
---------- -------------------------
Pre-transformation constraints determine if a transformation is valid based on Pre-transformation constraints determine if a transformation is valid based on
only the original input and the position of the replacement. These constraints only the original input and the position of the replacement. These constraints
@@ -151,7 +151,7 @@ constraints can prevent search methods from swapping words at the same index
twice, or from replacing stopwords. twice, or from replacing stopwords.
Pre-Transformation Constraint Pre-Transformation Constraint
######################## ###############################
.. automodule:: textattack.constraints.pre_transformation.pre_transformation_constraint .. automodule:: textattack.constraints.pre_transformation.pre_transformation_constraint
:special-members: __call__ :special-members: __call__
:private-members: :private-members:
@@ -166,3 +166,13 @@ Repeat Modification
######################## ########################
.. automodule:: textattack.constraints.pre_transformation.repeat_modification .. automodule:: textattack.constraints.pre_transformation.repeat_modification
:members: :members:
Input Column Modification
#############################
.. automodule:: textattack.constraints.pre_transformation.input_column_modification
:members:
Max Word Index Modification
###############################
.. automodule:: textattack.constraints.pre_transformation.max_word_index_modification
:members:

View File

@@ -69,7 +69,7 @@ Word Swap by Random Character Insertion
:members: :members:
Word Swap by Random Character Substitution Word Swap by Random Character Substitution
--------------------------------------- -------------------------------------------
.. automodule:: textattack.transformations.word_swap_random_character_substitution .. automodule:: textattack.transformations.word_swap_random_character_substitution
:members: :members:

View File

@@ -22,7 +22,7 @@ copyright = "2020, UVA QData Lab"
author = "UVA QData Lab" author = "UVA QData Lab"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = "0.1.2" release = "0.1.5"
# Set master doc to `index.rst`. # Set master doc to `index.rst`.
master_doc = "index" master_doc = "index"

View File

@@ -6,19 +6,10 @@ Datasets
:members: :members:
:private-members: :private-members:
Classification .. automodule:: textattack.datasets.huggingface_nlp_dataset
###############
.. automodule:: textattack.datasets.classification.classification_dataset
:members: :members:
Entailment .. automodule:: textattack.datasets.translation.ted_multi
############
.. automodule:: textattack.datasets.entailment.entailment_dataset
:members: :members:
Translation
#############
.. automodule:: textattack.datasets.translation.translation_datasets
:members:

View File

@@ -11,7 +11,7 @@ We split models up into two broad categories:
**Classification models:** **Classification models:**
:ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using transformers_. :ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using ``transformers``.
:ref:`LSTM`: a standard LSTM fine-tuned on various datasets. :ref:`LSTM`: a standard LSTM fine-tuned on various datasets.
@@ -20,30 +20,29 @@ We split models up into two broad categories:
**Text-to-text models:** **Text-to-text models:**
:ref:`T5`: ``T5`` fine-tuned on various datasets using transformers_. :ref:`T5`: ``T5`` fine-tuned on various datasets using ``transformers``.
.. _BERT:
BERT BERT
******** ********
.. _BERT:
.. automodule:: textattack.models.helpers.bert_for_classification .. automodule:: textattack.models.helpers.bert_for_classification
:members: :members:
LSTM
*******
.. _LSTM: .. _LSTM:
LSTM
*******
.. automodule:: textattack.models.helpers.lstm_for_classification .. automodule:: textattack.models.helpers.lstm_for_classification
:members: :members:
Word-CNN
************
.. _CNN: .. _CNN:
Word-CNN
************
.. automodule:: textattack.models.helpers.word_cnn_for_classification .. automodule:: textattack.models.helpers.word_cnn_for_classification
:members: :members:

View File

@@ -4,7 +4,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# The TextAttack🐙 ecosystem: search, transformations, and constraints\n", "# The TextAttack ecosystem: search, transformations, and constraints\n",
"\n", "\n",
"An attack in TextAttack consists of four parts.\n", "An attack in TextAttack consists of four parts.\n",
"\n", "\n",
@@ -31,9 +31,9 @@
"This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\" (these don't intend to preserve the grammaticality of inputs). Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n", "This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\" (these don't intend to preserve the grammaticality of inputs). Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n",
"\n", "\n",
"\n", "\n",
"### Banana word swap 🍌\n", "### Banana word swap \n",
"\n", "\n",
"As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'." "As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'. 🍌"
] ]
}, },
{ {
@@ -308,9 +308,9 @@
"collapsed": true "collapsed": true
}, },
"source": [ "source": [
"### Conclusion 🍌\n", "### Conclusion n",
"\n", "\n",
"We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a single \"banana\", while others needed up to 17 \"banana\" substitutions to change the class score. Wow!" "We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a couple \"banana\"s, while others needed up to 17 \"banana\" substitutions to change the class score. Wow! 🍌"
] ]
} }
], ],

View File

@@ -35,7 +35,6 @@ TextAttack has some other features that make it a pleasure to use:
Installation <quickstart/installation> Installation <quickstart/installation>
Overview <quickstart/overview>
Command-Line Usage <quickstart/command_line_usage> Command-Line Usage <quickstart/command_line_usage>
Tutorial 0: TextAttack End-To-End (Train, Eval, Attack) <examples/0_End_to_End.ipynb> Tutorial 0: TextAttack End-To-End (Train, Eval, Attack) <examples/0_End_to_End.ipynb>
Tutorial 1: Transformations <examples/1_Introduction_and_Transformations.ipynb> Tutorial 1: Transformations <examples/1_Introduction_and_Transformations.ipynb>
@@ -76,7 +75,7 @@ TextAttack has some other features that make it a pleasure to use:
:hidden: :hidden:
:caption: Miscellaneous :caption: Miscellaneous
misc/attacked_text
misc/checkpoints misc/checkpoints
misc/loggers misc/loggers
misc/validators misc/validators
misc/tokenized_text

View File

@@ -0,0 +1,6 @@
===================
Attacked Text
===================
.. automodule:: textattack.shared.attacked_text
:members:

View File

@@ -1,6 +0,0 @@
===================
Tokenized Text
===================
.. automodule:: textattack.shared.tokenized_text
:members:

View File

@@ -22,7 +22,7 @@ examples corresponding to the proper columns.
For example, given the following as `examples.csv`: For example, given the following as `examples.csv`:
```csv ```
"text",label "text",label
"the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1 "the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1
"the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1 "the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1
@@ -40,7 +40,7 @@ will augment the `text` column with four swaps per augmentation, twice as many a
output CSV. (All of this will be saved to `augment.csv` by default.) output CSV. (All of this will be saved to `augment.csv` by default.)
After augmentation, here are the contents of `augment.csv`: After augmentation, here are the contents of `augment.csv`:
```csv ```
text,label text,label
"the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1 "the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1
"the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1 "the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1
@@ -132,4 +132,4 @@ see some basic information about the dataset.
For example, use `textattack peek-dataset --dataset-from-nlp glue:mrpc` to see For example, use `textattack peek-dataset --dataset-from-nlp glue:mrpc` to see
information about the MRPC dataset (from the GLUE set of datasets). This will information about the MRPC dataset (from the GLUE set of datasets). This will
print statistics like the number of labels, average number of words, etc. print statistics like the number of labels, average number of words, etc.

View File

@@ -9,7 +9,7 @@ numpy
pandas>=1.0.1 pandas>=1.0.1
scikit-learn scikit-learn
scipy==1.4.1 scipy==1.4.1
sentence_transformers sentence_transformers==0.2.6.1
torch torch
transformers>=3 transformers>=3
tensorflow>=2 tensorflow>=2

View File

@@ -1,9 +1,3 @@
[flake8]
ignore = E203, E266, E501, W503
max-line-length = 120
per-file-ignores = __init__.py:F401
mypy_config = mypy.ini
[isort] [isort]
line_length = 88 line_length = 88
skip = __init__.py skip = __init__.py
@@ -14,3 +8,11 @@ multi_line_output = 3
include_trailing_comma = True include_trailing_comma = True
use_parentheses = True use_parentheses = True
force_grid_wrap = 0 force_grid_wrap = 0
[flake8]
exclude = .git,__pycache__,wandb,build,dist
ignore = E203, E266, E501, W503, D203
max-complexity = 10
max-line-length = 120
mypy_config = mypy.ini
per-file-ignores = __init__.py:F401

View File

@@ -7,9 +7,12 @@ with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
extras = {} extras = {}
# Packages required for installing docs.
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
# Packages required for formatting code & running tests.
extras["test"] = ["black", "isort", "flake8", "pytest", "pytest-xdist"]
# For developers, install development tools along with all optional dependencies. # For developers, install development tools along with all optional dependencies.
extras["dev"] = ["black", "isort", "pytest", "pytest-xdist"] extras["dev"] = extras["docs"] + extras["test"]
setuptools.setup( setuptools.setup(
name="textattack", name="textattack",
@@ -27,9 +30,9 @@ setuptools.setup(
"build*", "build*",
"docs*", "docs*",
"dist*", "dist*",
"examples*",
"outputs*", "outputs*",
"tests*", "tests*",
"local_test*",
"wandb*", "wandb*",
] ]
), ),

View File

@@ -28,6 +28,10 @@
) )
(3): RepeatModification (3): RepeatModification
(4): StopwordModification (4): StopwordModification
(5): InputColumnModification(
(matching_column_labels): ['premise', 'hypothesis']
(columns_to_ignore): {'premise'}
)
(is_black_box): True (is_black_box): True
) )
/.*/ /.*/

View File

@@ -0,0 +1,57 @@
/.*/Attack(
(search_method): GreedySearch
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 15
(embedding_type): paragramcf
)
(constraints):
(0): MaxWordsPerturbed(
(max_percent): 0.5
)
(1): ThoughtVector(
(embedding_type): paragramcf
(metric): max_euclidean
(threshold): -0.2
(compare_with_original): False
(window_size): inf
(skip_text_shorter_than_window): False
)
(2): GPT2(
(max_log_prob_diff): 2.0
)
(3): RepeatModification
(4): StopwordModification
(is_black_box): True
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> Negative (69%)
it 's a charming and often affecting journey .
it 's a loveable and ordinarily affecting journey .
--------------------------------------------- Result 2 ---------------------------------------------
Negative (83%) --> Positive (90%)
unflinchingly bleak and desperate
unflinchingly bleak and desperation
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 0 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% |
| Average perturbed word %: | 25.0% |
| Average num. words per input: | 6.0 |
| Avg num queries: | 48.5 |
+-------------------------------+--------+

View File

@@ -39,7 +39,7 @@
| Original accuracy: | 100.0% | | Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% | | Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% | | Attack success rate: | 100.0% |
| Average perturbed word %: | 45.39% | | Average perturbed word %: | 45.0% |
| Average num. words per input: | 11.5 | | Average num. words per input: | 12.0 |
| Avg num queries: | 26.5 | | Avg num queries: | 27.0 |
+-------------------------------+--------+ +-------------------------------+--------+

View File

@@ -0,0 +1,66 @@
/.*/Attack(
(search_method): GeneticAlgorithm(
(pop_size): 60
(max_iters): 20
(temp): 0.3
(give_up_if_no_improvement): False
)
(goal_function): UntargetedClassification
(transformation): WordSwapEmbedding(
(max_candidates): 8
(embedding_type): paragramcf
)
(constraints):
(0): MaxWordsPerturbed(
(max_percent): 0.2
)
(1): WordEmbeddingDistance(
(embedding_type): paragramcf
(max_mse_dist): 0.5
(cased): False
(include_unknown_words): True
)
(2): LearningToWriteLanguageModel(
(max_log_prob_diff): 5.0
)
(3): RepeatModification
(4): StopwordModification
(is_black_box): True
)
/.*/
--------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> Negative (73%)
this kind of hands-on storytelling is ultimately what makes shanghai ghetto move beyond a good , dry , reliable textbook and what allows it to rank with its worthy predecessors .
this kind of hands-on tale is ultimately what makes shanghai ghetto move beyond a good , secs , credible textbook and what allows it to rank with its worthy predecessors .
--------------------------------------------- Result 2 ---------------------------------------------
Positive (80%) --> Negative (97%)
making such a tragedy the backdrop to a love story risks trivializing it , though chouraqui no doubt intended the film to affirm love's power to help people endure almost unimaginable horror .
making such a tragedy the backdrop to a love story risks trivializing it , notwithstanding chouraqui no doubt intended the film to affirm love's power to help people endure almost incomprehensible horror .
--------------------------------------------- Result 3 ---------------------------------------------
Positive (92%) --> [FAILED]
grown-up quibbles are beside the point here . the little girls understand , and mccracken knows that's all that matters .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 1 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 66.67% |
| Average perturbed word %: | 8.58% |
| Average num. words per input: | 25.67 |
| Avg num queries: |/.*/|
+-------------------------------+--------+

View File

@@ -24,11 +24,11 @@
) )
/.*/ /.*/
--------------------------------------------- Result 1 --------------------------------------------- --------------------------------------------- Result 1 ---------------------------------------------
Positive (100%) --> Negative (88%) Positive (100%) --> Negative (98%)
exposing the ways we fool ourselves is one hour photo's real strength . exposing the ways we fool ourselves is one hour photo's real strength .
exposing the ways we fool ourselves is one hour pictures's real kraft . exposing the ways we fool ourselves is one stopwatch photo's real kraft .
--------------------------------------------- Result 2 --------------------------------------------- --------------------------------------------- Result 2 ---------------------------------------------
@@ -65,7 +65,7 @@ mostly , [goldbacher] just lets her complicated characters be haphazard
| Original accuracy: | 100.0% | | Original accuracy: | 100.0% |
| Accuracy under attack: | 0.0% | | Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% | | Attack success rate: | 100.0% |
| Average perturbed word %: | 17.13% | | Average perturbed word %: | 17.56% |
| Average num. words per input: | 17.0 | | Average num. words per input: | 16.25 |
| Avg num queries: | 46.0 | | Avg num queries: | 45.5 |
+-------------------------------+--------+ +-------------------------------+--------+

View File

@@ -35,6 +35,6 @@ I went into "Night of the Hunted" not knowing what to expect at all. I was reall
| Accuracy under attack: | 0.0% | | Accuracy under attack: | 0.0% |
| Attack success rate: | 100.0% | | Attack success rate: | 100.0% |
| Average perturbed word %: | 0.62% | | Average perturbed word %: | 0.62% |
| Average num. words per input: | 165.0 | | Average num. words per input: | 164.0 |
| Avg num queries: | 167.0 | | Avg num queries: | 166.0 |
+-------------------------------+--------+ +-------------------------------+--------+

View File

@@ -27,11 +27,9 @@
) )
/.*/ /.*/
--------------------------------------------- Result 1 --------------------------------------------- --------------------------------------------- Result 1 ---------------------------------------------
Positive (97%) --> Negative (100%) Positive (97%) --> [FAILED]
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur supplies with tremendous skill . the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur supplies with tremendous skill .
the story gives ample opportunity for large-scale action and suspense , which director shekhar unwilling supplies with tremendous skill .
--------------------------------------------- Result 2 --------------------------------------------- --------------------------------------------- Result 2 ---------------------------------------------
@@ -58,13 +56,13 @@ throws in enough clever and unexpected twists to make the formula feel fresh .
+-------------------------------+--------+ +-------------------------------+--------+
| Attack Results | | | Attack Results | |
+-------------------------------+--------+ +-------------------------------+--------+
| Number of successful attacks: | 2 | | Number of successful attacks: | 1 |
| Number of failed attacks: | 2 | | Number of failed attacks: | 3 |
| Number of skipped attacks: | 0 | | Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% | | Original accuracy: | 100.0% |
| Accuracy under attack: | 50.0% | | Accuracy under attack: | 75.0% |
| Attack success rate: | 50.0% | | Attack success rate: | 25.0% |
| Average perturbed word %: | 4.55% | | Average perturbed word %: | 3.85% |
| Average num. words per input: | 15.75 | | Average num. words per input: | 15.5 |
| Avg num queries: | 1.5 | | Avg num queries: | 1.25 |
+-------------------------------+--------+ +-------------------------------+--------+

View File

@@ -23,10 +23,13 @@
--------------------------------------------- Result 2 --------------------------------------------- --------------------------------------------- Result 2 ---------------------------------------------
Neutral (100%) --> [FAILED] Neutral (100%) --> Entailment (56%)
Premise: This site includes a list of all award winners and a searchable database of Government Executive articles. Premise: This site includes a list of all award winners and a searchable database of Government Executive articles.
Hypothesis: The Government Executive articles housed on the website are not able to be searched. Hypothesis: The Government Executive articles housed on the website are not able to be searched.
Premise: This site includes a list of all award winners and a searchable database of Government Executive articles.
Hypothesis: The Government Executive articles housed on the website are not able-bodied to be searched.
--------------------------------------------- Result 3 --------------------------------------------- --------------------------------------------- Result 3 ---------------------------------------------
@@ -43,13 +46,13 @@
+-------------------------------+--------+ +-------------------------------+--------+
| Attack Results | | | Attack Results | |
+-------------------------------+--------+ +-------------------------------+--------+
| Number of successful attacks: | 1 | | Number of successful attacks: | 2 |
| Number of failed attacks: | 1 | | Number of failed attacks: | 0 |
| Number of skipped attacks: | 1 | | Number of skipped attacks: | 1 |
| Original accuracy: | 66.67% | | Original accuracy: | 66.67% |
| Accuracy under attack: | 33.33% | | Accuracy under attack: | 0.0% |
| Attack success rate: | 50.0% | | Attack success rate: | 100.0% |
| Average perturbed word %: | 2.27% | | Average perturbed word %: | 2.78% |
| Average num. words per input: | 29.0 | | Average num. words per input: | 28.67 |
| Avg num queries: | 447.5 | | Avg num queries: | 182.0 |
+-------------------------------+--------+ +-------------------------------+--------+

View File

@@ -12,12 +12,27 @@ def attacked_text():
return textattack.shared.AttackedText(raw_text) return textattack.shared.AttackedText(raw_text)
raw_pokemon_text = "the threat implied in the title pokémon 4ever is terrifying – like locusts in a horde these things will keep coming ."
@pytest.fixture
def pokemon_attacked_text():
return textattack.shared.AttackedText(raw_pokemon_text)
premise = "Among these are the red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes." premise = "Among these are the red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes."
hypothesis = "The Patan Museum is down the street from the red brick Royal Palace." hypothesis = "The Patan Museum is down the street from the red brick Royal Palace."
raw_text_pair = collections.OrderedDict( raw_text_pair = collections.OrderedDict(
[("premise", premise), ("hypothesis", hypothesis)] [("premise", premise), ("hypothesis", hypothesis)]
) )
raw_hyphenated_text = "It's a run-of-the-mill kind of farmer's tan."
@pytest.fixture
def hyphenated_text():
return textattack.shared.AttackedText(raw_hyphenated_text)
@pytest.fixture @pytest.fixture
def attacked_text_pair(): def attacked_text_pair():
@@ -25,27 +40,13 @@ def attacked_text_pair():
class TestAttackedText: class TestAttackedText:
def test_words(self, attacked_text): def test_words(self, attacked_text, pokemon_attacked_text):
# fmt: off
assert attacked_text.words == [ assert attacked_text.words == [
"A", "A", "person", "walks", "up", "stairs", "into", "a", "room", "and", "sees", "beer", "poured", "from", "a", "keg", "and", "people", "talking",
"person",
"walks",
"up",
"stairs",
"into",
"a",
"room",
"and",
"sees",
"beer",
"poured",
"from",
"a",
"keg",
"and",
"people",
"talking",
] ]
assert pokemon_attacked_text.words == ['the', 'threat', 'implied', 'in', 'the', 'title', 'pokémon', '4ever', 'is', 'terrifying', 'like', 'locusts', 'in', 'a', 'horde', 'these', 'things', 'will', 'keep', 'coming']
# fmt: on
def test_window_around_index(self, attacked_text): def test_window_around_index(self, attacked_text):
assert attacked_text.text_window_around_index(5, 1) == "into" assert attacked_text.text_window_around_index(5, 1) == "into"
@@ -69,8 +70,9 @@ class TestAttackedText:
def test_window_around_index_end(self, attacked_text): def test_window_around_index_end(self, attacked_text):
assert attacked_text.text_window_around_index(17, 3) == "and people talking" assert attacked_text.text_window_around_index(17, 3) == "and people talking"
def test_text(self, attacked_text, attacked_text_pair): def test_text(self, attacked_text, pokemon_attacked_text, attacked_text_pair):
assert attacked_text.text == raw_text assert attacked_text.text == raw_text
assert pokemon_attacked_text.text == raw_pokemon_text
assert attacked_text_pair.text == "\n".join(raw_text_pair.values()) assert attacked_text_pair.text == "\n".join(raw_text_pair.values())
def test_printable_text(self, attacked_text, attacked_text_pair): def test_printable_text(self, attacked_text, attacked_text_pair):
@@ -140,13 +142,13 @@ class TestAttackedText:
+ "\n" + "\n"
+ "The Patan Museum is down the street from the red brick Royal Palace." + "The Patan Museum is down the street from the red brick Royal Palace."
) )
new_text = new_text.insert_text_after_word_index(38, "and shapes") new_text = new_text.insert_text_after_word_index(37, "and shapes")
assert new_text.text == ( assert new_text.text == (
"Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes." "Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes."
+ "\n" + "\n"
+ "The Patan Museum is down the street from the red brick Royal Palace." + "The Patan Museum is down the street from the red brick Royal Palace."
) )
new_text = new_text.insert_text_after_word_index(41, "The") new_text = new_text.insert_text_after_word_index(40, "The")
assert new_text.text == ( assert new_text.text == (
"Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes." "Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes."
+ "\n" + "\n"
@@ -163,7 +165,7 @@ class TestAttackedText:
) )
for old_idx, new_idx in enumerate(new_text.attack_attrs["original_index_map"]): for old_idx, new_idx in enumerate(new_text.attack_attrs["original_index_map"]):
assert (attacked_text.words[old_idx] == new_text.words[new_idx]) or ( assert (attacked_text.words[old_idx] == new_text.words[new_idx]) or (
new_i == -1 new_idx == -1
) )
new_text = ( new_text = (
new_text.delete_word_at_index(0) new_text.delete_word_at_index(0)
@@ -180,3 +182,14 @@ class TestAttackedText:
new_text.text new_text.text
== "person walks a very long way up stairs into a room and sees beer poured and people on the couch." == "person walks a very long way up stairs into a room and sees beer poured and people on the couch."
) )
def test_hyphen_apostrophe_words(self, hyphenated_text):
assert hyphenated_text.words == [
"It's",
"a",
"run-of-the-mill",
"kind",
"of",
"farmer's",
"tan",
]

View File

@@ -112,6 +112,24 @@ attack_test_params = [
), ),
# fmt: on # fmt: on
# #
# test: run_attack on LSTM MR using word embedding transformation and genetic algorithm. Simulate alzantot recipe without using expensive LM
(
"run_attack_faster_alzantot_recipe",
(
"textattack attack --model lstm-mr --recipe faster-alzantot --num-examples 3 --num-examples-offset 20"
),
"tests/sample_outputs/run_attack_faster_alzantot_recipe.txt",
),
#
# test: run_attack with kuleshov recipe and sst-2 cnn
#
(
"run_attack_kuleshov_nn",
(
"textattack attack --recipe kuleshov --num-examples 2 --model cnn-sst --attack-n --query-budget 200"
),
"tests/sample_outputs/kuleshov_cnn_sst_2.txt",
),
] ]

View File

@@ -37,3 +37,5 @@ def test_command_line_augmentation(name, command, outfile, sample_output_file):
# Ensure CSV file exists, then delete it. # Ensure CSV file exists, then delete it.
assert os.path.exists(outfile) assert os.path.exists(outfile)
os.remove(outfile) os.remove(outfile)
assert result.returncode == 0

View File

@@ -27,3 +27,5 @@ def test_command_line_list(name, command, sample_output_file):
print("stderr =>", stderr) print("stderr =>", stderr)
assert stdout == desired_text assert stdout == desired_text
assert result.returncode == 0

View File

@@ -0,0 +1,56 @@
from textattack.constraints.pre_transformation import (
InputColumnModification,
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import PSOAlgorithm
from textattack.shared.attack import Attack
from textattack.transformations import WordSwapEmbedding, WordSwapHowNet
def PSOZang2020(model):
"""
Zang, Y., Yang, C., Qi, F., Liu, Z., Zhang, M., Liu, Q., & Sun, M. (2019).
Word-level Textual Adversarial Attacking as Combinatorial Optimization.
https://www.aclweb.org/anthology/2020.acl-main.540.pdf
Methodology description quoted from the paper:
"We propose a novel word substitution-based textual attack model, which reforms
both the aforementioned two steps. In the first step, we adopt a sememe-based word
substitution strategy, which can generate more candidate adversarial examples with
better semantic preservation. In the second step, we utilize particle swarm optimization
(Eberhart and Kennedy, 1995) as the adversarial example searching algorithm."
And "Following the settings in Alzantot et al. (2018), we set the max iteration time G to 20."
"""
#
# Swap words with their synonyms extracted based on the HowNet.
#
transformation = WordSwapHowNet()
#
# Don't modify the same word twice or stopwords
#
constraints = [RepeatModification(), StopwordModification()]
#
#
# During entailment, we should only edit the hypothesis - keep the premise
# the same.
#
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
#
# Use untargeted classification for demo, can be switched to targeted one
#
goal_function = UntargetedClassification(model)
#
# Perform word substitution with a Particle Swarm Optimization (PSO) algorithm.
#
search_method = PSOAlgorithm(pop_size=60, max_iters=20)
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -4,8 +4,10 @@ from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019 from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
from .deepwordbug_gao_2018 import DeepWordBugGao2018 from .deepwordbug_gao_2018 import DeepWordBugGao2018
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017 from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
from .input_reduction_feng_2018 import InputReductionFeng2018
from .kuleshov_2017 import Kuleshov2017 from .kuleshov_2017 import Kuleshov2017
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
from .textbugger_li_2018 import TextBuggerLi2018 from .textbugger_li_2018 import TextBuggerLi2018
from .textfooler_jin_2019 import TextFoolerJin2019 from .textfooler_jin_2019 import TextFoolerJin2019
from .pwws_ren_2019 import PWWSRen2019 from .pwws_ren_2019 import PWWSRen2019
from .PSO_zang_2020 import PSOZang2020

View File

@@ -20,14 +20,6 @@ def BERTAttackLi2020(model):
This is "attack mode" 1 from the paper, BAE-R, word replacement. This is "attack mode" 1 from the paper, BAE-R, word replacement.
""" """
from textattack.shared.utils import logger
logger.warn(
"WARNING: This BERT-Attack implementation is based off of a"
" preliminary draft of the paper, which lacked source code and"
" did not include any hyperparameters. Attack reuslts are likely to"
" change."
)
# [from correspondence with the author] # [from correspondence with the author]
# Candidate size K is set to 48 for all data-sets. # Candidate size K is set to 48 for all data-sets.
transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48) transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)

View File

@@ -119,6 +119,6 @@ def FasterGeneticAlgorithmJia2019(model):
# #
# Perform word substitution with a genetic algorithm. # Perform word substitution with a genetic algorithm.
# #
search_method = GeneticAlgorithm(pop_size=60, max_iters=20) search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0)
return Attack(goal_function, constraints, transformation, search_method) return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -3,6 +3,7 @@ from textattack.constraints.grammaticality.language_models import (
) )
from textattack.constraints.overlap import MaxWordsPerturbed from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import ( from textattack.constraints.pre_transformation import (
InputColumnModification,
RepeatModification, RepeatModification,
StopwordModification, StopwordModification,
) )
@@ -34,6 +35,14 @@ def GeneticAlgorithmAlzantot2018(model):
# #
constraints = [RepeatModification(), StopwordModification()] constraints = [RepeatModification(), StopwordModification()]
# #
# During entailment, we should only edit the hypothesis - keep the premise
# the same.
#
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
#
# Maximum words perturbed percentage of 20% # Maximum words perturbed percentage of 20%
# #
constraints.append(MaxWordsPerturbed(max_percent=0.2)) constraints.append(MaxWordsPerturbed(max_percent=0.2))
@@ -52,6 +61,6 @@ def GeneticAlgorithmAlzantot2018(model):
# #
# Perform word substitution with a genetic algorithm. # Perform word substitution with a genetic algorithm.
# #
search_method = GeneticAlgorithm(pop_size=60, max_iters=20) search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0)
return Attack(goal_function, constraints, transformation, search_method) return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -0,0 +1,40 @@
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import InputReduction
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared.attack import Attack
from textattack.transformations import WordDeletion
def InputReductionFeng2018(model):
"""
Feng, Wallace, Grissom, Iyyer, Rodriguez, Boyd-Graber. (2018).
Pathologies of Neural Models Make Interpretations Difficult.
ArXiv, abs/1804.07781.
"""
# At each step, we remove the word with the lowest importance value until
# the model changes its prediction.
transformation = WordDeletion()
constraints = [RepeatModification(), StopwordModification()]
#
# Goal is untargeted classification
#
goal_function = InputReduction(model, maximizable=True)
#
# "For each word in an input sentence, we measure its importance by the
# change in the confidence of the original prediction when we remove
# that word from the sentence."
#
# "Instead of looking at the words with high importance values—what
# interpretation methods commonly do—we take a complementary approach
# and study how the model behaves when the supposedly unimportant words are
# removed."
#
search_method = GreedyWordSwapWIR(wir_method="delete")
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -26,5 +26,5 @@ def PWWSRen2019(model):
constraints = [RepeatModification(), StopwordModification()] constraints = [RepeatModification(), StopwordModification()]
goal_function = UntargetedClassification(model) goal_function = UntargetedClassification(model)
# search over words based on a combination of their saliency score, and how efficient the WordSwap transform is # search over words based on a combination of their saliency score, and how efficient the WordSwap transform is
search_method = GreedyWordSwapWIR("pwws", ascending=False) search_method = GreedyWordSwapWIR("pwws")
return Attack(goal_function, constraints, transformation, search_method) return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -27,7 +27,7 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"):
# Goal is non-overlapping output. # Goal is non-overlapping output.
# #
goal_function = NonOverlappingOutput(model) goal_function = NonOverlappingOutput(model)
# @TODO implement transformation / search method just like they do in # TODO implement transformation / search method just like they do in
# seq2sick. # seq2sick.
transformation = WordSwapEmbedding(max_candidates=50) transformation = WordSwapEmbedding(max_candidates=50)
# #
@@ -42,6 +42,6 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"):
# #
# Greedily swap words with "Word Importance Ranking". # Greedily swap words with "Word Importance Ranking".
# #
search_method = GreedyWordSwapWIR() search_method = GreedyWordSwapWIR(wir_method="unk")
return Attack(goal_function, constraints, transformation, search_method) return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -1,5 +1,6 @@
from textattack.constraints.grammaticality import PartOfSpeech from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.pre_transformation import ( from textattack.constraints.pre_transformation import (
InputColumnModification,
RepeatModification, RepeatModification,
StopwordModification, StopwordModification,
) )
@@ -35,6 +36,13 @@ def TextFoolerJin2019(model):
# fmt: on # fmt: on
constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)] constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)]
# #
# During entailment, we should only edit the hypothesis - keep the premise
# the same.
#
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
# Minimum word embedding cosine similarity of 0.5. # Minimum word embedding cosine similarity of 0.5.
# (The paper claims 0.7, but analysis of the released code and some empirical # (The paper claims 0.7, but analysis of the released code and some empirical
# results show that it's 0.5.) # results show that it's 0.5.)

View File

@@ -1,3 +1,4 @@
from .maximized_attack_result import MaximizedAttackResult
from .failed_attack_result import FailedAttackResult from .failed_attack_result import FailedAttackResult
from .skipped_attack_result import SkippedAttackResult from .skipped_attack_result import SkippedAttackResult
from .successful_attack_result import SuccessfulAttackResult from .successful_attack_result import SuccessfulAttackResult

View File

@@ -13,11 +13,11 @@ class AttackResult:
perturbed text. May or may not have been successful. perturbed text. May or may not have been successful.
""" """
def __init__(self, original_result, perturbed_result, num_queries=0): def __init__(self, original_result, perturbed_result):
if original_result is None: if original_result is None:
raise ValueError("Attack original result cannot be None") raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult): elif not isinstance(original_result, GoalFunctionResult):
raise TypeError(f"Invalid original goal function result: {original_text}") raise TypeError(f"Invalid original goal function result: {original_result}")
if perturbed_result is None: if perturbed_result is None:
raise ValueError("Attack perturbed result cannot be None") raise ValueError("Attack perturbed result cannot be None")
elif not isinstance(perturbed_result, GoalFunctionResult): elif not isinstance(perturbed_result, GoalFunctionResult):
@@ -27,7 +27,7 @@ class AttackResult:
self.original_result = original_result self.original_result = original_result
self.perturbed_result = perturbed_result self.perturbed_result = perturbed_result
self.num_queries = num_queries self.num_queries = perturbed_result.num_queries
# We don't want the AttackedText attributes sticking around clogging up # We don't want the AttackedText attributes sticking around clogging up
# space on our devices. Delete them here, if they're still present, # space on our devices. Delete them here, if they're still present,
@@ -89,27 +89,34 @@ class AttackResult:
i1 = 0 i1 = 0
i2 = 0 i2 = 0
while i1 < len(t1.words) and i2 < len(t2.words): while i1 < t1.num_words or i2 < t2.num_words:
# show deletions # show deletions
while t2.attack_attrs["original_index_map"][i1] == -1: while (
i1 < len(t2.attack_attrs["original_index_map"])
and t2.attack_attrs["original_index_map"][i1] == -1
):
words_1.append(utils.color_text(t1.words[i1], color_1, color_method)) words_1.append(utils.color_text(t1.words[i1], color_1, color_method))
words_1_idxs.append(i1) words_1_idxs.append(i1)
i1 += 1 i1 += 1
# show insertions # show insertions
while i2 < t2.attack_attrs["original_index_map"][i1]: while (
i1 < len(t2.attack_attrs["original_index_map"])
and i2 < t2.attack_attrs["original_index_map"][i1]
):
words_2.append(utils.color_text(t1.words[i2], color_2, color_method)) words_2.append(utils.color_text(t1.words[i2], color_2, color_method))
words_2_idxs.append(i2) words_2_idxs.append(i2)
i2 += 1 i2 += 1
# show swaps # show swaps
word_1 = t1.words[i1] if i1 < t1.num_words and i2 < t2.num_words:
word_2 = t2.words[i2] word_1 = t1.words[i1]
if word_1 != word_2: word_2 = t2.words[i2]
words_1.append(utils.color_text(word_1, color_1, color_method)) if word_1 != word_2:
words_2.append(utils.color_text(word_2, color_2, color_method)) words_1.append(utils.color_text(word_1, color_1, color_method))
words_1_idxs.append(i1) words_2.append(utils.color_text(word_2, color_2, color_method))
words_2_idxs.append(i2) words_1_idxs.append(i1)
i1 += 1 words_2_idxs.append(i2)
i2 += 1 i1 += 1
i2 += 1
t1 = self.original_result.attacked_text.replace_words_at_indices( t1 = self.original_result.attacked_text.replace_words_at_indices(
words_1_idxs, words_1 words_1_idxs, words_1

View File

@@ -6,9 +6,9 @@ from .attack_result import AttackResult
class FailedAttackResult(AttackResult): class FailedAttackResult(AttackResult):
"""The result of a failed attack.""" """The result of a failed attack."""
def __init__(self, original_result, perturbed_result=None, num_queries=0): def __init__(self, original_result, perturbed_result=None):
perturbed_result = perturbed_result or original_result perturbed_result = perturbed_result or original_result
super().__init__(original_result, perturbed_result, num_queries) super().__init__(original_result, perturbed_result)
def str_lines(self, color_method=None): def str_lines(self, color_method=None):
lines = ( lines = (

View File

@@ -0,0 +1,5 @@
from .attack_result import AttackResult
class MaximizedAttackResult(AttackResult):
""" The result of a successful attack. """

View File

@@ -1,2 +1,5 @@
from .attack_command import AttackCommand from .attack_command import AttackCommand
from .attack_resume_command import AttackResumeCommand from .attack_resume_command import AttackResumeCommand
from .run_attack_single_threaded import run as run_attack_single_threaded
from .run_attack_parallel import run as run_attack_parallel

View File

@@ -7,11 +7,13 @@ ATTACK_RECIPE_NAMES = {
"faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019", "faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018", "deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017", "hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
"input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
"kuleshov": "textattack.attack_recipes.Kuleshov2017", "kuleshov": "textattack.attack_recipes.Kuleshov2017",
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox", "seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
"textbugger": "textattack.attack_recipes.TextBuggerLi2018", "textbugger": "textattack.attack_recipes.TextBuggerLi2018",
"textfooler": "textattack.attack_recipes.TextFoolerJin2019", "textfooler": "textattack.attack_recipes.TextFoolerJin2019",
"pwws": "textattack.attack_recipes.PWWSRen2019", "pwws": "textattack.attack_recipes.PWWSRen2019",
"pso": "textattack.attack_recipes.PSOZang2020",
} }
# #
@@ -218,11 +220,22 @@ TEXTATTACK_DATASET_BY_MODEL = {
), ),
# #
# Translation models # Translation models
# TODO add proper `nlp` datasets for translation & summarization "t5-en-de": (
"english_to_german",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
),
"t5-en-fr": (
"english_to_french",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "fr"),
),
"t5-en-ro": (
"english_to_romanian",
("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"),
),
# #
# Summarization models # Summarization models
# #
#'t5-summ': 'textattack.models.summarization.T5Summarization', "t5-summarization": ("summarization", ("gigaword", None, "test")),
} }
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = { BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {

View File

@@ -332,7 +332,14 @@ def parse_dataset_from_args(args):
if args.model in HUGGINGFACE_DATASET_BY_MODEL: if args.model in HUGGINGFACE_DATASET_BY_MODEL:
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model] _, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
elif args.model in TEXTATTACK_DATASET_BY_MODEL: elif args.model in TEXTATTACK_DATASET_BY_MODEL:
_, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model] _, dataset = TEXTATTACK_DATASET_BY_MODEL[args.model]
if dataset[0].startswith("textattack"):
# unsavory way to pass custom dataset classes
# ex: dataset = ('textattack.datasets.translation.TedMultiTranslationDataset', 'en', 'de')
dataset = eval(f"{dataset[0]}")(*dataset[1:])
return dataset
else:
args.dataset_from_nlp = dataset
# Automatically detect dataset for models trained with textattack. # Automatically detect dataset for models trained with textattack.
elif args.model and os.path.exists(args.model): elif args.model and os.path.exists(args.model):
model_args_json_path = os.path.join(args.model, "train_args.json") model_args_json_path = os.path.join(args.model, "train_args.json")

View File

@@ -125,7 +125,10 @@ def run(args, checkpoint=None):
pbar.update() pbar.update()
num_results += 1 num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult: if (
type(result) == textattack.attack_results.SuccessfulAttackResult
or type(result) == textattack.attack_results.MaximizedAttackResult
):
num_successes += 1 num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult: if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1 num_failures += 1
@@ -170,6 +173,8 @@ def run(args, checkpoint=None):
finish_time = time.time() finish_time = time.time()
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s") textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
return attack_log_manager.results
def pytorch_multiprocessing_workaround(): def pytorch_multiprocessing_workaround():
# This is a fix for a known bug # This is a fix for a known bug

View File

@@ -108,7 +108,10 @@ def run(args, checkpoint=None):
num_results += 1 num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult: if (
type(result) == textattack.attack_results.SuccessfulAttackResult
or type(result) == textattack.attack_results.MaximizedAttackResult
):
num_successes += 1 num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult: if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1 num_failures += 1
@@ -139,6 +142,8 @@ def run(args, checkpoint=None):
finish_time = time.time() finish_time = time.time()
textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s") textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s")
return attack_log_manager.results
if __name__ == "__main__": if __name__ == "__main__":
run(get_args()) run(get_args())

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import textattack
from textattack.shared.utils import default_class_repr from textattack.shared.utils import default_class_repr
@@ -71,10 +72,10 @@ class Constraint(ABC):
transformed_text (AttackedText): The candidate transformed ``AttackedText``. transformed_text (AttackedText): The candidate transformed ``AttackedText``.
reference_text (AttackedText): The ``AttackedText`` to compare against. reference_text (AttackedText): The ``AttackedText`` to compare against.
""" """
if not isinstance(transformed_text, AttackedText): if not isinstance(transformed_text, textattack.shared.AttackedText):
raise TypeError("transformed_text must be of type AttackedText") raise TypeError("transformed_text must be of type AttackedText")
if not isinstance(reference_text, AttackedText): if not isinstance(current_text, textattack.shared.AttackedText):
raise TypeError("reference_text must be of type AttackedText") raise TypeError("current_text must be of type AttackedText")
try: try:
if not self.check_compatibility( if not self.check_compatibility(

View File

@@ -49,7 +49,7 @@ class GoogleLanguageModel(Constraint):
[t.words[word_swap_index] for t in transformed_texts] [t.words[word_swap_index] for t in transformed_texts]
) )
if self.print_step: if self.print_step:
print(prefix, swapped_words, suffix) print(prefix, swapped_words)
probs = self.lm.get_words_probs(prefix, swapped_words) probs = self.lm.get_words_probs(prefix, swapped_words)
return probs return probs

View File

@@ -52,7 +52,7 @@ class GPT2(LanguageModelConstraint):
probs = [] probs = []
for attacked_text in text_list: for attacked_text in text_list:
nxt_word_ids = self.tokenizer.encode(attacked_text.words[word_index]) next_word_ids = self.tokenizer.encode(attacked_text.words[word_index])
next_word_prob = predictions[0, -1, next_word_ids[0]] next_word_prob = predictions[0, -1, next_word_ids[0]]
probs.append(next_word_prob) probs.append(next_word_prob)

View File

@@ -3,7 +3,11 @@ from abc import abstractmethod
from textattack.constraints import Constraint from textattack.constraints import Constraint
<<<<<<< HEAD
class LanguageModelConstraint(Constraint): class LanguageModelConstraint(Constraint):
=======
class LanguageModelConstraint(Constraint, ABC):
>>>>>>> master
""" """
Determines if two sentences have a swapped word that has a similar Determines if two sentences have a swapped word that has a similar
probability according to a language model. probability according to a language model.

View File

@@ -11,8 +11,8 @@ from .language_model_helpers import QueryHandler
class LearningToWriteLanguageModel(LanguageModelConstraint): class LearningToWriteLanguageModel(LanguageModelConstraint):
""" A constraint based on the L2W language model. """ A constraint based on the L2W language model.
The RNN-based language model from ``Learning to Write With Cooperative The RNN-based language model from "Learning to Write With Cooperative
Discriminators'' (Holtzman et al, 2018). Discriminators" (Holtzman et al, 2018).
https://arxiv.org/pdf/1805.06087.pdf https://arxiv.org/pdf/1805.06087.pdf

View File

@@ -1,3 +1,11 @@
<<<<<<< HEAD
from .stopword_modification import StopwordModification from .stopword_modification import StopwordModification
from .repeat_modification import RepeatModification from .repeat_modification import RepeatModification
=======
from .pre_transformation_constraint import PreTransformationConstraint
from .input_column_modification import InputColumnModification
>>>>>>> master
from .max_word_index_modification import MaxWordIndexModification from .max_word_index_modification import MaxWordIndexModification
from .repeat_modification import RepeatModification
from .stopword_modification import StopwordModification

View File

@@ -0,0 +1,43 @@
from textattack.constraints.pre_transformation import PreTransformationConstraint
class InputColumnModification(PreTransformationConstraint):
"""
A constraint disallowing the modification of words within a specific input
column.
For example, can prevent modification of 'premise' during
entailment.
"""
def __init__(self, matching_column_labels, columns_to_ignore):
self.matching_column_labels = matching_column_labels
self.columns_to_ignore = columns_to_ignore
def _get_modifiable_indices(self, current_text):
""" Returns the word indices in current_text which are able to be
deleted.
If ``current_text.column_labels`` doesn't match
``self.matching_column_labels``, do nothing, and allow all words
to be modified.
If it does match, only allow words to be modified if they are not
in columns from ``columns_to_ignore``.
"""
if current_text.column_labels != self.matching_column_labels:
return set(range(len(current_text.words)))
idx = 0
indices_to_modify = set()
for column, words in zip(
current_text.column_labels, current_text.words_per_input
):
num_words = len(words)
if column not in self.columns_to_ignore:
indices_to_modify |= set(range(idx, idx + num_words))
idx += num_words
return indices_to_modify
def extra_repr_keys(self):
return ["matching_column_labels", "columns_to_ignore"]

View File

@@ -13,3 +13,6 @@ class MaxWordIndexModification(PreTransformationConstraint):
def _get_modifiable_indices(self, current_text): def _get_modifiable_indices(self, current_text):
""" Returns the word indices in current_text which are able to be deleted """ """ Returns the word indices in current_text which are able to be deleted """
return set(range(min(self.max_length, len(current_text.words)))) return set(range(min(self.max_length, len(current_text.words))))
def extra_repr_keys(self):
return ["max_length"]

View File

@@ -57,4 +57,9 @@ class PreTransformationConstraint(ABC):
""" """
return [] return []
__str__ = __repr__ = default_class_repr def _check_constraint(self):
raise RuntimeError(
"PreTransformationConstraints do not support `_check_constraint()`."
)
__str__ = __repr__ = default_class_repr

View File

@@ -73,7 +73,9 @@ class SentenceEncoder(Constraint):
The similarity between the starting and transformed text using the metric. The similarity between the starting and transformed text using the metric.
""" """
try: try:
modified_index = next(iter(x_adv.attack_attrs["newly_modified_indices"])) modified_index = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
except KeyError: except KeyError:
raise KeyError( raise KeyError(
"Cannot apply sentence encoder constraint without `newly_modified_indices`" "Cannot apply sentence encoder constraint without `newly_modified_indices`"
@@ -112,7 +114,7 @@ class SentenceEncoder(Constraint):
``transformed_texts``. If ``transformed_texts`` is empty, ``transformed_texts``. If ``transformed_texts`` is empty,
an empty tensor is returned an empty tensor is returned
""" """
# Return an empty tensor if x_adv_list is empty. # Return an empty tensor if transformed_texts is empty.
# This prevents us from calling .repeat(x, 0), which throws an # This prevents us from calling .repeat(x, 0), which throws an
# error on machines with multiple GPUs (pytorch 1.2). # error on machines with multiple GPUs (pytorch 1.2).
if len(transformed_texts) == 0: if len(transformed_texts) == 0:
@@ -142,9 +144,9 @@ class SentenceEncoder(Constraint):
) )
) )
embeddings = self.encode(starting_text_windows + transformed_text_windows) embeddings = self.encode(starting_text_windows + transformed_text_windows)
starting_embeddings = torch.tensor(embeddings[: len(transformed_texts)]).to( if not isinstance(embeddings, torch.Tensor):
utils.device embeddings = torch.tensor(embeddings)
) starting_embeddings = embeddings[: len(transformed_texts)].to(utils.device)
transformed_embeddings = torch.tensor( transformed_embeddings = torch.tensor(
embeddings[len(transformed_texts) :] embeddings[len(transformed_texts) :]
).to(utils.device) ).to(utils.device)
@@ -152,18 +154,12 @@ class SentenceEncoder(Constraint):
starting_raw_text = starting_text.text starting_raw_text = starting_text.text
transformed_raw_texts = [t.text for t in transformed_texts] transformed_raw_texts = [t.text for t in transformed_texts]
embeddings = self.encode([starting_raw_text] + transformed_raw_texts) embeddings = self.encode([starting_raw_text] + transformed_raw_texts)
if isinstance(embeddings[0], torch.Tensor): if not isinstance(embeddings, torch.Tensor):
starting_embedding = embeddings[0].to(utils.device) embeddings = torch.tensor(embeddings)
else:
# If the embedding is not yet a tensor, make it one.
starting_embedding = torch.tensor(embeddings[0]).to(utils.device)
if isinstance(embeddings, list): starting_embedding = embeddings[0].to(utils.device)
# If `encode` did not return a Tensor of all embeddings, combine
# into a tensor. transformed_embeddings = embeddings[1:].to(utils.device)
transformed_embeddings = torch.stack(embeddings[1:]).to(utils.device)
else:
transformed_embeddings = torch.tensor(embeddings[1:]).to(utils.device)
# Repeat original embedding to size of perturbed embedding. # Repeat original embedding to size of perturbed embedding.
starting_embeddings = starting_embedding.unsqueeze(dim=0).repeat( starting_embeddings = starting_embedding.unsqueeze(dim=0).repeat(

View File

@@ -36,7 +36,7 @@ class ThoughtVector(SentenceEncoder):
return torch.mean(embeddings, dim=0) return torch.mean(embeddings, dim=0)
def encode(self, raw_text_list): def encode(self, raw_text_list):
return [self._get_thought_vector(text) for text in raw_text_list] return torch.stack([self._get_thought_vector(text) for text in raw_text_list])
def extra_repr_keys(self): def extra_repr_keys(self):
"""Set the extra representation of the constraint using these keys. """Set the extra representation of the constraint using these keys.

View File

@@ -51,7 +51,7 @@ class WordEmbeddingDistance(Constraint):
mse_dist_file = "mse_dist.p" mse_dist_file = "mse_dist.p"
cos_sim_file = "cos_sim.p" cos_sim_file = "cos_sim.p"
else: else:
raise ValueError(f"Could not find word embedding {word_embedding}") raise ValueError(f"Could not find word embedding {embedding_type}")
# Download embeddings if they're not cached. # Download embeddings if they're not cached.
word_embeddings_path = utils.download_if_needed(WordEmbeddingDistance.PATH) word_embeddings_path = utils.download_if_needed(WordEmbeddingDistance.PATH)

View File

@@ -1,6 +1,4 @@
from .dataset import TextAttackDataset from .dataset import TextAttackDataset
from .huggingface_nlp_dataset import HuggingFaceNLPDataset from .huggingface_nlp_dataset import HuggingFaceNLPDataset
from . import classification
from . import entailment
from . import translation from . import translation

View File

@@ -1,2 +0,0 @@
from .ag_news import AGNews
from .kaggle_fake_news import KaggleFakeNews

View File

@@ -1,44 +0,0 @@
from textattack.shared import utils
from .classification_dataset import ClassificationDataset
class AGNews(ClassificationDataset):
"""
Loads samples from the AG News Dataset.
AG is a collection of more than 1 million news articles. News articles have
been gathered from more than 2000 news sources by ComeToMyHead in more than
1 year of activity. ComeToMyHead is an academic news search engine which has
been running since July, 2004. The dataset is provided by the academic
community for research purposes in data mining (clustering, classification,
etc), information retrieval (ranking, search, etc), xml, data compression,
data streaming, and any other non-commercial activity. For more information,
please refer to the link
http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html.
The AG's news topic classification dataset was constructed by Xiang Zhang
(xiang.zhang@nyu.edu) from the dataset above. It is used as a text
classification benchmark in the following paper: Xiang Zhang, Junbo Zhao,
Yann LeCun. Character-level Convolutional Networks for Text Classification.
Advances in Neural Information Processing Systems 28 (NIPS 2015).
Labels
0: World
1: Sports
2: Business
3: Sci/Tech
Args:
offset (int): line to start reading from
shuffle (bool): If True, randomly shuffle loaded data
"""
DATA_PATH = "datasets/classification/ag_news.txt"
def __init__(self, offset=0, shuffle=False):
""" Loads a full dataset from disk. """
self._load_classification_text_file(
AGNews.DATA_PATH, offset=offset, shuffle=shuffle
)

View File

@@ -1,13 +0,0 @@
from textattack.datasets import TextAttackDataset
class ClassificationDataset(TextAttackDataset):
"""
A generic class for loading classification data.
"""
def _process_example_from_file(self, raw_line):
tokens = raw_line.strip().split()
label = int(tokens[0])
text = " ".join(tokens[1:])
return (text, label)

View File

@@ -1,24 +0,0 @@
from .classification_dataset import ClassificationDataset
class KaggleFakeNews(ClassificationDataset):
"""
Loads samples from the Kaggle Fake News dataset. https://www.kaggle.com/mrisdal/fake-news
Labels
0: Real Article
1: Fake Article
Args:
offset (int): line to start reading from
shuffle (bool): If True, randomly shuffle loaded data
"""
DATA_PATH = "datasets/classification/fake"
def __init__(self, offset=0, shuffle=False):
""" Loads a full dataset from disk. """
self._load_classification_text_file(
KaggleFakeNews.DATA_PATH, offset=offset, shuffle=shuffle
)

View File

@@ -1 +0,0 @@
from .snli import SNLI

View File

@@ -1,38 +0,0 @@
import collections
from textattack.datasets import TextAttackDataset
from textattack.shared import AttackedText
class EntailmentDataset(TextAttackDataset):
"""
A generic class for loading entailment data.
Labels
0: Entailment
1: Neutral
2: Contradiction
"""
def _label_str_to_int(self, label_str):
if label_str == "entailment":
return 0
elif label_str == "neutral":
return 1
elif label_str == "contradiction":
return 2
else:
raise ValueError(f"Unknown entailment label {label_str}")
def _process_example_from_file(self, raw_line):
line = raw_line.strip()
label, premise, hypothesis = line.split("\t")
try:
label = int(label)
except ValueError:
# If the label is not an integer, it's a label description.
label = self._label_str_to_int(label)
text_input = collections.OrderedDict(
[("premise", premise), ("hypothesis", hypothesis),]
)
return (text_input, label)

View File

@@ -1,25 +0,0 @@
from .entailment_dataset import EntailmentDataset
class SNLI(EntailmentDataset):
"""
Loads samples from the SNLI dataset.
Labels
0: Entailment
1: Neutral
2: Contradiction
Args:
offset (int): line to start reading from
shuffle (bool): If True, randomly shuffle loaded data
"""
DATA_PATH = "datasets/entailment/snli"
def __init__(self, offset=0, shuffle=False):
""" Loads a full dataset from disk. """
self._load_classification_text_file(
SNLI.DATA_PATH, offset=offset, shuffle=shuffle
)

View File

@@ -35,6 +35,12 @@ def get_nlp_dataset_columns(dataset):
elif {"sentence", "label"} <= schema: elif {"sentence", "label"} <= schema:
input_columns = ("sentence",) input_columns = ("sentence",)
output_column = "label" output_column = "label"
elif {"document", "summary"} <= schema:
input_columns = ("document",)
output_column = "summary"
elif {"content", "summary"} <= schema:
input_columns = ("content",)
output_column = "summary"
else: else:
raise ValueError( raise ValueError(
f"Unsupported dataset schema {schema}. Try loading dataset manually (from a file) instead." f"Unsupported dataset schema {schema}. Try loading dataset manually (from a file) instead."
@@ -47,18 +53,17 @@ class HuggingFaceNLPDataset(TextAttackDataset):
""" Loads a dataset from HuggingFace ``nlp`` and prepares it as a """ Loads a dataset from HuggingFace ``nlp`` and prepares it as a
TextAttack dataset. TextAttack dataset.
name: the dataset name - name: the dataset name
subset: the subset of the main dataset. Dataset will be loaded as - subset: the subset of the main dataset. Dataset will be loaded as ``nlp.load_dataset(name, subset)``.
``nlp.load_dataset(name, subset)``. - label_map: Mapping if output labels should be re-mapped. Useful
label_map: Mapping if output labels should be re-mapped. Useful if model was trained with a different label arrangement than
if model was trained with a different label arrangement than provided in the ``nlp`` version of the dataset.
provided in the ``nlp`` version of the dataset. - output_scale_factor (float): Factor to divide ground-truth outputs by.
output_scale_factor (float): Factor to divide ground-truth outputs by.
Generally, TextAttack goal functions require model outputs Generally, TextAttack goal functions require model outputs
between 0 and 1. Some datasets test the model's *correlation* between 0 and 1. Some datasets test the model's \*correlation\*
with ground-truth output, instead of its accuracy, so these with ground-truth output, instead of its accuracy, so these
outputs may be scaled arbitrarily. outputs may be scaled arbitrarily.
shuffle (bool): Whether to shuffle the dataset on load. - shuffle (bool): Whether to shuffle the dataset on load.
""" """
@@ -72,6 +77,7 @@ class HuggingFaceNLPDataset(TextAttackDataset):
dataset_columns=None, dataset_columns=None,
shuffle=False, shuffle=False,
): ):
self._name = name
self._dataset = nlp.load_dataset(name, subset)[split] self._dataset = nlp.load_dataset(name, subset)[split]
subset_print_str = f", subset {_cb(subset)}" if subset else "" subset_print_str = f", subset {_cb(subset)}" if subset else ""
textattack.shared.logger.info( textattack.shared.logger.info(

View File

@@ -1 +1 @@
from .translation_datasets import * from .ted_multi import TedMultiTranslationDataset

View File

@@ -0,0 +1,38 @@
import collections
import nlp
import numpy as np
from textattack.datasets import HuggingFaceNLPDataset
class TedMultiTranslationDataset(HuggingFaceNLPDataset):
""" Loads examples from the Ted Talk translation dataset using the `nlp`
package.
dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/
"""
def __init__(self, source_lang="en", target_lang="de", split="test"):
self._dataset = nlp.load_dataset("ted_multi")[split]
self.examples = self._dataset["translations"]
language_options = set(self.examples[0]["language"])
if source_lang not in language_options:
raise ValueError(
f"Source language {source_lang} invalid. Choices: {sorted(language_options)}"
)
if target_lang not in language_options:
raise ValueError(
f"Target language {target_lang} invalid. Choices: {sorted(language_options)}"
)
self.source_lang = source_lang
self.target_lang = target_lang
self.label_names = ("Translation",)
def _format_raw_example(self, raw_example):
translations = np.array(raw_example["translation"])
languages = np.array(raw_example["language"])
source = translations[languages == self.source_lang][0]
target = translations[languages == self.target_lang][0]
source_dict = collections.OrderedDict([("Source", source)])
return (source_dict, target)

View File

@@ -1,24 +0,0 @@
from textattack.datasets import TextAttackDataset
class NewsTest2013EnglishToGerman(TextAttackDataset):
"""
Loads samples from newstest2013 dataset from the publicly available
WMT2016 translation task. (This is from the 'news' portion of
WMT2016. See http://www.statmt.org/wmt16/ for details.) Dataset
sourced from GluonNLP library.
Samples are loaded as (input, translation) tuples of string pairs.
Args:
offset (int): line to start reading from
shuffle (bool): If True, randomly shuffle loaded data
"""
DATA_PATH = "datasets/translation/NewsTest2013EnglishToGerman"
def __init__(self, offset=0, shuffle=False):
self._load_pickle_file(NewsTest2013EnglishToGerman.DATA_PATH, offset=offset)
if shuffle:
self._shuffle_data()

View File

@@ -1,4 +1,4 @@
from .goal_function_result import GoalFunctionResult from .goal_function_result import GoalFunctionResult, GoalFunctionResultStatus
from .classification_goal_function_result import ClassificationGoalFunctionResult from .classification_goal_function_result import ClassificationGoalFunctionResult
from .text_to_text_goal_function_result import TextToTextGoalFunctionResult from .text_to_text_goal_function_result import TextToTextGoalFunctionResult

View File

@@ -1,6 +1,13 @@
import torch import torch
class GoalFunctionResultStatus:
SUCCEEDED = 0
SEARCHING = 1 # In process of searching for a success
MAXIMIZING = 2
SKIPPED = 3
class GoalFunctionResult: class GoalFunctionResult:
""" """
Represents the result of a goal function evaluating a AttackedText object. Represents the result of a goal function evaluating a AttackedText object.
@@ -8,16 +15,29 @@ class GoalFunctionResult:
Args: Args:
attacked_text: The sequence that was evaluated. attacked_text: The sequence that was evaluated.
output: The display-friendly output. output: The display-friendly output.
succeeded: Whether the goal has been achieved. goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal.
score: A score representing how close the model is to achieving its goal. score: A score representing how close the model is to achieving its goal.
num_queries: How many model queries have been used
ground_truth_output: The ground truth output
""" """
def __init__(self, attacked_text, raw_output, output, succeeded, score): def __init__(
self,
attacked_text,
raw_output,
output,
goal_status,
score,
num_queries,
ground_truth_output,
):
self.attacked_text = attacked_text self.attacked_text = attacked_text
self.raw_output = raw_output self.raw_output = raw_output
self.output = output self.output = output
self.score = score self.score = score
self.succeeded = succeeded self.goal_status = goal_status
self.num_queries = num_queries
self.ground_truth_output = ground_truth_output
if isinstance(self.raw_output, torch.Tensor): if isinstance(self.raw_output, torch.Tensor):
self.raw_output = self.raw_output.cpu() self.raw_output = self.raw_output.cpu()
@@ -25,9 +45,6 @@ class GoalFunctionResult:
if isinstance(self.score, torch.Tensor): if isinstance(self.score, torch.Tensor):
self.score = self.score.item() self.score = self.score.item()
if isinstance(self.succeeded, torch.Tensor):
self.succeeded = self.succeeded.item()
def get_text_color_input(self): def get_text_color_input(self):
""" A string representing the color this result's changed """ A string representing the color this result's changed
portion should be if it represents the original input. portion should be if it represents the original input.

View File

@@ -1,2 +1,3 @@
from .input_reduction import InputReduction
from .untargeted_classification import UntargetedClassification from .untargeted_classification import UntargetedClassification
from .targeted_classification import TargetedClassification from .targeted_classification import TargetedClassification

View File

@@ -52,3 +52,6 @@ class ClassificationGoalFunction(GoalFunction):
def extra_repr_keys(self): def extra_repr_keys(self):
return [] return []
def _get_displayed_output(self, raw_output):
return int(raw_output.argmax())

View File

@@ -0,0 +1,44 @@
from .classification_goal_function import ClassificationGoalFunction
class InputReduction(ClassificationGoalFunction):
"""
Attempts to reduce the input down to as few words as possible while maintaining
the same predicted label.
From Feng, Wallace, Grissom, Iyyer, Rodriguez, Boyd-Graber. (2018).
Pathologies of Neural Models Make Interpretations Difficult.
ArXiv, abs/1804.07781.
"""
def __init__(self, *args, target_num_words=1, **kwargs):
self.target_num_words = target_num_words
super().__init__(*args, **kwargs)
def _is_goal_complete(self, model_output, attacked_text):
return (
self.ground_truth_output == model_output.argmax()
and attacked_text.num_words <= self.target_num_words
)
def _should_skip(self, model_output, attacked_text):
return self.ground_truth_output != model_output.argmax()
def _get_score(self, model_output, attacked_text):
# Give the lowest score possible to inputs which don't maintain the ground truth label.
if self.ground_truth_output != model_output.argmax():
return 0
cur_num_words = attacked_text.num_words
initial_num_words = self.initial_attacked_text.num_words
# The main goal is to reduce the number of words (num_words_score)
# Higher model score for the ground truth label is used as a tiebreaker (model_score)
num_words_score = max(
(initial_num_words - cur_num_words) / initial_num_words, 0
)
model_score = model_output[self.ground_truth_output]
return min(num_words_score + model_score / initial_num_words, 1)
def extra_repr_keys(self):
return ["target_num_words"]

View File

@@ -3,18 +3,18 @@ from .classification_goal_function import ClassificationGoalFunction
class TargetedClassification(ClassificationGoalFunction): class TargetedClassification(ClassificationGoalFunction):
""" """
An targeted attack on classification models which attempts to maximize the A targeted attack on classification models which attempts to maximize the
score of the target label until it is the predicted label. score of the target label. Complete when the arget label is the predicted label.
""" """
def __init__(self, model, target_class=0): def __init__(self, *args, target_class=0, **kwargs):
super().__init__(model) super().__init__(*args, **kwargs)
self.target_class = target_class self.target_class = target_class
def _is_goal_complete(self, model_output, ground_truth_output): def _is_goal_complete(self, model_output, _):
return ( return (
self.target_class == model_output.argmax() self.target_class == model_output.argmax()
) or ground_truth_output == self.target_class ) or self.ground_truth_output == self.target_class
def _get_score(self, model_output, _): def _get_score(self, model_output, _):
if self.target_class < 0 or self.target_class >= len(model_output): if self.target_class < 0 or self.target_class >= len(model_output):
@@ -24,8 +24,5 @@ class TargetedClassification(ClassificationGoalFunction):
else: else:
return model_output[self.target_class] return model_output[self.target_class]
def _get_displayed_output(self, raw_output):
return int(raw_output.argmax())
def extra_repr_keys(self): def extra_repr_keys(self):
return ["target_class"] return ["target_class"]

View File

@@ -16,23 +16,22 @@ class UntargetedClassification(ClassificationGoalFunction):
self.target_max_score = target_max_score self.target_max_score = target_max_score
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _is_goal_complete(self, model_output, ground_truth_output): def _is_goal_complete(self, model_output, _):
if self.target_max_score: if self.target_max_score:
return model_output[ground_truth_output] < self.target_max_score return model_output[self.ground_truth_output] < self.target_max_score
elif (model_output.numel() == 1) and isinstance(ground_truth_output, float): elif (model_output.numel() == 1) and isinstance(
return abs(ground_truth_output - model_output.item()) >= ( self.ground_truth_output, float
):
return abs(self.ground_truth_output - model_output.item()) >= (
self.target_max_score or 0.5 self.target_max_score or 0.5
) )
else: else:
return model_output.argmax() != ground_truth_output return model_output.argmax() != self.ground_truth_output
def _get_score(self, model_output, ground_truth_output): def _get_score(self, model_output, _):
# If the model outputs a single number and the ground truth output is # If the model outputs a single number and the ground truth output is
# a float, we assume that this is a regression task. # a float, we assume that this is a regression task.
if (model_output.numel() == 1) and isinstance(ground_truth_output, float): if (model_output.numel() == 1) and isinstance(self.ground_truth_output, float):
return abs(model_output.item() - ground_truth_output) return abs(model_output.item() - self.ground_truth_output)
else: else:
return 1 - model_output[ground_truth_output] return 1 - model_output[self.ground_truth_output]
def _get_displayed_output(self, raw_output):
return int(raw_output.argmax())

View File

@@ -1,19 +1,25 @@
from abc import ABC, abstractmethod
import math import math
import lru import lru
import numpy as np import numpy as np
import torch import torch
from textattack.goal_function_results.goal_function_result import (
GoalFunctionResultStatus,
)
from textattack.shared import utils, validators from textattack.shared import utils, validators
from textattack.shared.utils import batch_model_predict, default_class_repr from textattack.shared.utils import batch_model_predict, default_class_repr
class GoalFunction: class GoalFunction(ABC):
""" """
Evaluates how well a perturbed attacked_text object is achieving a specified goal. Evaluates how well a perturbed attacked_text object is achieving a specified goal.
Args: Args:
model: The model used for evaluation. model: The model used for evaluation.
maximizable: Whether the goal function is maximizable, as opposed to a boolean result
of success or failure.
query_budget (float): The maximum number of model queries allowed. query_budget (float): The maximum number of model queries allowed.
model_batch_size (int): The batch size for making calls to the model model_batch_size (int): The batch size for making calls to the model
model_cache_size (int): The maximum number of items to keep in the model model_cache_size (int): The maximum number of items to keep in the model
@@ -23,6 +29,7 @@ class GoalFunction:
def __init__( def __init__(
self, self,
model, model,
maximizable=False,
tokenizer=None, tokenizer=None,
use_cache=True, use_cache=True,
query_budget=float("inf"), query_budget=float("inf"),
@@ -33,6 +40,7 @@ class GoalFunction:
self.__class__, model.__class__ self.__class__, model.__class__
) )
self.model = model self.model = model
self.maximizable = maximizable
self.tokenizer = tokenizer self.tokenizer = tokenizer
if not self.tokenizer: if not self.tokenizer:
if hasattr(self.model, "tokenizer"): if hasattr(self.model, "tokenizer"):
@@ -42,7 +50,6 @@ class GoalFunction:
if not hasattr(self.tokenizer, "encode"): if not hasattr(self.tokenizer, "encode"):
raise TypeError("Tokenizer must contain `encode()` method") raise TypeError("Tokenizer must contain `encode()` method")
self.use_cache = use_cache self.use_cache = use_cache
self.num_queries = 0
self.query_budget = query_budget self.query_budget = query_budget
self.model_batch_size = model_batch_size self.model_batch_size = model_batch_size
if self.use_cache: if self.use_cache:
@@ -50,13 +57,16 @@ class GoalFunction:
else: else:
self._call_model_cache = None self._call_model_cache = None
def should_skip(self, attacked_text, ground_truth_output): def init_attack_example(self, attacked_text, ground_truth_output):
""" """
Returns whether or not the goal has already been completed for ``attacked_text``, Called before attacking ``attacked_text`` to 'reset' the goal
due to misprediction by the model. function and set properties for this example.
""" """
model_outputs = self._call_model([attacked_text]) self.initial_attacked_text = attacked_text
return self._is_goal_complete(model_outputs[0], ground_truth_output) self.ground_truth_output = ground_truth_output
self.num_queries = 0
result, _ = self.get_result(attacked_text, check_skip=True)
return result, _
def get_output(self, attacked_text): def get_output(self, attacked_text):
""" """
@@ -64,16 +74,16 @@ class GoalFunction:
""" """
return self._get_displayed_output(self._call_model([attacked_text])[0]) return self._get_displayed_output(self._call_model([attacked_text])[0])
def get_result(self, attacked_text, ground_truth_output): def get_result(self, attacked_text, **kwargs):
""" """
A helper method that queries `self.get_results` with a single A helper method that queries ``self.get_results`` with a single
``AttackedText`` object. ``AttackedText`` object.
""" """
results, search_over = self.get_results([attacked_text], ground_truth_output) results, search_over = self.get_results([attacked_text], **kwargs)
result = results[0] if len(results) else None result = results[0] if len(results) else None
return result, search_over return result, search_over
def get_results(self, attacked_text_list, ground_truth_output): def get_results(self, attacked_text_list, check_skip=False):
""" """
For each attacked_text object in attacked_text_list, returns a result For each attacked_text object in attacked_text_list, returns a result
consisting of whether or not the goal has been achieved, the output for consisting of whether or not the goal has been achieved, the output for
@@ -88,34 +98,55 @@ class GoalFunction:
model_outputs = self._call_model(attacked_text_list) model_outputs = self._call_model(attacked_text_list)
for attacked_text, raw_output in zip(attacked_text_list, model_outputs): for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
displayed_output = self._get_displayed_output(raw_output) displayed_output = self._get_displayed_output(raw_output)
succeeded = self._is_goal_complete(raw_output, ground_truth_output) goal_status = self._get_goal_status(
goal_function_score = self._get_score(raw_output, ground_truth_output) raw_output, attacked_text, check_skip=check_skip
)
goal_function_score = self._get_score(raw_output, attacked_text)
results.append( results.append(
self._goal_function_result_type()( self._goal_function_result_type()(
attacked_text, attacked_text,
raw_output, raw_output,
displayed_output, displayed_output,
succeeded, goal_status,
goal_function_score, goal_function_score,
self.num_queries,
self.ground_truth_output,
) )
) )
return results, self.num_queries == self.query_budget return results, self.num_queries == self.query_budget
def _is_goal_complete(self, model_output, ground_truth_output): def _get_goal_status(self, model_output, attacked_text, check_skip=False):
should_skip = check_skip and self._should_skip(model_output, attacked_text)
if should_skip:
return GoalFunctionResultStatus.SKIPPED
if self.maximizable:
return GoalFunctionResultStatus.MAXIMIZING
if self._is_goal_complete(model_output, attacked_text):
return GoalFunctionResultStatus.SUCCEEDED
return GoalFunctionResultStatus.SEARCHING
@abstractmethod
def _is_goal_complete(self, model_output, attacked_text):
raise NotImplementedError() raise NotImplementedError()
def _get_score(self, model_output, ground_truth_output): def _should_skip(self, model_output, attacked_text):
return self._is_goal_complete(model_output, attacked_text)
@abstractmethod
def _get_score(self, model_output, attacked_text):
raise NotImplementedError() raise NotImplementedError()
def _get_displayed_output(self, raw_output): def _get_displayed_output(self, raw_output):
return raw_output return raw_output
@abstractmethod
def _goal_function_result_type(self): def _goal_function_result_type(self):
""" """
Returns the class of this goal function's results. Returns the class of this goal function's results.
""" """
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def _process_model_outputs(self, inputs, outputs): def _process_model_outputs(self, inputs, outputs):
""" """
Processes and validates a list of model outputs. Processes and validates a list of model outputs.
@@ -142,7 +173,7 @@ class GoalFunction:
return self._process_model_outputs(attacked_text_list, outputs) return self._process_model_outputs(attacked_text_list, outputs)
def _call_model(self, attacked_text_list): def _call_model(self, attacked_text_list):
""" Gets predictions for a list of `AttackedText` objects. """ Gets predictions for a list of ``AttackedText`` objects.
Gets prediction from cache if possible. If prediction is not in the Gets prediction from cache if possible. If prediction is not in the
cache, queries model and stores prediction in cache. cache, queries model and stores prediction in cache.

View File

@@ -14,15 +14,15 @@ class NonOverlappingOutput(TextToTextGoalFunction):
Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3). Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3).
""" """
def _is_goal_complete(self, model_output, ground_truth_output): def _is_goal_complete(self, model_output, _):
return self._get_score(model_output, ground_truth_output) == 1.0 return self._get_score(model_output, self.ground_truth_output) == 1.0
def _get_score(self, model_output, ground_truth_output): def _get_score(self, model_output, _):
num_words_diff = word_difference_score(model_output, ground_truth_output) num_words_diff = word_difference_score(model_output, self.ground_truth_output)
if num_words_diff == 0: if num_words_diff == 0:
return 0.0 return 0.0
else: else:
return num_words_diff / len(get_words_cached(ground_truth_output)) return num_words_diff / len(get_words_cached(self.ground_truth_output))
@functools.lru_cache(maxsize=2 ** 12) @functools.lru_cache(maxsize=2 ** 12)

View File

@@ -9,9 +9,6 @@ class TextToTextGoalFunction(GoalFunction):
original_output: the original output of the model original_output: the original output of the model
""" """
def __init__(self, model):
super().__init__(model)
def _goal_function_result_type(self): def _goal_function_result_type(self):
""" Returns the class of this goal function's results. """ """ Returns the class of this goal function's results. """
return TextToTextGoalFunctionResult return TextToTextGoalFunctionResult

View File

@@ -20,9 +20,8 @@ class CSVLogger(Logger):
self._flushed = True self._flushed = True
def log_attack_result(self, result): def log_attack_result(self, result):
if isinstance(result, FailedAttackResult):
return
original_text, perturbed_text = result.diff_color(self.color_method) original_text, perturbed_text = result.diff_color(self.color_method)
result_type = result.__class__.__name__.replace("AttackResult", "")
row = { row = {
"original_text": original_text, "original_text": original_text,
"perturbed_text": perturbed_text, "perturbed_text": perturbed_text,
@@ -30,7 +29,9 @@ class CSVLogger(Logger):
"perturbed_score": result.perturbed_result.score, "perturbed_score": result.perturbed_result.score,
"original_output": result.original_result.output, "original_output": result.original_result.output,
"perturbed_output": result.perturbed_result.output, "perturbed_output": result.perturbed_result.output,
"ground_truth_output": result.original_result.ground_truth_output,
"num_queries": result.num_queries, "num_queries": result.num_queries,
"result_type": result_type,
} }
self.df = self.df.append(row, ignore_index=True) self.df = self.df.append(row, ignore_index=True)
self._flushed = False self._flushed = False

View File

@@ -31,7 +31,7 @@ class T5ForTextToText:
def __init__( def __init__(
self, mode="english_to_german", max_length=20, num_beams=1, early_stopping=True self, mode="english_to_german", max_length=20, num_beams=1, early_stopping=True
): ):
self.model = transformers.AutoModelWithLMHead.from_pretrained("t5-base") self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base")
self.model.to(utils.device) self.model.to(utils.device)
self.model.eval() self.model.eval()
self.tokenizer = T5Tokenizer(mode) self.tokenizer = T5Tokenizer(mode)

View File

@@ -18,11 +18,7 @@ class AutoTokenizer:
""" """
def __init__( def __init__(
self, self, name="bert-base-uncased", max_length=256, use_fast=True,
name="bert-base-uncased",
max_length=256,
pad_to_length=False,
use_fast=True,
): ):
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
name, use_fast=use_fast name, use_fast=use_fast
@@ -43,7 +39,7 @@ class AutoTokenizer:
*input_text, *input_text,
max_length=self.max_length, max_length=self.max_length,
add_special_tokens=True, add_special_tokens=True,
pad_to_max_length=True, padding="max_length",
truncation=True, truncation=True,
) )
return dict(encoded_text) return dict(encoded_text)
@@ -59,7 +55,7 @@ class AutoTokenizer:
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
add_special_tokens=True, add_special_tokens=True,
pad_to_max_length=True, padding="max_length",
) )
# Encodings is a `transformers.utils.BatchEncode` object, which # Encodings is a `transformers.utils.BatchEncode` object, which
# is basically a big dictionary that contains a key for all input # is basically a big dictionary that contains a key for all input

View File

@@ -62,7 +62,11 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer):
normalizers = [] normalizers = []
if unicode_normalizer: if unicode_normalizer:
normalizers += [unicode_normalizer_from_str(unicode_normalizer)] normalizers += [
hf_tokenizers.normalizers.unicode_normalizer_from_str(
unicode_normalizer
)
]
if lowercase: if lowercase:
normalizers += [hf_tokenizers.normalizers.Lowercase()] normalizers += [hf_tokenizers.normalizers.Lowercase()]

View File

@@ -11,10 +11,10 @@ class T5Tokenizer(AutoTokenizer):
Supports the following modes: Supports the following modes:
* summarization: summarize English text (CNN/Daily Mail dataset) * summarization: summarize English text
* english_to_german: translate English to German (WMT dataset) * english_to_german: translate English to German
* english_to_french: translate English to French (WMT dataset) * english_to_french: translate English to French
* english_to_romanian: translate English to Romanian (WMT dataset) * english_to_romanian: translate English to Romanian
""" """
@@ -28,7 +28,7 @@ class T5Tokenizer(AutoTokenizer):
elif mode == "summarization": elif mode == "summarization":
self.tokenization_prefix = "summarize: " self.tokenization_prefix = "summarize: "
else: else:
raise ValueError(f"Invalid t5 tokenizer mode {english_to_german}.") raise ValueError(f"Invalid t5 tokenizer mode {mode}.")
super().__init__(name="t5-base", max_length=max_length) super().__init__(name="t5-base", max_length=max_length)
@@ -38,12 +38,29 @@ class T5Tokenizer(AutoTokenizer):
passed into T5. passed into T5.
""" """
if isinstance(text, tuple): if isinstance(text, tuple):
if len(text) > 1:
raise ValueError(
f"T5Tokenizer tuple inputs must have length 1; got {len(text)}"
)
text = text[0] text = text[0]
if not isinstance(text, str): if not isinstance(text, str):
raise TypeError(f"T5Tokenizer expects `str` input, got {type(text)}") raise TypeError(f"T5Tokenizer expects `str` input, got {type(text)}")
text_to_encode = self.tokenization_prefix + text text_to_encode = self.tokenization_prefix + text
return super().encode(text_to_encode) return super().encode(text_to_encode)
def batch_encode(self, input_text_list):
new_input_text_list = []
for text in input_text_list:
if isinstance(text, tuple):
if len(text) > 1:
raise ValueError(
f"T5Tokenizer tuple inputs must have length 1; got {len(text)}"
)
text = text[0]
new_input_text_list.append(self.tokenization_prefix + text)
return super().batch_encode(new_input_text_list)
def decode(self, ids): def decode(self, ids):
""" """
Converts IDs (typically generated by the model) back to a string. Converts IDs (typically generated by the model) back to a string.

View File

@@ -0,0 +1,274 @@
"""
Reimplementation of search method from Word-level Textual Adversarial Attacking as Combinatorial Optimization
by Zang et. al
`<https://www.aclweb.org/anthology/2020.acl-main.540.pdf>`_
`<https://github.com/thunlp/SememePSO-Attack>`_
"""
from copy import deepcopy
import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
class PSOAlgorithm(SearchMethod):
"""
Attacks a model with word substiutitions using a Particle Swarm Optimization (PSO) algorithm.
Some key hyper-parameters are setup according to the original paper:
"We adjust PSO on the validation set of SST and set ω_1 as 0.8 and ω_2 as 0.2.
We set the max velocity of the particles V_{max} to 3, which means the changing
probability of the particles ranges from 0.047 (sigmoid(-3)) to 0.953 (sigmoid(3))."
Args:
pop_size (:obj:`int`, optional): The population size. Defauls to 60.
max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 20.
"""
def __init__(
self, pop_size=60, max_iters=20,
):
self.max_iters = max_iters
self.pop_size = pop_size
self.search_over = False
self.Omega_1 = 0.8
self.Omega_2 = 0.2
self.C1_origin = 0.8
self.C2_origin = 0.2
self.V_max = 3.0
def _generate_population(self, x_orig, neighbors_list, neighbors_len):
h_score, w_list = self._gen_h_score(x_orig, neighbors_len, neighbors_list)
return [self._mutate(x_orig, h_score, w_list) for _ in range(self.pop_size)]
def _mutate(self, x_cur, w_select_probs, w_list):
rand_idx = np.random.choice(len(w_select_probs), 1, p=w_select_probs)[0]
return x_cur.replace_word_at_index(rand_idx, w_list[rand_idx])
def _gen_h_score(self, x, neighbors_len, neighbors_list):
w_list = []
prob_list = []
for i, orig_w in enumerate(x.words):
if neighbors_len[i] == 0:
w_list.append(orig_w)
prob_list.append(0)
continue
p, w = self._gen_most_change(x, i, neighbors_list[i])
w_list.append(w)
prob_list.append(p)
prob_list = self._norm(prob_list)
h_score = prob_list
h_score = np.array(h_score)
return h_score, w_list
def _norm(self, n):
tn = []
for i in n:
if i <= 0:
tn.append(0)
else:
tn.append(i)
s = np.sum(tn)
if s == 0:
for i in range(len(tn)):
tn[i] = 1
return [t / len(tn) for t in tn]
new_n = [t / s for t in tn]
return new_n
# for un-targeted attacking
def _gen_most_change(self, x_cur, pos, replace_list):
orig_result, self.search_over = self.get_goal_results([x_cur])
if self.search_over:
return 0, x_cur.words[pos]
new_x_list = [x_cur.replace_word_at_index(pos, w) for w in replace_list]
# new_x_list = self.get_transformations(
# x_cur,
# original_text=self.original_attacked_text,
# indices_to_modify=[pos],
# )
new_x_results, self.search_over = self.get_goal_results(new_x_list)
new_x_scores = np.array([r.score for r in new_x_results])
new_x_scores = (
new_x_scores - orig_result[0].score
) # minimize the score of ground truth
if len(new_x_scores):
return (
np.max(new_x_scores),
new_x_list[np.argsort(new_x_scores)[-1]].words[pos],
)
else:
return 0, x_cur.words[pos]
def _get_neighbors_list(self, attacked_text):
"""
Generates this neighbors_len list
Args:
attacked_text: The original text
Returns:
A list of number of candidate neighbors for each word
"""
words = attacked_text.words
neighbors_list = [[] for _ in range(len(words))]
transformations = self.get_transformations(
attacked_text, original_text=self.original_attacked_text
)
for transformed_text in transformations:
try:
diff_idx = attacked_text.first_word_diff_index(transformed_text)
neighbors_list[diff_idx].append(transformed_text.words[diff_idx])
except:
assert len(attacked_text.words) == len(transformed_text.words)
assert all(
[
w1 == w2
for w1, w2 in zip(attacked_text.words, transformed_text.words)
]
)
neighbors_list = [np.array(x) for x in neighbors_list]
return neighbors_list
def _equal(self, a, b):
if a == b:
return -self.V_max
else:
return self.V_max
def _turn(self, x1, x2, prob, x_len):
indices_to_replace = []
words_to_replace = []
x2_words = x2.words
for i in range(x_len):
if np.random.uniform() < prob[i]:
indices_to_replace.append(i)
words_to_replace.append(x2_words[i])
new_text = x1.replace_words_at_indices(indices_to_replace, words_to_replace)
return new_text
def _count_change_ratio(self, x1, x2, x_len):
change_ratio = float(np.sum(x1.words != x2.words)) / float(x_len)
return change_ratio
def _sigmoid(self, n):
return 1 / (1 + np.exp(-n))
def _perform_search(self, initial_result):
self.original_attacked_text = initial_result.attacked_text
x_len = len(self.original_attacked_text.words)
self.correct_output = initial_result.output
# get word substitute candidates and generate population
neighbors_list = self._get_neighbors_list(self.original_attacked_text)
neighbors_len = [len(x) for x in neighbors_list]
pop = self._generate_population(
self.original_attacked_text, neighbors_list, neighbors_len
)
# test population against target model
pop_results, self.search_over = self.get_goal_results(pop)
if self.search_over:
return max(pop_results, key=lambda x: x.score)
pop_scores = np.array([r.score for r in pop_results])
# rank the scores from low to high and check if there is a successful attack
part_elites = deepcopy(pop)
part_elites_scores = pop_scores
top_attack = np.argmax(pop_scores)
all_elite = pop[top_attack]
all_elite_score = pop_scores[top_attack]
if pop_results[top_attack].goal_status == GoalFunctionResultStatus.SUCCEEDED:
return pop_results[top_attack]
# set up hyper-parameters
V = np.random.uniform(-self.V_max, self.V_max, self.pop_size)
V_P = [[V[t] for _ in range(x_len)] for t in range(self.pop_size)]
# start iterations
for i in range(self.max_iters):
Omega = (self.Omega_1 - self.Omega_2) * (
self.max_iters - i
) / self.max_iters + self.Omega_2
C1 = self.C1_origin - i / self.max_iters * (self.C1_origin - self.C2_origin)
C2 = self.C2_origin + i / self.max_iters * (self.C1_origin - self.C2_origin)
P1 = C1
P2 = C2
all_elite_words = all_elite.words
for id in range(self.pop_size):
# calculate the probability of turning each word
pop_words = pop[id].words
part_elites_words = part_elites[id].words
for dim in range(x_len):
V_P[id][dim] = Omega * V_P[id][dim] + (1 - Omega) * (
self._equal(pop_words[dim], part_elites_words[dim])
+ self._equal(pop_words[dim], all_elite_words[dim])
)
turn_prob = [self._sigmoid(V_P[id][d]) for d in range(x_len)]
if np.random.uniform() < P1:
pop[id] = self._turn(part_elites[id], pop[id], turn_prob, x_len)
if np.random.uniform() < P2:
pop[id] = self._turn(all_elite, pop[id], turn_prob, x_len)
# check if there is any successful attack in the current population
pop_results, self.search_over = self.get_goal_results(pop)
if self.search_over:
return max(pop_results, key=lambda x: x.score)
pop_scores = np.array([r.score for r in pop_results])
top_attack = np.argmax(pop_scores)
if (
pop_results[top_attack].goal_status
== GoalFunctionResultStatus.SUCCEEDED
):
return pop_results[top_attack]
# mutation based on the current change rate
new_pop = []
for x in pop:
change_ratio = self._count_change_ratio(
x, self.original_attacked_text, x_len
)
p_change = (
1 - 2 * change_ratio
) # referred from the original source code
if np.random.uniform() < p_change:
new_h, new_w_list = self._gen_h_score(
x, neighbors_len, neighbors_list
)
new_pop.append(self._mutate(x, new_h, new_w_list))
else:
new_pop.append(x)
pop = new_pop
# check if there is any successful attack in the current population
pop_results, self.search_over = self.get_goal_results(pop)
if self.search_over:
return max(pop_results, key=lambda x: x.score)
pop_scores = np.array([r.score for r in pop_results])
top_attack = np.argmax(pop_scores)
if (
pop_results[top_attack].goal_status
== GoalFunctionResultStatus.SUCCEEDED
):
return pop_results[top_attack]
# update the elite if the score is increased
for k in range(self.pop_size):
if pop_scores[k] > part_elites_scores[k]:
part_elites[k] = pop[k]
part_elites_scores[k] = pop_scores[k]
if pop_scores[top_attack] > all_elite_score:
all_elite = pop[top_attack]
all_elite_score = pop_scores[top_attack]
return initial_result

View File

@@ -3,3 +3,4 @@ from .beam_search import BeamSearch
from .greedy_search import GreedySearch from .greedy_search import GreedySearch
from .greedy_word_swap_wir import GreedyWordSwapWIR from .greedy_word_swap_wir import GreedyWordSwapWIR
from .genetic_algorithm import GeneticAlgorithm from .genetic_algorithm import GeneticAlgorithm
from .PSO_algorithm import PSOAlgorithm

View File

@@ -1,5 +1,6 @@
import numpy as np import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod from textattack.search_methods import SearchMethod
@@ -21,7 +22,7 @@ class BeamSearch(SearchMethod):
def _perform_search(self, initial_result): def _perform_search(self, initial_result):
beam = [initial_result.attacked_text] beam = [initial_result.attacked_text]
best_result = initial_result best_result = initial_result
while not best_result.succeeded: while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
potential_next_beam = [] potential_next_beam = []
for text in beam: for text in beam:
transformations = self.get_transformations( transformations = self.get_transformations(
@@ -32,9 +33,7 @@ class BeamSearch(SearchMethod):
if len(potential_next_beam) == 0: if len(potential_next_beam) == 0:
# If we did not find any possible perturbations, give up. # If we did not find any possible perturbations, give up.
return best_result return best_result
results, search_over = self.get_goal_results( results, search_over = self.get_goal_results(potential_next_beam)
potential_next_beam, initial_result.output
)
scores = np.array([r.score for r in results]) scores = np.array([r.score for r in results])
best_result = results[scores.argmax()] best_result = results[scores.argmax()]
if search_over: if search_over:

View File

@@ -10,6 +10,7 @@ from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps from textattack.shared.validators import transformation_consists_of_word_swaps
@@ -19,167 +20,209 @@ class GeneticAlgorithm(SearchMethod):
Attacks a model with word substiutitions using a genetic algorithm. Attacks a model with word substiutitions using a genetic algorithm.
Args: Args:
pop_size (:obj:`int`, optional): The population size. Defauls to 20. pop_size (int): The population size. Defaults to 20.
max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 50. max_iters (int): The maximum number of iterations to use. Defaults to 50.
temp (float): Temperature for softmax function used to normalize probability dist when sampling parents.
Higher temperature increases the sensitivity to lower probability candidates.
give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found.
max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints.
Setting it to 0 means we immediately take one of the parents at random as the child.
""" """
def __init__( def __init__(
self, pop_size=20, max_iters=50, temp=0.3, give_up_if_no_improvement=False self,
pop_size=20,
max_iters=50,
temp=0.3,
give_up_if_no_improvement=False,
max_crossover_retries=20,
): ):
self.max_iters = max_iters self.max_iters = max_iters
self.pop_size = pop_size self.pop_size = pop_size
self.temp = temp self.temp = temp
self.give_up_if_no_improvement = give_up_if_no_improvement self.give_up_if_no_improvement = give_up_if_no_improvement
self.search_over = False self.max_crossover_retries = max_crossover_retries
def _replace_at_index(self, pop_member, idx): # internal flag to indicate if search should end immediately
self._search_over = False
def _perturb(self, pop_member, original_result):
""" """
Select the best replacement for word at position (idx) Replaces a word in pop_member that has not been modified in place.
in (pop_member) to maximize score.
Args: Args:
pop_member: The population member being perturbed. pop_member (PopulationMember): The population member being perturbed.
idx: The index at which to replace a word. original_result (GoalFunctionResult): Result of original sample being attacked
Returns: Returns: None
Whether a replacement which increased the score was found.
""" """
transformations = self.get_transformations( num_words = pop_member.num_candidates_per_word.shape[0]
pop_member.attacked_text, num_candidates_per_word = np.copy(pop_member.num_candidates_per_word)
original_text=self.original_attacked_text, non_zero_indices = np.count_nonzero(num_candidates_per_word)
indices_to_modify=[idx],
)
if not len(transformations):
return False
orig_result, self.search_over = self.get_goal_results(
[pop_member.attacked_text], self.correct_output
)
if self.search_over:
return False
new_x_results, self.search_over = self.get_goal_results(
transformations, self.correct_output
)
new_x_scores = torch.Tensor([r.score for r in new_x_results])
new_x_scores = new_x_scores - orig_result[0].score
if len(new_x_scores) and new_x_scores.max() > 0:
pop_member.attacked_text = transformations[new_x_scores.argmax()]
return True
return False
def _perturb(self, pop_member):
"""
Replaces a word in pop_member that has not been modified.
Args:
pop_member: The population member being perturbed.
"""
x_len = pop_member.neighbors_len.shape[0]
neighbors_len = deepcopy(pop_member.neighbors_len)
non_zero_indices = np.sum(np.sign(pop_member.neighbors_len))
if non_zero_indices == 0: if non_zero_indices == 0:
return return
iterations = 0 iterations = 0
while iterations < non_zero_indices and not self.search_over: while iterations < non_zero_indices:
w_select_probs = neighbors_len / np.sum(neighbors_len) w_select_probs = num_candidates_per_word / np.sum(num_candidates_per_word)
rand_idx = np.random.choice(x_len, 1, p=w_select_probs)[0] rand_idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
if self._replace_at_index(pop_member, rand_idx):
pop_member.neighbors_len[rand_idx] = 0 transformations = self.get_transformations(
pop_member.attacked_text,
original_text=original_result.attacked_text,
indices_to_modify=[rand_idx],
)
if not len(transformations):
iterations += 1
continue
new_results, self._search_over = self.get_goal_results(transformations)
if self._search_over:
break break
neighbors_len[rand_idx] = 0
diff_scores = (
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
)
if len(diff_scores) and diff_scores.max() > 0:
idx = diff_scores.argmax()
pop_member.attacked_text = transformations[idx]
pop_member.num_candidates_per_word[rand_idx] = 0
pop_member.results = new_results[idx]
break
num_candidates_per_word[rand_idx] = 0
iterations += 1 iterations += 1
def _generate_population(self, neighbors_len, initial_result): def _crossover(self, pop_member1, pop_member2, original_result):
"""
Generates a population of texts each with one word replaced
Args:
neighbors_len: A list of the number of candidate neighbors for each word.
initial_result: The result to instantiate the population with
Returns:
The population.
"""
pop = []
for _ in range(self.pop_size):
pop_member = PopulationMember(
self.original_attacked_text, deepcopy(neighbors_len), initial_result
)
self._perturb(pop_member)
pop.append(pop_member)
return pop
def _crossover(self, pop_member1, pop_member2):
""" """
Generates a crossover between pop_member1 and pop_member2. Generates a crossover between pop_member1 and pop_member2.
If the child fails to satisfy the constraits, we re-try crossover for a fix number of times,
before taking one of the parents at random as the resulting child.
Args: Args:
pop_member1: The first population member. pop_member1 (PopulationMember): The first population member.
pop_member2: The second population member. pop_member2 (PopulationMember): The second population member.
Returns: Returns:
A population member containing the crossover. A population member containing the crossover.
""" """
indices_to_replace = []
words_to_replace = []
x1_text = pop_member1.attacked_text x1_text = pop_member1.attacked_text
x2_words = pop_member2.attacked_text.words x2_text = pop_member2.attacked_text
new_neighbors_len = deepcopy(pop_member1.neighbors_len) x2_words = x2_text.words
for i in range(len(x1_text.words)):
if np.random.uniform() < 0.5:
indices_to_replace.append(i)
words_to_replace.append(x2_words[i])
new_neighbors_len[i] = pop_member2.neighbors_len[i]
new_text = x1_text.replace_words_at_indices(
indices_to_replace, words_to_replace
)
return PopulationMember(new_text, deepcopy(new_neighbors_len))
def _get_neighbors_len(self, attacked_text): num_tries = 0
passed_constraints = False
while num_tries < self.max_crossover_retries + 1:
indices_to_replace = []
words_to_replace = []
num_candidates_per_word = np.copy(pop_member1.num_candidates_per_word)
for i in range(len(x1_text.words)):
if np.random.uniform() < 0.5:
indices_to_replace.append(i)
words_to_replace.append(x2_words[i])
num_candidates_per_word[i] = pop_member2.num_candidates_per_word[i]
new_text = x1_text.replace_words_at_indices(
indices_to_replace, words_to_replace
)
if "last_transformation" in x1_text.attack_attrs:
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[
"last_transformation"
]
filtered = self.filter_transformations(
[new_text], x1_text, original_text=original_result.attacked_text
)
elif "last_transformation" in x2_text.attack_attrs:
new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[
"last_transformation"
]
filtered = self.filter_transformations(
[new_text], x1_text, original_text=original_result.attacked_text
)
else:
# In this case, neither x_1 nor x_2 has been transformed,
# meaning that new_text == original_text
filtered = [new_text]
if filtered:
new_text = filtered[0]
passed_constraints = True
break
num_tries += 1
if not passed_constraints:
# If we cannot find a child that passes the constraints,
# we just randomly pick one of the parents to be the child for the next iteration.
new_text = (
pop_member1.attacked_text
if np.random.uniform() < 0.5
else pop_member2.attacked_text
)
new_results, self._search_over = self.get_goal_results([new_text])
return PopulationMember(new_text, num_candidates_per_word, new_results[0])
def _initialize_population(self, initial_result):
""" """
Generates this neighbors_len list Initialize a population of texts each with one word replaced
Args: Args:
attacked_text: The original text initial_result (GoalFunctionResult): The result to instantiate the population with
Returns: Returns:
A list of number of candidate neighbors for each word The population.
""" """
words = attacked_text.words words = initial_result.attacked_text.words
neighbors_list = [[] for _ in range(len(words))] num_candidates_per_word = np.zeros(len(words))
transformations = self.get_transformations( transformations = self.get_transformations(
attacked_text, original_text=self.original_attacked_text initial_result.attacked_text, original_text=initial_result.attacked_text
) )
for transformed_text in transformations: for transformed_text in transformations:
diff_idx = attacked_text.first_word_diff_index(transformed_text) diff_idx = initial_result.attacked_text.first_word_diff_index(
neighbors_list[diff_idx].append(transformed_text.words[diff_idx]) transformed_text
neighbors_list = [np.array(x) for x in neighbors_list] )
neighbors_len = np.array([len(x) for x in neighbors_list]) num_candidates_per_word[diff_idx] += 1
return neighbors_len
# Just b/c there are no candidates now doesn't mean we never want to select the word for perturbation
# Therefore, we give small non-zero probability for words with no candidates
# Epsilon is some small number to approximately assign 1% probability
num_total_candidates = np.sum(num_candidates_per_word)
epsilon = max(1, int(num_total_candidates * 0.01))
for i in range(len(num_candidates_per_word)):
if num_candidates_per_word[i] == 0:
num_candidates_per_word[i] = epsilon
population = []
for _ in range(self.pop_size):
pop_member = PopulationMember(
initial_result.attacked_text,
np.copy(num_candidates_per_word),
initial_result,
)
# Perturb `pop_member` in-place
self._perturb(pop_member, initial_result)
population.append(pop_member)
return population
def _perform_search(self, initial_result): def _perform_search(self, initial_result):
self.original_attacked_text = initial_result.attacked_text self._search_over = False
self.correct_output = initial_result.output population = self._initialize_population(initial_result)
neighbors_len = self._get_neighbors_len(self.original_attacked_text) current_score = initial_result.score
pop = self._generate_population(neighbors_len, initial_result)
cur_score = initial_result.score
for i in range(self.max_iters): for i in range(self.max_iters):
pop_results, self.search_over = self.get_goal_results( population = sorted(population, key=lambda x: x.result.score, reverse=True)
[pm.attacked_text for pm in pop], self.correct_output if (
) self._search_over
if self.search_over: or population[0].result.goal_status
if not len(pop_results): == GoalFunctionResultStatus.SUCCEEDED
return pop[0].result ):
return max(pop_results, key=lambda x: x.score) break
for idx, result in enumerate(pop_results):
pop[idx].result = pop_results[idx]
pop = sorted(pop, key=lambda x: -x.result.score)
pop_scores = torch.Tensor([r.score for r in pop_results]) if population[0].result.score > current_score:
logits = ((-pop_scores) / self.temp).exp() current_score = population[0].result.score
select_probs = (logits / logits.sum()).cpu().numpy()
if pop[0].result.succeeded:
return pop[0].result
if pop[0].result.score > cur_score:
cur_score = pop[0].result.score
elif self.give_up_if_no_improvement: elif self.give_up_if_no_improvement:
break break
elite = [pop[0]] pop_scores = torch.Tensor([pm.result.score for pm in population])
logits = ((-pop_scores) / self.temp).exp()
select_probs = (logits / logits.sum()).cpu().numpy()
parent1_idx = np.random.choice( parent1_idx = np.random.choice(
self.pop_size, size=self.pop_size - 1, p=select_probs self.pop_size, size=self.pop_size - 1, p=select_probs
) )
@@ -187,16 +230,27 @@ class GeneticAlgorithm(SearchMethod):
self.pop_size, size=self.pop_size - 1, p=select_probs self.pop_size, size=self.pop_size - 1, p=select_probs
) )
children = [ children = []
self._crossover(pop[parent1_idx[idx]], pop[parent2_idx[idx]]) for idx in range(self.pop_size - 1):
for idx in range(self.pop_size - 1) child = self._crossover(
] population[parent1_idx[idx]],
for c in children: population[parent2_idx[idx]],
self._perturb(c) initial_result,
)
if self._search_over:
break
pop = elite + children self._perturb(child, initial_result)
children.append(child)
return pop[0].result # We need two `search_over` checks b/c value might change both in
# `crossover` method and `perturb` method.
if self._search_over:
break
population = [population[0]] + children
return population[0].result
def check_transformation_compatibility(self, transformation): def check_transformation_compatibility(self, transformation):
""" """
@@ -214,10 +268,10 @@ class PopulationMember:
Args: Args:
attacked_text: The ``AttackedText`` of the population member. attacked_text: The ``AttackedText`` of the population member.
neighbors_len: A list of the number of candidate neighbors list for each word. num_candidates_per_word (numpy.array): A list of the number of candidate neighbors list for each word.
""" """
def __init__(self, attacked_text, neighbors_len, result=None): def __init__(self, attacked_text, num_candidates_per_word, result):
self.attacked_text = attacked_text self.attacked_text = attacked_text
self.neighbors_len = neighbors_len self.num_candidates_per_word = num_candidates_per_word
self.result = result self.result = result

View File

@@ -10,8 +10,11 @@ import numpy as np
import torch import torch
from torch.nn.functional import softmax from torch.nn.functional import softmax
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps from textattack.shared.validators import (
transformation_consists_of_word_swaps_and_deletions,
)
class GreedyWordSwapWIR(SearchMethod): class GreedyWordSwapWIR(SearchMethod):
@@ -20,22 +23,23 @@ class GreedyWordSwapWIR(SearchMethod):
order of index, after ranking indices by importance. order of index, after ranking indices by importance.
Args: Args:
<<<<<<< HEAD
wir_method (str): Method for ranking most important words. Available choices: `unk`, `delete`, `pwws`, and `random`. wir_method (str): Method for ranking most important words. Available choices: `unk`, `delete`, `pwws`, and `random`.
ascending (bool): if True, ranks words from least-to-most important. (Default ascending (bool): if True, ranks words from least-to-most important. (Default
ranking shows the most important word first.) ranking shows the most important word first.)
=======
wir_method: method for ranking most important words
>>>>>>> master
""" """
def __init__(self, wir_method="unk", ascending=False): def __init__(self, wir_method="unk"):
self.wir_method = wir_method self.wir_method = wir_method
self.ascending = ascending
def _get_index_order(self, initial_result, texts): def _get_index_order(self, initial_result, texts):
""" Queries model for list of attacked text objects ``text`` and """ Queries model for list of attacked text objects ``text`` and
ranks in order of descending score. ranks in order of descending score.
""" """
leave_one_results, search_over = self.get_goal_results( leave_one_results, search_over = self.get_goal_results(texts)
texts, initial_result.output
)
leave_one_scores = np.array([result.score for result in leave_one_results]) leave_one_scores = np.array([result.score for result in leave_one_results])
return leave_one_scores, search_over return leave_one_scores, search_over
@@ -98,10 +102,7 @@ class GreedyWordSwapWIR(SearchMethod):
search_over = False search_over = False
if self.wir_method != "random": if self.wir_method != "random":
if self.ascending: index_order = (-leave_one_scores).argsort()
index_order = (leave_one_scores).argsort()
else:
index_order = (-leave_one_scores).argsort()
i = 0 i = 0
results = None results = None
@@ -114,9 +115,7 @@ class GreedyWordSwapWIR(SearchMethod):
i += 1 i += 1
if len(transformed_text_candidates) == 0: if len(transformed_text_candidates) == 0:
continue continue
results, search_over = self.get_goal_results( results, search_over = self.get_goal_results(transformed_text_candidates)
transformed_text_candidates, initial_result.output
)
results = sorted(results, key=lambda x: -x.score) results = sorted(results, key=lambda x: -x.score)
# Skip swaps which don't improve the score # Skip swaps which don't improve the score
if results[0].score > cur_result.score: if results[0].score > cur_result.score:
@@ -124,12 +123,12 @@ class GreedyWordSwapWIR(SearchMethod):
else: else:
continue continue
# If we succeeded, return the index with best similarity. # If we succeeded, return the index with best similarity.
if cur_result.succeeded: if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
best_result = cur_result best_result = cur_result
# @TODO: Use vectorwise operations # @TODO: Use vectorwise operations
max_similarity = -float("inf") max_similarity = -float("inf")
for result in results: for result in results:
if not result.succeeded: if result.goal_status != GoalFunctionResultStatus.SUCCEEDED:
break break
candidate = result.attacked_text candidate = result.attacked_text
try: try:
@@ -149,9 +148,9 @@ class GreedyWordSwapWIR(SearchMethod):
def check_transformation_compatibility(self, transformation): def check_transformation_compatibility(self, transformation):
""" """
Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swaps transformations. Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swap and deletion transformations.
""" """
return transformation_consists_of_word_swaps(transformation) return transformation_consists_of_word_swaps_and_deletions(transformation)
def extra_repr_keys(self): def extra_repr_keys(self):
return ["wir_method"] return ["wir_method"]

View File

@@ -22,6 +22,10 @@ class SearchMethod(ABC):
raise AttributeError( raise AttributeError(
"Search Method must have access to get_goal_results method" "Search Method must have access to get_goal_results method"
) )
if not hasattr(self, "filter_transformations"):
raise AttributeError(
"Search Method must have access to filter_transformations method"
)
return self._perform_search(initial_result) return self._perform_search(initial_result)
@abstractmethod @abstractmethod

View File

@@ -7,9 +7,11 @@ import numpy as np
import textattack import textattack
from textattack.attack_results import ( from textattack.attack_results import (
FailedAttackResult, FailedAttackResult,
MaximizedAttackResult,
SkippedAttackResult, SkippedAttackResult,
SuccessfulAttackResult, SuccessfulAttackResult,
) )
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.shared import AttackedText, utils from textattack.shared import AttackedText, utils
@@ -56,7 +58,7 @@ class Attack:
self.transformation self.transformation
): ):
raise ValueError( raise ValueError(
"SearchMethod {self.search_method} incompatible with transformation {self.transformation}" f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
) )
self.constraints = [] self.constraints = []
@@ -74,7 +76,12 @@ class Attack:
# Give search method access to functions for getting transformations and evaluating them # Give search method access to functions for getting transformations and evaluating them
self.search_method.get_transformations = self.get_transformations self.search_method.get_transformations = self.get_transformations
self.search_method.get_goal_results = self.goal_function.get_results # The search method only needs access to the first argument. The second is only used
# by the attack class when checking whether to skip the sample
self.search_method.get_goal_results = lambda attacked_text_list: self.goal_function.get_results(
attacked_text_list
)
self.search_method.filter_transformations = self.filter_transformations
def get_transformations(self, current_text, original_text=None, **kwargs): def get_transformations(self, current_text, original_text=None, **kwargs):
""" """
@@ -102,7 +109,7 @@ class Attack:
**kwargs, **kwargs,
) )
) )
return self._filter_transformations( return self.filter_transformations(
transformed_texts, current_text, original_text transformed_texts, current_text, original_text
) )
@@ -138,7 +145,7 @@ class Attack:
self.constraints_cache[(current_text, filtered_text)] = True self.constraints_cache[(current_text, filtered_text)] = True
return filtered_texts return filtered_texts
def _filter_transformations( def filter_transformations(
self, transformed_texts, current_text, original_text=None self, transformed_texts, current_text, original_text=None
): ):
""" """
@@ -180,17 +187,18 @@ class Attack:
initial_result: The initial ``GoalFunctionResult`` from which to perturb. initial_result: The initial ``GoalFunctionResult`` from which to perturb.
Returns: Returns:
Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``. A ``SuccessfulAttackResult``, ``FailedAttackResult``,
or ``MaximizedAttackResult``.
""" """
final_result = self.search_method(initial_result) final_result = self.search_method(initial_result)
if final_result.succeeded: if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult( return SuccessfulAttackResult(initial_result, final_result,)
initial_result, final_result, self.goal_function.num_queries elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
) return FailedAttackResult(initial_result, final_result,)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(initial_result, final_result,)
else: else:
return FailedAttackResult( raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
initial_result, final_result, self.goal_function.num_queries
)
def _get_examples_from_dataset(self, dataset, indices=None): def _get_examples_from_dataset(self, dataset, indices=None):
""" """
@@ -222,14 +230,9 @@ class Attack:
attacked_text = AttackedText( attacked_text = AttackedText(
text, attack_attrs={"label_names": label_names} text, attack_attrs={"label_names": label_names}
) )
self.goal_function.num_queries = 0 goal_function_result, _ = self.goal_function.init_attack_example(
goal_function_result, _ = self.goal_function.get_result(
attacked_text, ground_truth_output attacked_text, ground_truth_output
) )
if goal_function_result.succeeded:
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = ground_truth_output
yield goal_function_result yield goal_function_result
except IndexError: except IndexError:
@@ -250,7 +253,7 @@ class Attack:
examples = self._get_examples_from_dataset(dataset, indices=indices) examples = self._get_examples_from_dataset(dataset, indices=indices)
for goal_function_result in examples: for goal_function_result in examples:
if goal_function_result.succeeded: if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
yield SkippedAttackResult(goal_function_result) yield SkippedAttackResult(goal_function_result)
else: else:
result = self.attack_one(goal_function_result) result = self.attack_one(goal_function_result)

View File

@@ -42,9 +42,11 @@ class AttackedText:
raise TypeError( raise TypeError(
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)" f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
) )
# Find words in input lazily.
self._words = None
self._words_per_input = None
# Format text inputs. # Format text inputs.
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()]) self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
self.words = words_from_text(self.text)
if attack_attrs is None: if attack_attrs is None:
self.attack_attrs = dict() self.attack_attrs = dict()
elif isinstance(attack_attrs, dict): elif isinstance(attack_attrs, dict):
@@ -53,7 +55,7 @@ class AttackedText:
raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}") raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}")
# Indices of words from the *original* text. Allows us to map # Indices of words from the *original* text. Allows us to map
# indices between original text and this text, and vice-versa. # indices between original text and this text, and vice-versa.
self.attack_attrs.setdefault("original_index_map", np.arange(len(self.words))) self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words))
# A list of all indices in *this* text that have been modified. # A list of all indices in *this* text that have been modified.
self.attack_attrs.setdefault("modified_indices", set()) self.attack_attrs.setdefault("modified_indices", set())
@@ -97,7 +99,7 @@ class AttackedText:
def text_window_around_index(self, index, window_size): def text_window_around_index(self, index, window_size):
""" The text window of ``window_size`` words centered around ``index``. """ """ The text window of ``window_size`` words centered around ``index``. """
length = len(self.words) length = self.num_words
half_size = (window_size - 1) / 2.0 half_size = (window_size - 1) / 2.0
if index - half_size < 0: if index - half_size < 0:
start = 0 start = 0
@@ -177,7 +179,7 @@ class AttackedText:
""" Takes indices of words from original string and converts them to """ Takes indices of words from original string and converts them to
indices of the same words in the current string. indices of the same words in the current string.
Uses information from ``self.attack_attrs['original_index_map'], Uses information from ``self.attack_attrs['original_index_map']``,
which maps word indices from the original to perturbed text. which maps word indices from the original to perturbed text.
""" """
if len(self.attack_attrs["original_index_map"]) == 0: if len(self.attack_attrs["original_index_map"]) == 0:
@@ -344,6 +346,29 @@ class AttackedText:
""" The tuple of inputs to be passed to the tokenizer. """ """ The tuple of inputs to be passed to the tokenizer. """
return tuple(self._text_input.values()) return tuple(self._text_input.values())
@property
def column_labels(self):
""" Returns the labels for this text's columns. For single-sequence
inputs, this simply returns ['text'].
"""
return list(self._text_input.keys())
@property
def words_per_input(self):
""" Returns a list of lists of words corresponding to each input.
"""
if not self._words_per_input:
self._words_per_input = [
words_from_text(_input) for _input in self._text_input.values()
]
return self._words_per_input
@property
def words(self):
if not self._words:
self._words = words_from_text(self.text)
return self._words
@property @property
def text(self): def text(self):
""" Represents full text input. Multiply inputs are joined with a line """ Represents full text input. Multiply inputs are joined with a line
@@ -351,6 +376,11 @@ class AttackedText:
""" """
return "\n".join(self._text_input.values()) return "\n".join(self._text_input.values())
@property
def num_words(self):
""" Returns the number of words in the sequence. """
return len(self.words)
def printable_text(self, key_color="bold", key_color_method=None): def printable_text(self, key_color="bold", key_color_method=None):
""" Represents full text input. Adds field descriptions. """ Represents full text input. Adds field descriptions.

View File

@@ -6,6 +6,7 @@ import time
from textattack.attack_results import ( from textattack.attack_results import (
FailedAttackResult, FailedAttackResult,
MaximizedAttackResult,
SkippedAttackResult, SkippedAttackResult,
SuccessfulAttackResult, SuccessfulAttackResult,
) )
@@ -101,6 +102,11 @@ class Checkpoint:
f"(Number of failed attacks): {self.num_failed_attacks}", 2 f"(Number of failed attacks): {self.num_failed_attacks}", 2
) )
) )
breakdown_lines.append(
utils.add_indent(
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2
)
)
breakdown_lines.append( breakdown_lines.append(
utils.add_indent( utils.add_indent(
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2 f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
@@ -140,6 +146,12 @@ class Checkpoint:
isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results
) )
@property
def num_maximized_attacks(self):
return sum(
isinstance(r, MaximizedAttackResult) for r in self.log_manager.results
)
@property @property
def num_remaining_attacks(self): def num_remaining_attacks(self):
if self.args.attack_n: if self.args.attack_n:

View File

@@ -11,6 +11,9 @@ import requests
import torch import torch
import tqdm import tqdm
# Hide an error message from `tokenizers` if this process is forked.
os.environ["TOKENIZERS_PARALLELISM"] = "True"
def path_in_cache(file_path): def path_in_cache(file_path):
try: try:

Some files were not shown because too many files have changed in this diff Show More