diff --git a/.github/workflows/check-formatting.yml b/.github/workflows/check-formatting.yml new file mode 100644 index 00000000..f62d5caa --- /dev/null +++ b/.github/workflows/check-formatting.yml @@ -0,0 +1,34 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Formatting with black & isort + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + pip install black flake8 isort # Testing packages + python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 + pip install -e . + - name: Check code format with black and isort + run: | + make lint diff --git a/.github/workflows/make-docs.yml b/.github/workflows/make-docs.yml new file mode 100644 index 00000000..67f7c63f --- /dev/null +++ b/.github/workflows/make-docs.yml @@ -0,0 +1,37 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Build documentation with Sphinx + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo sed -i 's/azure\.//' /etc/apt/sources.list # workaround for flaky pandoc install + sudo apt-get update # from here https://github.com/actions/virtual-environments/issues/675 + sudo apt-get install pandoc -o Acquire::Retries=3 # install pandoc + python -m pip install --upgrade pip setuptools wheel # update python + pip install ipython --upgrade # needed for Github for whatever reason + python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 + pip install -e . ".[dev]" # This should install all packages for development + - name: Build docs with Sphinx and check for errors + run: | + sphinx-build -b html docs docs/_build/html -W diff --git a/.github/workflows/python-publish.yml b/.github/workflows/publish-to-pypi.yml similarity index 95% rename from .github/workflows/python-publish.yml rename to .github/workflows/publish-to-pypi.yml index 7e958c6e..fe037a12 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -1,7 +1,7 @@ # This workflows will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries -name: Upload Python Package +name: Upload Python Package to PyPI on: release: diff --git a/.github/workflows/python-test.yml b/.github/workflows/run-pytest.yml similarity index 81% rename from .github/workflows/python-test.yml rename to .github/workflows/run-pytest.yml index 448bc1f3..ee16e7c7 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/run-pytest.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Github PyTest +name: Test with PyTest on: push: @@ -26,13 +26,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel - pip install black isort pytest pytest-xdist + pip install pytest pytest-xdist # Testing packages python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537 pip install -e . - - name: Check code format with black and isort - run: | - black . --check - isort --check-only --recursive tests textattack - name: Test with pytest run: | pytest tests -vx --dist=loadfile -n auto diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 858ef797..7eed7873 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -179,11 +179,25 @@ Follow these steps to start contributing: $ git push -u origin a-descriptive-name-for-my-changes ``` -6. Once you are satisfied (**and the checklist below is happy too**), go to the +6. Add documentation. + + Our docs are in the `docs/` folder. Thanks to `sphinx-automodule`, this + should just be two lines. Our docs will automatically generate from the + comments you added to your code. If you're adding an attack recipe, add a + reference in `attack_recipes.rst`. If you're adding a transformation, add + a reference in `transformation.rst`, etc. + + You can build the docs and view the updates using `make docs`. If you're + adding a tutorial or something where you want to update the docs multiple + times, you can run `make docs-auto`. This will run a server using + `sphinx-autobuild` that should automatically reload whenever you change + a file. + +7. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review. -7. It's ok if maintainers ask you for changes. It happens to core contributors +8. It's ok if maintainers ask you for changes. It happens to core contributors too! So everyone can see the changes in the Pull request, work in your local branch and push the changes to your fork. They will automatically appear in the pull request. diff --git a/Makefile b/Makefile index f22a6848..dcee03c2 100644 --- a/Makefile +++ b/Makefile @@ -3,9 +3,10 @@ format: FORCE ## Run black and isort (rewriting files) isort --atomic --recursive tests textattack -lint: FORCE ## Run black (in check mode) +lint: FORCE ## Run black, isort, flake8 (in check mode) black . --check isort --check-only --recursive tests textattack + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=./.*,build,dist # catch certain syntax errors using flake8 test: FORCE ## Run tests using pytest python -m pytest --dist=loadfile -n auto @@ -13,10 +14,13 @@ test: FORCE ## Run tests using pytest docs: FORCE ## Build docs using Sphinx. sphinx-build -b html docs docs/_build/html +docs-check: FORCE ## Builds docs using Sphinx. If there is an error, exit with an error code (instead of warning & continuing). + sphinx-build -b html docs docs/_build/html -W + docs-auto: FORCE ## Build docs using Sphinx and run hotreload server using Sphinx autobuild. sphinx-autobuild docs docs/_build/html -H 0.0.0.0 -p 8765 -all: format lint test ## Format, lint, and test. +all: format lint docs-check test ## Format, lint, and test. .PHONY: help diff --git a/README.md b/README.md index fe78536f..6a38fa36 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@

- +TextAttack Demo GIF ## About @@ -97,18 +97,19 @@ We include attack recipes which implement attacks from the literature. You can l To run an attack recipe: `textattack attack --recipe [recipe_name]` -These attacks are for classification tasks, like sentiment classification and entailment: +Attacks on classification tasks, like sentiment classification and entailment: - **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)). - **bae**: BERT masked language model transformation attack from (["BAE: BERT-based Adversarial Examples for Text Classification" (Garg & Ramakrishnan, 2019)](https://arxiv.org/abs/2004.01970)). - **bert-attack**: BERT masked language model transformation attack with subword replacements (["BERT-ATTACK: Adversarial Attack Against BERT Using BERT" (Li et al., 2020)](https://arxiv.org/abs/2004.09984)). - **deepwordbug**: Greedy replace-1 scoring and multi-transformation character-swap attack (["Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers" (Gao et al., 2018)](https://arxiv.org/abs/1801.04354)). - **hotflip**: Beam search and gradient-based word swap (["HotFlip: White-Box Adversarial Examples for Text Classification" (Ebrahimi et al., 2017)](https://arxiv.org/abs/1712.06751)). +- **input-reduction**: Reducing the input while maintaining the prediction through word importance ranking (["Pathologies of Neural Models Make Interpretation Difficult" (Feng et al., 2018)](https://arxiv.org/pdf/1804.07781.pdf)). - **kuleshov**: Greedy search and counterfitted embedding swap (["Adversarial Examples for Natural Language Classification Problems" (Kuleshov et al., 2018)](https://openreview.net/pdf?id=r1QZ3zbAZ)). - **pwws**: Greedy attack with word importance ranking based on word saliency and synonym swap scores (["Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency" (Ren et al., 2019)](https://www.aclweb.org/anthology/P19-1103/)). - **textbugger**: Greedy attack with word importance ranking and character-based swaps ([(["TextBugger: Generating Adversarial Text Against Real-world Applications" (Li et al., 2018)](https://arxiv.org/abs/1812.05271)). - **textfooler**: Greedy attack with word importance ranking and counter-fitted embedding swap (["Is Bert Really Robust?" (Jin et al., 2019)](https://arxiv.org/abs/1907.11932)). -The final is for sequence-to-sequence models: +Attacks on sequence-to-sequence models: - **seq2sick**: Greedy attack with goal of changing every word in the output translation. Currently implemented as black-box with plans to change to white-box as done in paper (["Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples" (Cheng et al., 2018)](https://arxiv.org/abs/1803.01128)). #### Recipe Usage Examples @@ -122,7 +123,7 @@ textattack attack --model bert-base-uncased-sst2 --recipe textfooler --num-examp *seq2sick (black-box) against T5 fine-tuned for English-German translation:* ```bash -textattack attack --recipe seq2sick --model t5-en2de --num-examples 100 + textattack attack --model t5-en-de --recipe seq2sick --num-examples 100 ``` ### Augmenting Text @@ -284,7 +285,7 @@ The `attack_one` method in an `Attack` takes as input an `AttackedText`, and out ### Goal Functions -A `GoalFunction` takes as input an `AttackedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`. +A `GoalFunction` takes as input an `AttackedText` object, scores it, and determines whether the attack has succeeded, returning a `GoalFunctionResult`. ### Constraints @@ -303,6 +304,8 @@ A `SearchMethod` takes as input an initial `GoalFunctionResult` and returns a fi We welcome suggestions and contributions! Submit an issue or pull request and we will do our best to respond in a timely manner. TextAttack is currently in an "alpha" stage in which we are working to improve its capabilities and design. +See [CONTRIBUTING.md](https://github.com/QData/TextAttack/blob/master/CONTRIBUTING.md) for detailed information on contributing. + ## Citing TextAttack If you use TextAttack for your research, please cite [TextAttack: A Framework for Adversarial Attacks in Natural Language Processing](https://arxiv.org/abs/2005.05909). diff --git a/docs/attacks/attack_recipes.rst b/docs/attacks/attack_recipes.rst index f53a349c..e1499532 100644 --- a/docs/attacks/attack_recipes.rst +++ b/docs/attacks/attack_recipes.rst @@ -6,67 +6,81 @@ We provide a number of pre-built attack recipes. To run an attack recipe, run:: textattack attack --recipe [recipe_name] Alzantot Genetic Algorithm (Generating Natural Language Adversarial Examples) -########### +################################################################################### .. automodule:: textattack.attack_recipes.genetic_algorithm_alzantot_2018 :members: Faster Alzantot Genetic Algorithm (Certified Robustness to Adversarial Word Substitutions) -########### +############################################################################################## .. automodule:: textattack.attack_recipes.faster_genetic_algorithm_jia_2019 :members: BAE (BAE: BERT-Based Adversarial Examples) -############ +############################################# -.. automodule:: textattack.attack_recipes.deepwordbug_gao_2018 +.. automodule:: textattack.attack_recipes.bae_garg_2019 + :members: BERT-Attack: (BERT-Attack: Adversarial Attack Against BERT Using BERT) -############ +######################################################################### -.. automodule:: textattack.attack_recipes.deepwordbug_gao_2018 +.. automodule:: textattack.attack_recipes.bert_attack_li_2020 + :members: DeepWordBug (Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers) -############ +###################################################################################################### .. automodule:: textattack.attack_recipes.deepwordbug_gao_2018 :members: HotFlip (HotFlip: White-Box Adversarial Examples for Text Classification) -########### +############################################################################## + +.. automodule:: textattack.attack_recipes.hotflip_ebrahimi_2017 + :members: + +Input Reduction +################ .. automodule:: textattack.attack_recipes.input_reduction_feng_2018 :members: Kuleshov (Adversarial Examples for Natural Language Classification Problems) -########### +############################################################################## .. automodule:: textattack.attack_recipes.kuleshov_2017 :members: +Particle Swarm Optimization (Word-level Textual Adversarial Attacking as Combinatorial Optimization) +##################################################################################################### + +.. automodule:: textattack.attack_recipes.PSO_zang_2020 + :members: + PWWS (Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency) -########### +################################################################################################### .. automodule:: textattack.attack_recipes.pwws_ren_2019 :members: Seq2Sick (Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples) -########### +######################################################################################################### .. automodule:: textattack.attack_recipes.seq2sick_cheng_2018_blackbox :members: TextFooler (Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and Entailment) -########### +######################################################################################################################## .. automodule:: textattack.attack_recipes.textfooler_jin_2019 :members: TextBugger (TextBugger: Generating Adversarial Text Against Real-world Applications) -########### +######################################################################################## .. automodule:: textattack.attack_recipes.textbugger_li_2018 :members: diff --git a/docs/attacks/constraint.rst b/docs/attacks/constraint.rst index 9af35f0d..db11f159 100644 --- a/docs/attacks/constraint.rst +++ b/docs/attacks/constraint.rst @@ -85,7 +85,7 @@ GPT-2 :members: "Learning To Write" Language Model -******* +************************************ .. automodule:: textattack.constraints.grammaticality.language_models.learning_to_write.learning_to_write :members: @@ -142,7 +142,7 @@ Maximum Words Perturbed .. _pre_transformation: Pre-Transformation ----------- +------------------------- Pre-transformation constraints determine if a transformation is valid based on only the original input and the position of the replacement. These constraints @@ -151,7 +151,7 @@ constraints can prevent search methods from swapping words at the same index twice, or from replacing stopwords. Pre-Transformation Constraint -######################## +############################### .. automodule:: textattack.constraints.pre_transformation.pre_transformation_constraint :special-members: __call__ :private-members: @@ -166,3 +166,13 @@ Repeat Modification ######################## .. automodule:: textattack.constraints.pre_transformation.repeat_modification :members: + +Input Column Modification +############################# +.. automodule:: textattack.constraints.pre_transformation.input_column_modification + :members: + +Max Word Index Modification +############################### +.. automodule:: textattack.constraints.pre_transformation.max_word_index_modification + :members: diff --git a/docs/attacks/transformation.rst b/docs/attacks/transformation.rst index d37f205a..b61a8eab 100644 --- a/docs/attacks/transformation.rst +++ b/docs/attacks/transformation.rst @@ -69,7 +69,7 @@ Word Swap by Random Character Insertion :members: Word Swap by Random Character Substitution ---------------------------------------- +------------------------------------------- .. automodule:: textattack.transformations.word_swap_random_character_substitution :members: diff --git a/docs/conf.py b/docs/conf.py index feb823ca..383855af 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,7 @@ copyright = "2020, UVA QData Lab" author = "UVA QData Lab" # The full version, including alpha/beta/rc tags -release = "0.1.2" +release = "0.1.5" # Set master doc to `index.rst`. master_doc = "index" diff --git a/docs/datasets_models/datasets.rst b/docs/datasets_models/datasets.rst index bcd5b407..158d87d7 100644 --- a/docs/datasets_models/datasets.rst +++ b/docs/datasets_models/datasets.rst @@ -6,19 +6,10 @@ Datasets :members: :private-members: -Classification -############### -.. automodule:: textattack.datasets.classification.classification_dataset +.. automodule:: textattack.datasets.huggingface_nlp_dataset :members: -Entailment -############ -.. automodule:: textattack.datasets.entailment.entailment_dataset +.. automodule:: textattack.datasets.translation.ted_multi :members: -Translation -############# -.. automodule:: textattack.datasets.translation.translation_datasets - :members: - diff --git a/docs/datasets_models/models.rst b/docs/datasets_models/models.rst index 15bbeb45..e3069415 100644 --- a/docs/datasets_models/models.rst +++ b/docs/datasets_models/models.rst @@ -11,7 +11,7 @@ We split models up into two broad categories: **Classification models:** - :ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using transformers_. + :ref:`BERT`: ``bert-base-uncased`` fine-tuned on various datasets using ``transformers``. :ref:`LSTM`: a standard LSTM fine-tuned on various datasets. @@ -20,30 +20,29 @@ We split models up into two broad categories: **Text-to-text models:** - :ref:`T5`: ``T5`` fine-tuned on various datasets using transformers_. + :ref:`T5`: ``T5`` fine-tuned on various datasets using ``transformers``. +.. _BERT: BERT ******** -.. _BERT: - .. automodule:: textattack.models.helpers.bert_for_classification :members: -LSTM -******* .. _LSTM: +LSTM +******* .. automodule:: textattack.models.helpers.lstm_for_classification :members: -Word-CNN -************ .. _CNN: +Word-CNN +************ .. automodule:: textattack.models.helpers.word_cnn_for_classification :members: diff --git a/docs/examples/1_Introduction_and_Transformations.ipynb b/docs/examples/1_Introduction_and_Transformations.ipynb index f39fb9b4..fb37db23 100644 --- a/docs/examples/1_Introduction_and_Transformations.ipynb +++ b/docs/examples/1_Introduction_and_Transformations.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# The TextAttack🐙 ecosystem: search, transformations, and constraints\n", + "# The TextAttack ecosystem: search, transformations, and constraints\n", "\n", "An attack in TextAttack consists of four parts.\n", "\n", @@ -31,9 +31,9 @@ "This lesson explains how to create a custom transformation. In TextAttack, many transformations involve *word swaps*: they take a word and try and find suitable substitutes. Some attacks focus on replacing characters with neighboring characters to create \"typos\" (these don't intend to preserve the grammaticality of inputs). Other attacks rely on semantics: they take a word and try to replace it with semantic equivalents.\n", "\n", "\n", - "### Banana word swap 🍌\n", + "### Banana word swap \n", "\n", - "As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'." + "As an introduction to writing transformations for TextAttack, we're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract `WordSwap` class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend `WordSwap` and implement a single method, `_get_replacement_words`, to indicate to replace each word with 'banana'. 🍌" ] }, { @@ -308,9 +308,9 @@ "collapsed": true }, "source": [ - "### Conclusion 🍌\n", + "### Conclusion n", "\n", - "We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a single \"banana\", while others needed up to 17 \"banana\" substitutions to change the class score. Wow!" + "We can examine these examples for a good idea of how many words had to be changed to \"banana\" to change the prediction score from the correct class to another class. The examples without perturbed words were originally misclassified, so they were skipped by the attack. Looks like some examples needed only a couple \"banana\"s, while others needed up to 17 \"banana\" substitutions to change the class score. Wow! 🍌" ] } ], diff --git a/docs/index.rst b/docs/index.rst index 52f88a92..c2abc5f4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,7 +35,6 @@ TextAttack has some other features that make it a pleasure to use: Installation - Overview Command-Line Usage Tutorial 0: TextAttack End-To-End (Train, Eval, Attack) Tutorial 1: Transformations @@ -76,7 +75,7 @@ TextAttack has some other features that make it a pleasure to use: :hidden: :caption: Miscellaneous + misc/attacked_text misc/checkpoints misc/loggers misc/validators - misc/tokenized_text diff --git a/docs/misc/attacked_text.rst b/docs/misc/attacked_text.rst new file mode 100644 index 00000000..acfbadca --- /dev/null +++ b/docs/misc/attacked_text.rst @@ -0,0 +1,6 @@ +=================== +Attacked Text +=================== + +.. automodule:: textattack.shared.attacked_text + :members: diff --git a/docs/misc/tokenized_text.rst b/docs/misc/tokenized_text.rst deleted file mode 100644 index b39a4da5..00000000 --- a/docs/misc/tokenized_text.rst +++ /dev/null @@ -1,6 +0,0 @@ -=================== -Tokenized Text -=================== - -.. automodule:: textattack.shared.tokenized_text - :members: \ No newline at end of file diff --git a/docs/quickstart/command_line_usage.md b/docs/quickstart/command_line_usage.md index b294813b..3edd17e1 100644 --- a/docs/quickstart/command_line_usage.md +++ b/docs/quickstart/command_line_usage.md @@ -22,7 +22,7 @@ examples corresponding to the proper columns. For example, given the following as `examples.csv`: -```csv +``` "text",label "the rock is destined to be the 21st century's new conan and that he's going to make a splash even greater than arnold schwarzenegger , jean- claud van damme or steven segal.", 1 "the gorgeously elaborate continuation of 'the lord of the rings' trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth .", 1 @@ -40,7 +40,7 @@ will augment the `text` column with four swaps per augmentation, twice as many a output CSV. (All of this will be saved to `augment.csv` by default.) After augmentation, here are the contents of `augment.csv`: -```csv +``` text,label "the rock is destined to be the 21st century's newest conan and that he's gonna to make a splashing even stronger than arnold schwarzenegger , jean- claud van damme or steven segal.",1 "the rock is destined to be the 21tk century's novel conan and that he's going to make a splat even greater than arnold schwarzenegger , jean- claud van damme or stevens segal.",1 @@ -132,4 +132,4 @@ see some basic information about the dataset. For example, use `textattack peek-dataset --dataset-from-nlp glue:mrpc` to see information about the MRPC dataset (from the GLUE set of datasets). This will -print statistics like the number of labels, average number of words, etc. \ No newline at end of file +print statistics like the number of labels, average number of words, etc. diff --git a/requirements.txt b/requirements.txt index cc9a66e1..6ce04467 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ numpy pandas>=1.0.1 scikit-learn scipy==1.4.1 -sentence_transformers +sentence_transformers==0.2.6.1 torch transformers>=3 tensorflow>=2 diff --git a/setup.cfg b/setup.cfg index 66bc6eb5..66905bf6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,3 @@ -[flake8] -ignore = E203, E266, E501, W503 -max-line-length = 120 -per-file-ignores = __init__.py:F401 -mypy_config = mypy.ini - [isort] line_length = 88 skip = __init__.py @@ -14,3 +8,11 @@ multi_line_output = 3 include_trailing_comma = True use_parentheses = True force_grid_wrap = 0 + +[flake8] +exclude = .git,__pycache__,wandb,build,dist +ignore = E203, E266, E501, W503, D203 +max-complexity = 10 +max-line-length = 120 +mypy_config = mypy.ini +per-file-ignores = __init__.py:F401 diff --git a/setup.py b/setup.py index 748cfe9b..6a5ecdaa 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,12 @@ with open("README.md", "r") as fh: long_description = fh.read() extras = {} +# Packages required for installing docs. +extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"] +# Packages required for formatting code & running tests. +extras["test"] = ["black", "isort", "flake8", "pytest", "pytest-xdist"] # For developers, install development tools along with all optional dependencies. -extras["dev"] = ["black", "isort", "pytest", "pytest-xdist"] - +extras["dev"] = extras["docs"] + extras["test"] setuptools.setup( name="textattack", @@ -27,9 +30,9 @@ setuptools.setup( "build*", "docs*", "dist*", + "examples*", "outputs*", "tests*", - "local_test*", "wandb*", ] ), diff --git a/tests/sample_outputs/interactive_mode.txt b/tests/sample_outputs/interactive_mode.txt index d41e8900..e8709933 100644 --- a/tests/sample_outputs/interactive_mode.txt +++ b/tests/sample_outputs/interactive_mode.txt @@ -28,6 +28,10 @@ ) (3): RepeatModification (4): StopwordModification + (5): InputColumnModification( + (matching_column_labels): ['premise', 'hypothesis'] + (columns_to_ignore): {'premise'} + ) (is_black_box): True ) /.*/ diff --git a/tests/sample_outputs/kuleshov_cnn_sst_2.txt b/tests/sample_outputs/kuleshov_cnn_sst_2.txt new file mode 100644 index 00000000..70027a72 --- /dev/null +++ b/tests/sample_outputs/kuleshov_cnn_sst_2.txt @@ -0,0 +1,57 @@ +/.*/Attack( + (search_method): GreedySearch + (goal_function): UntargetedClassification + (transformation): WordSwapEmbedding( + (max_candidates): 15 + (embedding_type): paragramcf + ) + (constraints): + (0): MaxWordsPerturbed( + (max_percent): 0.5 + ) + (1): ThoughtVector( + (embedding_type): paragramcf + (metric): max_euclidean + (threshold): -0.2 + (compare_with_original): False + (window_size): inf + (skip_text_shorter_than_window): False + ) + (2): GPT2( + (max_log_prob_diff): 2.0 + ) + (3): RepeatModification + (4): StopwordModification + (is_black_box): True +) +/.*/ +--------------------------------------------- Result 1 --------------------------------------------- +Positive (100%) --> Negative (69%) + +it 's a charming and often affecting journey . + +it 's a loveable and ordinarily affecting journey . + + +--------------------------------------------- Result 2 --------------------------------------------- +Negative (83%) --> Positive (90%) + +unflinchingly bleak and desperate + +unflinchingly bleak and desperation + + + ++-------------------------------+--------+ +| Attack Results | | ++-------------------------------+--------+ +| Number of successful attacks: | 2 | +| Number of failed attacks: | 0 | +| Number of skipped attacks: | 0 | +| Original accuracy: | 100.0% | +| Accuracy under attack: | 0.0% | +| Attack success rate: | 100.0% | +| Average perturbed word %: | 25.0% | +| Average num. words per input: | 6.0 | +| Avg num queries: | 48.5 | ++-------------------------------+--------+ diff --git a/tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt b/tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt index c600d859..8f9a18c2 100644 --- a/tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt +++ b/tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt @@ -39,7 +39,7 @@ | Original accuracy: | 100.0% | | Accuracy under attack: | 0.0% | | Attack success rate: | 100.0% | -| Average perturbed word %: | 45.39% | -| Average num. words per input: | 11.5 | -| Avg num queries: | 26.5 | +| Average perturbed word %: | 45.0% | +| Average num. words per input: | 12.0 | +| Avg num queries: | 27.0 | +-------------------------------+--------+ diff --git a/tests/sample_outputs/run_attack_faster_alzantot_recipe.txt b/tests/sample_outputs/run_attack_faster_alzantot_recipe.txt new file mode 100644 index 00000000..cc7ab7c8 --- /dev/null +++ b/tests/sample_outputs/run_attack_faster_alzantot_recipe.txt @@ -0,0 +1,66 @@ +/.*/Attack( + (search_method): GeneticAlgorithm( + (pop_size): 60 + (max_iters): 20 + (temp): 0.3 + (give_up_if_no_improvement): False + ) + (goal_function): UntargetedClassification + (transformation): WordSwapEmbedding( + (max_candidates): 8 + (embedding_type): paragramcf + ) + (constraints): + (0): MaxWordsPerturbed( + (max_percent): 0.2 + ) + (1): WordEmbeddingDistance( + (embedding_type): paragramcf + (max_mse_dist): 0.5 + (cased): False + (include_unknown_words): True + ) + (2): LearningToWriteLanguageModel( + (max_log_prob_diff): 5.0 + ) + (3): RepeatModification + (4): StopwordModification + (is_black_box): True +) +/.*/ +--------------------------------------------- Result 1 --------------------------------------------- +Positive (100%) --> Negative (73%) + +this kind of hands-on storytelling is ultimately what makes shanghai ghetto move beyond a good , dry , reliable textbook and what allows it to rank with its worthy predecessors . + +this kind of hands-on tale is ultimately what makes shanghai ghetto move beyond a good , secs , credible textbook and what allows it to rank with its worthy predecessors . + + +--------------------------------------------- Result 2 --------------------------------------------- +Positive (80%) --> Negative (97%) + +making such a tragedy the backdrop to a love story risks trivializing it , though chouraqui no doubt intended the film to affirm love's power to help people endure almost unimaginable horror . + +making such a tragedy the backdrop to a love story risks trivializing it , notwithstanding chouraqui no doubt intended the film to affirm love's power to help people endure almost incomprehensible horror . + + +--------------------------------------------- Result 3 --------------------------------------------- +Positive (92%) --> [FAILED] + +grown-up quibbles are beside the point here . the little girls understand , and mccracken knows that's all that matters . + + + ++-------------------------------+--------+ +| Attack Results | | ++-------------------------------+--------+ +| Number of successful attacks: | 2 | +| Number of failed attacks: | 1 | +| Number of skipped attacks: | 0 | +| Original accuracy: | 100.0% | +| Accuracy under attack: | 33.33% | +| Attack success rate: | 66.67% | +| Average perturbed word %: | 8.58% | +| Average num. words per input: | 25.67 | +| Avg num queries: |/.*/| ++-------------------------------+--------+ diff --git a/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt b/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt index 49791b8e..7baf6ec1 100644 --- a/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt +++ b/tests/sample_outputs/run_attack_flair_pos_tagger_bert_score.txt @@ -24,11 +24,11 @@ ) /.*/ --------------------------------------------- Result 1 --------------------------------------------- -Positive (100%) --> Negative (88%) +Positive (100%) --> Negative (98%) -exposing the ways we fool ourselves is one hour photo's real strength . +exposing the ways we fool ourselves is one hour photo's real strength . -exposing the ways we fool ourselves is one hour pictures's real kraft . +exposing the ways we fool ourselves is one stopwatch photo's real kraft . --------------------------------------------- Result 2 --------------------------------------------- @@ -65,7 +65,7 @@ mostly , [goldbacher] just lets her complicated characters be haphazard | Original accuracy: | 100.0% | | Accuracy under attack: | 0.0% | | Attack success rate: | 100.0% | -| Average perturbed word %: | 17.13% | -| Average num. words per input: | 17.0 | -| Avg num queries: | 46.0 | +| Average perturbed word %: | 17.56% | +| Average num. words per input: | 16.25 | +| Avg num queries: | 45.5 | +-------------------------------+--------+ diff --git a/tests/sample_outputs/run_attack_from_file.txt b/tests/sample_outputs/run_attack_from_file.txt index 94c6c7a0..3ea17bbb 100644 --- a/tests/sample_outputs/run_attack_from_file.txt +++ b/tests/sample_outputs/run_attack_from_file.txt @@ -35,6 +35,6 @@ I went into "Night of the Hunted" not knowing what to expect at all. I was reall | Accuracy under attack: | 0.0% | | Attack success rate: | 100.0% | | Average perturbed word %: | 0.62% | -| Average num. words per input: | 165.0 | -| Avg num queries: | 167.0 | +| Average num. words per input: | 164.0 | +| Avg num queries: | 166.0 | +-------------------------------+--------+ diff --git a/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt b/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt index 2c1a8a82..2c232b3e 100644 --- a/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt +++ b/tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt @@ -27,11 +27,9 @@ ) /.*/ --------------------------------------------- Result 1 --------------------------------------------- -Positive (97%) --> Negative (100%) +Positive (97%) --> [FAILED] -the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur supplies with tremendous skill . - -the story gives ample opportunity for large-scale action and suspense , which director shekhar unwilling supplies with tremendous skill . +the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur supplies with tremendous skill . --------------------------------------------- Result 2 --------------------------------------------- @@ -58,13 +56,13 @@ throws in enough clever and unexpected twists to make the formula feel fresh . +-------------------------------+--------+ | Attack Results | | +-------------------------------+--------+ -| Number of successful attacks: | 2 | -| Number of failed attacks: | 2 | +| Number of successful attacks: | 1 | +| Number of failed attacks: | 3 | | Number of skipped attacks: | 0 | | Original accuracy: | 100.0% | -| Accuracy under attack: | 50.0% | -| Attack success rate: | 50.0% | -| Average perturbed word %: | 4.55% | -| Average num. words per input: | 15.75 | -| Avg num queries: | 1.5 | +| Accuracy under attack: | 75.0% | +| Attack success rate: | 25.0% | +| Average perturbed word %: | 3.85% | +| Average num. words per input: | 15.5 | +| Avg num queries: | 1.25 | +-------------------------------+--------+ diff --git a/tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt b/tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt index e90e3c4f..e7f3f657 100644 --- a/tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt +++ b/tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n.txt @@ -23,10 +23,13 @@ --------------------------------------------- Result 2 --------------------------------------------- -Neutral (100%) --> [FAILED] +Neutral (100%) --> Entailment (56%) Premise: This site includes a list of all award winners and a searchable database of Government Executive articles. -Hypothesis: The Government Executive articles housed on the website are not able to be searched. +Hypothesis: The Government Executive articles housed on the website are not able to be searched. + +Premise: This site includes a list of all award winners and a searchable database of Government Executive articles. +Hypothesis: The Government Executive articles housed on the website are not able-bodied to be searched. --------------------------------------------- Result 3 --------------------------------------------- @@ -43,13 +46,13 @@ +-------------------------------+--------+ | Attack Results | | +-------------------------------+--------+ -| Number of successful attacks: | 1 | -| Number of failed attacks: | 1 | +| Number of successful attacks: | 2 | +| Number of failed attacks: | 0 | | Number of skipped attacks: | 1 | | Original accuracy: | 66.67% | -| Accuracy under attack: | 33.33% | -| Attack success rate: | 50.0% | -| Average perturbed word %: | 2.27% | -| Average num. words per input: | 29.0 | -| Avg num queries: | 447.5 | +| Accuracy under attack: | 0.0% | +| Attack success rate: | 100.0% | +| Average perturbed word %: | 2.78% | +| Average num. words per input: | 28.67 | +| Avg num queries: | 182.0 | +-------------------------------+--------+ diff --git a/tests/test_attacked_text.py b/tests/test_attacked_text.py index f249bd1a..d7bb4498 100644 --- a/tests/test_attacked_text.py +++ b/tests/test_attacked_text.py @@ -12,12 +12,27 @@ def attacked_text(): return textattack.shared.AttackedText(raw_text) +raw_pokemon_text = "the threat implied in the title pokémon 4ever is terrifying – like locusts in a horde these things will keep coming ." + + +@pytest.fixture +def pokemon_attacked_text(): + return textattack.shared.AttackedText(raw_pokemon_text) + + premise = "Among these are the red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes." hypothesis = "The Patan Museum is down the street from the red brick Royal Palace." raw_text_pair = collections.OrderedDict( [("premise", premise), ("hypothesis", hypothesis)] ) +raw_hyphenated_text = "It's a run-of-the-mill kind of farmer's tan." + + +@pytest.fixture +def hyphenated_text(): + return textattack.shared.AttackedText(raw_hyphenated_text) + @pytest.fixture def attacked_text_pair(): @@ -25,27 +40,13 @@ def attacked_text_pair(): class TestAttackedText: - def test_words(self, attacked_text): + def test_words(self, attacked_text, pokemon_attacked_text): + # fmt: off assert attacked_text.words == [ - "A", - "person", - "walks", - "up", - "stairs", - "into", - "a", - "room", - "and", - "sees", - "beer", - "poured", - "from", - "a", - "keg", - "and", - "people", - "talking", + "A", "person", "walks", "up", "stairs", "into", "a", "room", "and", "sees", "beer", "poured", "from", "a", "keg", "and", "people", "talking", ] + assert pokemon_attacked_text.words == ['the', 'threat', 'implied', 'in', 'the', 'title', 'pokémon', '4ever', 'is', 'terrifying', 'like', 'locusts', 'in', 'a', 'horde', 'these', 'things', 'will', 'keep', 'coming'] + # fmt: on def test_window_around_index(self, attacked_text): assert attacked_text.text_window_around_index(5, 1) == "into" @@ -69,8 +70,9 @@ class TestAttackedText: def test_window_around_index_end(self, attacked_text): assert attacked_text.text_window_around_index(17, 3) == "and people talking" - def test_text(self, attacked_text, attacked_text_pair): + def test_text(self, attacked_text, pokemon_attacked_text, attacked_text_pair): assert attacked_text.text == raw_text + assert pokemon_attacked_text.text == raw_pokemon_text assert attacked_text_pair.text == "\n".join(raw_text_pair.values()) def test_printable_text(self, attacked_text, attacked_text_pair): @@ -140,13 +142,13 @@ class TestAttackedText: + "\n" + "The Patan Museum is down the street from the red brick Royal Palace." ) - new_text = new_text.insert_text_after_word_index(38, "and shapes") + new_text = new_text.insert_text_after_word_index(37, "and shapes") assert new_text.text == ( "Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes." + "\n" + "The Patan Museum is down the street from the red brick Royal Palace." ) - new_text = new_text.insert_text_after_word_index(41, "The") + new_text = new_text.insert_text_after_word_index(40, "The") assert new_text.text == ( "Among these are the old decrepit red brick Royal Palace, which now houses the Patan Museum (Nepal's finest and most modern museum), and, facing the palace across the narrow brick plaza, eight temples of different styles and sizes and shapes." + "\n" @@ -163,7 +165,7 @@ class TestAttackedText: ) for old_idx, new_idx in enumerate(new_text.attack_attrs["original_index_map"]): assert (attacked_text.words[old_idx] == new_text.words[new_idx]) or ( - new_i == -1 + new_idx == -1 ) new_text = ( new_text.delete_word_at_index(0) @@ -180,3 +182,14 @@ class TestAttackedText: new_text.text == "person walks a very long way up stairs into a room and sees beer poured and people on the couch." ) + + def test_hyphen_apostrophe_words(self, hyphenated_text): + assert hyphenated_text.words == [ + "It's", + "a", + "run-of-the-mill", + "kind", + "of", + "farmer's", + "tan", + ] diff --git a/tests/test_command_line/test_attack.py b/tests/test_command_line/test_attack.py index 3741246e..dc0d2bb4 100644 --- a/tests/test_command_line/test_attack.py +++ b/tests/test_command_line/test_attack.py @@ -112,6 +112,24 @@ attack_test_params = [ ), # fmt: on # + # test: run_attack on LSTM MR using word embedding transformation and genetic algorithm. Simulate alzantot recipe without using expensive LM + ( + "run_attack_faster_alzantot_recipe", + ( + "textattack attack --model lstm-mr --recipe faster-alzantot --num-examples 3 --num-examples-offset 20" + ), + "tests/sample_outputs/run_attack_faster_alzantot_recipe.txt", + ), + # + # test: run_attack with kuleshov recipe and sst-2 cnn + # + ( + "run_attack_kuleshov_nn", + ( + "textattack attack --recipe kuleshov --num-examples 2 --model cnn-sst --attack-n --query-budget 200" + ), + "tests/sample_outputs/kuleshov_cnn_sst_2.txt", + ), ] diff --git a/tests/test_command_line/test_augment.py b/tests/test_command_line/test_augment.py index c8ea3605..0648e9f1 100644 --- a/tests/test_command_line/test_augment.py +++ b/tests/test_command_line/test_augment.py @@ -37,3 +37,5 @@ def test_command_line_augmentation(name, command, outfile, sample_output_file): # Ensure CSV file exists, then delete it. assert os.path.exists(outfile) os.remove(outfile) + + assert result.returncode == 0 diff --git a/tests/test_command_line/test_list.py b/tests/test_command_line/test_list.py index f4ac8287..96e4dcd4 100644 --- a/tests/test_command_line/test_list.py +++ b/tests/test_command_line/test_list.py @@ -27,3 +27,5 @@ def test_command_line_list(name, command, sample_output_file): print("stderr =>", stderr) assert stdout == desired_text + + assert result.returncode == 0 diff --git a/textattack/attack_recipes/PSO_zang_2020.py b/textattack/attack_recipes/PSO_zang_2020.py new file mode 100644 index 00000000..1cfb67e5 --- /dev/null +++ b/textattack/attack_recipes/PSO_zang_2020.py @@ -0,0 +1,56 @@ +from textattack.constraints.pre_transformation import ( + InputColumnModification, + RepeatModification, + StopwordModification, +) +from textattack.goal_functions import UntargetedClassification +from textattack.search_methods import PSOAlgorithm +from textattack.shared.attack import Attack +from textattack.transformations import WordSwapEmbedding, WordSwapHowNet + + +def PSOZang2020(model): + """ + Zang, Y., Yang, C., Qi, F., Liu, Z., Zhang, M., Liu, Q., & Sun, M. (2019). + + Word-level Textual Adversarial Attacking as Combinatorial Optimization. + + https://www.aclweb.org/anthology/2020.acl-main.540.pdf + + Methodology description quoted from the paper: + + "We propose a novel word substitution-based textual attack model, which reforms + both the aforementioned two steps. In the first step, we adopt a sememe-based word + substitution strategy, which can generate more candidate adversarial examples with + better semantic preservation. In the second step, we utilize particle swarm optimization + (Eberhart and Kennedy, 1995) as the adversarial example searching algorithm." + + And "Following the settings in Alzantot et al. (2018), we set the max iteration time G to 20." + """ + # + # Swap words with their synonyms extracted based on the HowNet. + # + transformation = WordSwapHowNet() + # + # Don't modify the same word twice or stopwords + # + constraints = [RepeatModification(), StopwordModification()] + # + # + # During entailment, we should only edit the hypothesis - keep the premise + # the same. + # + input_column_modification = InputColumnModification( + ["premise", "hypothesis"], {"premise"} + ) + constraints.append(input_column_modification) + # + # Use untargeted classification for demo, can be switched to targeted one + # + goal_function = UntargetedClassification(model) + # + # Perform word substitution with a Particle Swarm Optimization (PSO) algorithm. + # + search_method = PSOAlgorithm(pop_size=60, max_iters=20) + + return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/attack_recipes/__init__.py b/textattack/attack_recipes/__init__.py index 6f966749..f86092bb 100644 --- a/textattack/attack_recipes/__init__.py +++ b/textattack/attack_recipes/__init__.py @@ -4,8 +4,10 @@ from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018 from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019 from .deepwordbug_gao_2018 import DeepWordBugGao2018 from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017 +from .input_reduction_feng_2018 import InputReductionFeng2018 from .kuleshov_2017 import Kuleshov2017 from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox from .textbugger_li_2018 import TextBuggerLi2018 from .textfooler_jin_2019 import TextFoolerJin2019 from .pwws_ren_2019 import PWWSRen2019 +from .PSO_zang_2020 import PSOZang2020 diff --git a/textattack/attack_recipes/bert_attack_li_2020.py b/textattack/attack_recipes/bert_attack_li_2020.py index 955b35d2..2d5f6a51 100644 --- a/textattack/attack_recipes/bert_attack_li_2020.py +++ b/textattack/attack_recipes/bert_attack_li_2020.py @@ -20,14 +20,6 @@ def BERTAttackLi2020(model): This is "attack mode" 1 from the paper, BAE-R, word replacement. """ - from textattack.shared.utils import logger - - logger.warn( - "WARNING: This BERT-Attack implementation is based off of a" - " preliminary draft of the paper, which lacked source code and" - " did not include any hyperparameters. Attack reuslts are likely to" - " change." - ) # [from correspondence with the author] # Candidate size K is set to 48 for all data-sets. transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48) diff --git a/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py b/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py index fe09c62d..73cfe843 100644 --- a/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py +++ b/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py @@ -119,6 +119,6 @@ def FasterGeneticAlgorithmJia2019(model): # # Perform word substitution with a genetic algorithm. # - search_method = GeneticAlgorithm(pop_size=60, max_iters=20) + search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0) return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/attack_recipes/genetic_algorithm_alzantot_2018.py b/textattack/attack_recipes/genetic_algorithm_alzantot_2018.py index 6ff8bfec..5f6b2866 100644 --- a/textattack/attack_recipes/genetic_algorithm_alzantot_2018.py +++ b/textattack/attack_recipes/genetic_algorithm_alzantot_2018.py @@ -3,6 +3,7 @@ from textattack.constraints.grammaticality.language_models import ( ) from textattack.constraints.overlap import MaxWordsPerturbed from textattack.constraints.pre_transformation import ( + InputColumnModification, RepeatModification, StopwordModification, ) @@ -34,6 +35,14 @@ def GeneticAlgorithmAlzantot2018(model): # constraints = [RepeatModification(), StopwordModification()] # + # During entailment, we should only edit the hypothesis - keep the premise + # the same. + # + input_column_modification = InputColumnModification( + ["premise", "hypothesis"], {"premise"} + ) + constraints.append(input_column_modification) + # # Maximum words perturbed percentage of 20% # constraints.append(MaxWordsPerturbed(max_percent=0.2)) @@ -52,6 +61,6 @@ def GeneticAlgorithmAlzantot2018(model): # # Perform word substitution with a genetic algorithm. # - search_method = GeneticAlgorithm(pop_size=60, max_iters=20) + search_method = GeneticAlgorithm(pop_size=60, max_iters=20, max_crossover_retries=0) return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/attack_recipes/input_reduction_feng_2018.py b/textattack/attack_recipes/input_reduction_feng_2018.py new file mode 100644 index 00000000..aa921277 --- /dev/null +++ b/textattack/attack_recipes/input_reduction_feng_2018.py @@ -0,0 +1,40 @@ +from textattack.constraints.pre_transformation import ( + RepeatModification, + StopwordModification, +) +from textattack.goal_functions import InputReduction +from textattack.search_methods import GreedyWordSwapWIR +from textattack.shared.attack import Attack +from textattack.transformations import WordDeletion + + +def InputReductionFeng2018(model): + """ + Feng, Wallace, Grissom, Iyyer, Rodriguez, Boyd-Graber. (2018). + + Pathologies of Neural Models Make Interpretations Difficult. + + ArXiv, abs/1804.07781. + """ + # At each step, we remove the word with the lowest importance value until + # the model changes its prediction. + transformation = WordDeletion() + + constraints = [RepeatModification(), StopwordModification()] + # + # Goal is untargeted classification + # + goal_function = InputReduction(model, maximizable=True) + # + # "For each word in an input sentence, we measure its importance by the + # change in the confidence of the original prediction when we remove + # that word from the sentence." + # + # "Instead of looking at the words with high importance values—what + # interpretation methods commonly do—we take a complementary approach + # and study how the model behaves when the supposedly unimportant words are + # removed." + # + search_method = GreedyWordSwapWIR(wir_method="delete") + + return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/attack_recipes/pwws_ren_2019.py b/textattack/attack_recipes/pwws_ren_2019.py index db3e5a89..1c0431a5 100644 --- a/textattack/attack_recipes/pwws_ren_2019.py +++ b/textattack/attack_recipes/pwws_ren_2019.py @@ -26,5 +26,5 @@ def PWWSRen2019(model): constraints = [RepeatModification(), StopwordModification()] goal_function = UntargetedClassification(model) # search over words based on a combination of their saliency score, and how efficient the WordSwap transform is - search_method = GreedyWordSwapWIR("pwws", ascending=False) + search_method = GreedyWordSwapWIR("pwws") return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py b/textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py index 4a2c61d5..5e3fdf24 100644 --- a/textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py +++ b/textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py @@ -27,7 +27,7 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"): # Goal is non-overlapping output. # goal_function = NonOverlappingOutput(model) - # @TODO implement transformation / search method just like they do in + # TODO implement transformation / search method just like they do in # seq2sick. transformation = WordSwapEmbedding(max_candidates=50) # @@ -42,6 +42,6 @@ def Seq2SickCheng2018BlackBox(model, goal_function="non_overlapping"): # # Greedily swap words with "Word Importance Ranking". # - search_method = GreedyWordSwapWIR() + search_method = GreedyWordSwapWIR(wir_method="unk") return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/attack_recipes/textfooler_jin_2019.py b/textattack/attack_recipes/textfooler_jin_2019.py index 44d4b369..5f142580 100644 --- a/textattack/attack_recipes/textfooler_jin_2019.py +++ b/textattack/attack_recipes/textfooler_jin_2019.py @@ -1,5 +1,6 @@ from textattack.constraints.grammaticality import PartOfSpeech from textattack.constraints.pre_transformation import ( + InputColumnModification, RepeatModification, StopwordModification, ) @@ -35,6 +36,13 @@ def TextFoolerJin2019(model): # fmt: on constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)] # + # During entailment, we should only edit the hypothesis - keep the premise + # the same. + # + input_column_modification = InputColumnModification( + ["premise", "hypothesis"], {"premise"} + ) + constraints.append(input_column_modification) # Minimum word embedding cosine similarity of 0.5. # (The paper claims 0.7, but analysis of the released code and some empirical # results show that it's 0.5.) diff --git a/textattack/attack_results/__init__.py b/textattack/attack_results/__init__.py index 7e47b4d7..6d809276 100644 --- a/textattack/attack_results/__init__.py +++ b/textattack/attack_results/__init__.py @@ -1,3 +1,4 @@ +from .maximized_attack_result import MaximizedAttackResult from .failed_attack_result import FailedAttackResult from .skipped_attack_result import SkippedAttackResult from .successful_attack_result import SuccessfulAttackResult diff --git a/textattack/attack_results/attack_result.py b/textattack/attack_results/attack_result.py index 184af9f8..4ea3b5b4 100644 --- a/textattack/attack_results/attack_result.py +++ b/textattack/attack_results/attack_result.py @@ -13,11 +13,11 @@ class AttackResult: perturbed text. May or may not have been successful. """ - def __init__(self, original_result, perturbed_result, num_queries=0): + def __init__(self, original_result, perturbed_result): if original_result is None: raise ValueError("Attack original result cannot be None") elif not isinstance(original_result, GoalFunctionResult): - raise TypeError(f"Invalid original goal function result: {original_text}") + raise TypeError(f"Invalid original goal function result: {original_result}") if perturbed_result is None: raise ValueError("Attack perturbed result cannot be None") elif not isinstance(perturbed_result, GoalFunctionResult): @@ -27,7 +27,7 @@ class AttackResult: self.original_result = original_result self.perturbed_result = perturbed_result - self.num_queries = num_queries + self.num_queries = perturbed_result.num_queries # We don't want the AttackedText attributes sticking around clogging up # space on our devices. Delete them here, if they're still present, @@ -89,27 +89,34 @@ class AttackResult: i1 = 0 i2 = 0 - while i1 < len(t1.words) and i2 < len(t2.words): + while i1 < t1.num_words or i2 < t2.num_words: # show deletions - while t2.attack_attrs["original_index_map"][i1] == -1: + while ( + i1 < len(t2.attack_attrs["original_index_map"]) + and t2.attack_attrs["original_index_map"][i1] == -1 + ): words_1.append(utils.color_text(t1.words[i1], color_1, color_method)) words_1_idxs.append(i1) i1 += 1 # show insertions - while i2 < t2.attack_attrs["original_index_map"][i1]: + while ( + i1 < len(t2.attack_attrs["original_index_map"]) + and i2 < t2.attack_attrs["original_index_map"][i1] + ): words_2.append(utils.color_text(t1.words[i2], color_2, color_method)) words_2_idxs.append(i2) i2 += 1 # show swaps - word_1 = t1.words[i1] - word_2 = t2.words[i2] - if word_1 != word_2: - words_1.append(utils.color_text(word_1, color_1, color_method)) - words_2.append(utils.color_text(word_2, color_2, color_method)) - words_1_idxs.append(i1) - words_2_idxs.append(i2) - i1 += 1 - i2 += 1 + if i1 < t1.num_words and i2 < t2.num_words: + word_1 = t1.words[i1] + word_2 = t2.words[i2] + if word_1 != word_2: + words_1.append(utils.color_text(word_1, color_1, color_method)) + words_2.append(utils.color_text(word_2, color_2, color_method)) + words_1_idxs.append(i1) + words_2_idxs.append(i2) + i1 += 1 + i2 += 1 t1 = self.original_result.attacked_text.replace_words_at_indices( words_1_idxs, words_1 diff --git a/textattack/attack_results/failed_attack_result.py b/textattack/attack_results/failed_attack_result.py index a3d37833..df2d524c 100644 --- a/textattack/attack_results/failed_attack_result.py +++ b/textattack/attack_results/failed_attack_result.py @@ -6,9 +6,9 @@ from .attack_result import AttackResult class FailedAttackResult(AttackResult): """The result of a failed attack.""" - def __init__(self, original_result, perturbed_result=None, num_queries=0): + def __init__(self, original_result, perturbed_result=None): perturbed_result = perturbed_result or original_result - super().__init__(original_result, perturbed_result, num_queries) + super().__init__(original_result, perturbed_result) def str_lines(self, color_method=None): lines = ( diff --git a/textattack/attack_results/maximized_attack_result.py b/textattack/attack_results/maximized_attack_result.py new file mode 100644 index 00000000..96e3e22d --- /dev/null +++ b/textattack/attack_results/maximized_attack_result.py @@ -0,0 +1,5 @@ +from .attack_result import AttackResult + + +class MaximizedAttackResult(AttackResult): + """ The result of a successful attack. """ diff --git a/textattack/commands/attack/__init__.py b/textattack/commands/attack/__init__.py index afa629b7..1bcec084 100644 --- a/textattack/commands/attack/__init__.py +++ b/textattack/commands/attack/__init__.py @@ -1,2 +1,5 @@ from .attack_command import AttackCommand from .attack_resume_command import AttackResumeCommand + +from .run_attack_single_threaded import run as run_attack_single_threaded +from .run_attack_parallel import run as run_attack_parallel diff --git a/textattack/commands/attack/attack_args.py b/textattack/commands/attack/attack_args.py index 8c8bd732..ee546879 100644 --- a/textattack/commands/attack/attack_args.py +++ b/textattack/commands/attack/attack_args.py @@ -7,11 +7,13 @@ ATTACK_RECIPE_NAMES = { "faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019", "deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018", "hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017", + "input-reduction": "textattack.attack_recipes.InputReductionFeng2018", "kuleshov": "textattack.attack_recipes.Kuleshov2017", "seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox", "textbugger": "textattack.attack_recipes.TextBuggerLi2018", "textfooler": "textattack.attack_recipes.TextFoolerJin2019", "pwws": "textattack.attack_recipes.PWWSRen2019", + "pso": "textattack.attack_recipes.PSOZang2020", } # @@ -218,11 +220,22 @@ TEXTATTACK_DATASET_BY_MODEL = { ), # # Translation models - # TODO add proper `nlp` datasets for translation & summarization + "t5-en-de": ( + "english_to_german", + ("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"), + ), + "t5-en-fr": ( + "english_to_french", + ("textattack.datasets.translation.TedMultiTranslationDataset", "en", "fr"), + ), + "t5-en-ro": ( + "english_to_romanian", + ("textattack.datasets.translation.TedMultiTranslationDataset", "en", "de"), + ), # # Summarization models # - #'t5-summ': 'textattack.models.summarization.T5Summarization', + "t5-summarization": ("summarization", ("gigaword", None, "test")), } BLACK_BOX_TRANSFORMATION_CLASS_NAMES = { diff --git a/textattack/commands/attack/attack_args_helpers.py b/textattack/commands/attack/attack_args_helpers.py index bd5343ae..9461f944 100644 --- a/textattack/commands/attack/attack_args_helpers.py +++ b/textattack/commands/attack/attack_args_helpers.py @@ -332,7 +332,14 @@ def parse_dataset_from_args(args): if args.model in HUGGINGFACE_DATASET_BY_MODEL: _, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model] elif args.model in TEXTATTACK_DATASET_BY_MODEL: - _, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model] + _, dataset = TEXTATTACK_DATASET_BY_MODEL[args.model] + if dataset[0].startswith("textattack"): + # unsavory way to pass custom dataset classes + # ex: dataset = ('textattack.datasets.translation.TedMultiTranslationDataset', 'en', 'de') + dataset = eval(f"{dataset[0]}")(*dataset[1:]) + return dataset + else: + args.dataset_from_nlp = dataset # Automatically detect dataset for models trained with textattack. elif args.model and os.path.exists(args.model): model_args_json_path = os.path.join(args.model, "train_args.json") diff --git a/textattack/commands/attack/run_attack_parallel.py b/textattack/commands/attack/run_attack_parallel.py index 6aa365a9..e1fbc0fb 100644 --- a/textattack/commands/attack/run_attack_parallel.py +++ b/textattack/commands/attack/run_attack_parallel.py @@ -125,7 +125,10 @@ def run(args, checkpoint=None): pbar.update() num_results += 1 - if type(result) == textattack.attack_results.SuccessfulAttackResult: + if ( + type(result) == textattack.attack_results.SuccessfulAttackResult + or type(result) == textattack.attack_results.MaximizedAttackResult + ): num_successes += 1 if type(result) == textattack.attack_results.FailedAttackResult: num_failures += 1 @@ -170,6 +173,8 @@ def run(args, checkpoint=None): finish_time = time.time() textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s") + return attack_log_manager.results + def pytorch_multiprocessing_workaround(): # This is a fix for a known bug diff --git a/textattack/commands/attack/run_attack_single_threaded.py b/textattack/commands/attack/run_attack_single_threaded.py index fd1e9cc2..9cfdfc92 100644 --- a/textattack/commands/attack/run_attack_single_threaded.py +++ b/textattack/commands/attack/run_attack_single_threaded.py @@ -108,7 +108,10 @@ def run(args, checkpoint=None): num_results += 1 - if type(result) == textattack.attack_results.SuccessfulAttackResult: + if ( + type(result) == textattack.attack_results.SuccessfulAttackResult + or type(result) == textattack.attack_results.MaximizedAttackResult + ): num_successes += 1 if type(result) == textattack.attack_results.FailedAttackResult: num_failures += 1 @@ -139,6 +142,8 @@ def run(args, checkpoint=None): finish_time = time.time() textattack.shared.logger.info(f"Attack time: {time.time() - load_time}s") + return attack_log_manager.results + if __name__ == "__main__": run(get_args()) diff --git a/textattack/constraints/constraint.py b/textattack/constraints/constraint.py index f97d5945..fa12b6bb 100644 --- a/textattack/constraints/constraint.py +++ b/textattack/constraints/constraint.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +import textattack from textattack.shared.utils import default_class_repr @@ -71,10 +72,10 @@ class Constraint(ABC): transformed_text (AttackedText): The candidate transformed ``AttackedText``. reference_text (AttackedText): The ``AttackedText`` to compare against. """ - if not isinstance(transformed_text, AttackedText): + if not isinstance(transformed_text, textattack.shared.AttackedText): raise TypeError("transformed_text must be of type AttackedText") - if not isinstance(reference_text, AttackedText): - raise TypeError("reference_text must be of type AttackedText") + if not isinstance(current_text, textattack.shared.AttackedText): + raise TypeError("current_text must be of type AttackedText") try: if not self.check_compatibility( diff --git a/textattack/constraints/grammaticality/language_models/google_language_model/google_language_model.py b/textattack/constraints/grammaticality/language_models/google_language_model/google_language_model.py index 9f987806..3fd1da1e 100644 --- a/textattack/constraints/grammaticality/language_models/google_language_model/google_language_model.py +++ b/textattack/constraints/grammaticality/language_models/google_language_model/google_language_model.py @@ -49,7 +49,7 @@ class GoogleLanguageModel(Constraint): [t.words[word_swap_index] for t in transformed_texts] ) if self.print_step: - print(prefix, swapped_words, suffix) + print(prefix, swapped_words) probs = self.lm.get_words_probs(prefix, swapped_words) return probs diff --git a/textattack/constraints/grammaticality/language_models/gpt2.py b/textattack/constraints/grammaticality/language_models/gpt2.py index c5eecf10..db39be84 100644 --- a/textattack/constraints/grammaticality/language_models/gpt2.py +++ b/textattack/constraints/grammaticality/language_models/gpt2.py @@ -52,7 +52,7 @@ class GPT2(LanguageModelConstraint): probs = [] for attacked_text in text_list: - nxt_word_ids = self.tokenizer.encode(attacked_text.words[word_index]) + next_word_ids = self.tokenizer.encode(attacked_text.words[word_index]) next_word_prob = predictions[0, -1, next_word_ids[0]] probs.append(next_word_prob) diff --git a/textattack/constraints/grammaticality/language_models/language_model_constraint.py b/textattack/constraints/grammaticality/language_models/language_model_constraint.py index 893ea239..6ceffed6 100644 --- a/textattack/constraints/grammaticality/language_models/language_model_constraint.py +++ b/textattack/constraints/grammaticality/language_models/language_model_constraint.py @@ -3,7 +3,11 @@ from abc import abstractmethod from textattack.constraints import Constraint +<<<<<<< HEAD class LanguageModelConstraint(Constraint): +======= +class LanguageModelConstraint(Constraint, ABC): +>>>>>>> master """ Determines if two sentences have a swapped word that has a similar probability according to a language model. diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py b/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py index 3f1da0be..a082dcf6 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py @@ -11,8 +11,8 @@ from .language_model_helpers import QueryHandler class LearningToWriteLanguageModel(LanguageModelConstraint): """ A constraint based on the L2W language model. - The RNN-based language model from ``Learning to Write With Cooperative - Discriminators'' (Holtzman et al, 2018). + The RNN-based language model from "Learning to Write With Cooperative + Discriminators" (Holtzman et al, 2018). https://arxiv.org/pdf/1805.06087.pdf diff --git a/textattack/constraints/pre_transformation/__init__.py b/textattack/constraints/pre_transformation/__init__.py index e8c173d7..23f18589 100644 --- a/textattack/constraints/pre_transformation/__init__.py +++ b/textattack/constraints/pre_transformation/__init__.py @@ -1,3 +1,11 @@ +<<<<<<< HEAD from .stopword_modification import StopwordModification from .repeat_modification import RepeatModification +======= +from .pre_transformation_constraint import PreTransformationConstraint + +from .input_column_modification import InputColumnModification +>>>>>>> master from .max_word_index_modification import MaxWordIndexModification +from .repeat_modification import RepeatModification +from .stopword_modification import StopwordModification diff --git a/textattack/constraints/pre_transformation/input_column_modification.py b/textattack/constraints/pre_transformation/input_column_modification.py new file mode 100644 index 00000000..0e907f2b --- /dev/null +++ b/textattack/constraints/pre_transformation/input_column_modification.py @@ -0,0 +1,43 @@ +from textattack.constraints.pre_transformation import PreTransformationConstraint + + +class InputColumnModification(PreTransformationConstraint): + """ + A constraint disallowing the modification of words within a specific input + column. + + For example, can prevent modification of 'premise' during + entailment. + """ + + def __init__(self, matching_column_labels, columns_to_ignore): + self.matching_column_labels = matching_column_labels + self.columns_to_ignore = columns_to_ignore + + def _get_modifiable_indices(self, current_text): + """ Returns the word indices in current_text which are able to be + deleted. + + If ``current_text.column_labels`` doesn't match + ``self.matching_column_labels``, do nothing, and allow all words + to be modified. + + If it does match, only allow words to be modified if they are not + in columns from ``columns_to_ignore``. + """ + if current_text.column_labels != self.matching_column_labels: + return set(range(len(current_text.words))) + + idx = 0 + indices_to_modify = set() + for column, words in zip( + current_text.column_labels, current_text.words_per_input + ): + num_words = len(words) + if column not in self.columns_to_ignore: + indices_to_modify |= set(range(idx, idx + num_words)) + idx += num_words + return indices_to_modify + + def extra_repr_keys(self): + return ["matching_column_labels", "columns_to_ignore"] diff --git a/textattack/constraints/pre_transformation/max_word_index_modification.py b/textattack/constraints/pre_transformation/max_word_index_modification.py index b13890d0..efe7f4c2 100644 --- a/textattack/constraints/pre_transformation/max_word_index_modification.py +++ b/textattack/constraints/pre_transformation/max_word_index_modification.py @@ -13,3 +13,6 @@ class MaxWordIndexModification(PreTransformationConstraint): def _get_modifiable_indices(self, current_text): """ Returns the word indices in current_text which are able to be deleted """ return set(range(min(self.max_length, len(current_text.words)))) + + def extra_repr_keys(self): + return ["max_length"] diff --git a/textattack/constraints/pre_transformation_constraint.py b/textattack/constraints/pre_transformation_constraint.py index f235b789..8b5084ac 100644 --- a/textattack/constraints/pre_transformation_constraint.py +++ b/textattack/constraints/pre_transformation_constraint.py @@ -57,4 +57,9 @@ class PreTransformationConstraint(ABC): """ return [] - __str__ = __repr__ = default_class_repr \ No newline at end of file + def _check_constraint(self): + raise RuntimeError( + "PreTransformationConstraints do not support `_check_constraint()`." + ) + + __str__ = __repr__ = default_class_repr diff --git a/textattack/constraints/semantics/sentence_encoders/sentence_encoder.py b/textattack/constraints/semantics/sentence_encoders/sentence_encoder.py index ebe98483..69329c71 100644 --- a/textattack/constraints/semantics/sentence_encoders/sentence_encoder.py +++ b/textattack/constraints/semantics/sentence_encoders/sentence_encoder.py @@ -73,7 +73,9 @@ class SentenceEncoder(Constraint): The similarity between the starting and transformed text using the metric. """ try: - modified_index = next(iter(x_adv.attack_attrs["newly_modified_indices"])) + modified_index = next( + iter(transformed_text.attack_attrs["newly_modified_indices"]) + ) except KeyError: raise KeyError( "Cannot apply sentence encoder constraint without `newly_modified_indices`" @@ -112,7 +114,7 @@ class SentenceEncoder(Constraint): ``transformed_texts``. If ``transformed_texts`` is empty, an empty tensor is returned """ - # Return an empty tensor if x_adv_list is empty. + # Return an empty tensor if transformed_texts is empty. # This prevents us from calling .repeat(x, 0), which throws an # error on machines with multiple GPUs (pytorch 1.2). if len(transformed_texts) == 0: @@ -142,9 +144,9 @@ class SentenceEncoder(Constraint): ) ) embeddings = self.encode(starting_text_windows + transformed_text_windows) - starting_embeddings = torch.tensor(embeddings[: len(transformed_texts)]).to( - utils.device - ) + if not isinstance(embeddings, torch.Tensor): + embeddings = torch.tensor(embeddings) + starting_embeddings = embeddings[: len(transformed_texts)].to(utils.device) transformed_embeddings = torch.tensor( embeddings[len(transformed_texts) :] ).to(utils.device) @@ -152,18 +154,12 @@ class SentenceEncoder(Constraint): starting_raw_text = starting_text.text transformed_raw_texts = [t.text for t in transformed_texts] embeddings = self.encode([starting_raw_text] + transformed_raw_texts) - if isinstance(embeddings[0], torch.Tensor): - starting_embedding = embeddings[0].to(utils.device) - else: - # If the embedding is not yet a tensor, make it one. - starting_embedding = torch.tensor(embeddings[0]).to(utils.device) + if not isinstance(embeddings, torch.Tensor): + embeddings = torch.tensor(embeddings) - if isinstance(embeddings, list): - # If `encode` did not return a Tensor of all embeddings, combine - # into a tensor. - transformed_embeddings = torch.stack(embeddings[1:]).to(utils.device) - else: - transformed_embeddings = torch.tensor(embeddings[1:]).to(utils.device) + starting_embedding = embeddings[0].to(utils.device) + + transformed_embeddings = embeddings[1:].to(utils.device) # Repeat original embedding to size of perturbed embedding. starting_embeddings = starting_embedding.unsqueeze(dim=0).repeat( diff --git a/textattack/constraints/semantics/sentence_encoders/thought_vector.py b/textattack/constraints/semantics/sentence_encoders/thought_vector.py index cd46c39f..5704b5c0 100644 --- a/textattack/constraints/semantics/sentence_encoders/thought_vector.py +++ b/textattack/constraints/semantics/sentence_encoders/thought_vector.py @@ -36,7 +36,7 @@ class ThoughtVector(SentenceEncoder): return torch.mean(embeddings, dim=0) def encode(self, raw_text_list): - return [self._get_thought_vector(text) for text in raw_text_list] + return torch.stack([self._get_thought_vector(text) for text in raw_text_list]) def extra_repr_keys(self): """Set the extra representation of the constraint using these keys. diff --git a/textattack/constraints/semantics/word_embedding_distance.py b/textattack/constraints/semantics/word_embedding_distance.py index b60687e9..ffe3d670 100644 --- a/textattack/constraints/semantics/word_embedding_distance.py +++ b/textattack/constraints/semantics/word_embedding_distance.py @@ -51,7 +51,7 @@ class WordEmbeddingDistance(Constraint): mse_dist_file = "mse_dist.p" cos_sim_file = "cos_sim.p" else: - raise ValueError(f"Could not find word embedding {word_embedding}") + raise ValueError(f"Could not find word embedding {embedding_type}") # Download embeddings if they're not cached. word_embeddings_path = utils.download_if_needed(WordEmbeddingDistance.PATH) diff --git a/textattack/datasets/__init__.py b/textattack/datasets/__init__.py index 607109f6..eff50f47 100644 --- a/textattack/datasets/__init__.py +++ b/textattack/datasets/__init__.py @@ -1,6 +1,4 @@ from .dataset import TextAttackDataset from .huggingface_nlp_dataset import HuggingFaceNLPDataset -from . import classification -from . import entailment from . import translation diff --git a/textattack/datasets/classification/__init__.py b/textattack/datasets/classification/__init__.py deleted file mode 100644 index 0f39d181..00000000 --- a/textattack/datasets/classification/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .ag_news import AGNews -from .kaggle_fake_news import KaggleFakeNews diff --git a/textattack/datasets/classification/ag_news.py b/textattack/datasets/classification/ag_news.py deleted file mode 100644 index 068a0cb1..00000000 --- a/textattack/datasets/classification/ag_news.py +++ /dev/null @@ -1,44 +0,0 @@ -from textattack.shared import utils - -from .classification_dataset import ClassificationDataset - - -class AGNews(ClassificationDataset): - """ - Loads samples from the AG News Dataset. - - AG is a collection of more than 1 million news articles. News articles have - been gathered from more than 2000 news sources by ComeToMyHead in more than - 1 year of activity. ComeToMyHead is an academic news search engine which has - been running since July, 2004. The dataset is provided by the academic - community for research purposes in data mining (clustering, classification, - etc), information retrieval (ranking, search, etc), xml, data compression, - data streaming, and any other non-commercial activity. For more information, - please refer to the link - http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html. - - The AG's news topic classification dataset was constructed by Xiang Zhang - (xiang.zhang@nyu.edu) from the dataset above. It is used as a text - classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, - Yann LeCun. Character-level Convolutional Networks for Text Classification. - Advances in Neural Information Processing Systems 28 (NIPS 2015). - - Labels - 0: World - 1: Sports - 2: Business - 3: Sci/Tech - - Args: - offset (int): line to start reading from - shuffle (bool): If True, randomly shuffle loaded data - - """ - - DATA_PATH = "datasets/classification/ag_news.txt" - - def __init__(self, offset=0, shuffle=False): - """ Loads a full dataset from disk. """ - self._load_classification_text_file( - AGNews.DATA_PATH, offset=offset, shuffle=shuffle - ) diff --git a/textattack/datasets/classification/classification_dataset.py b/textattack/datasets/classification/classification_dataset.py deleted file mode 100644 index e3471142..00000000 --- a/textattack/datasets/classification/classification_dataset.py +++ /dev/null @@ -1,13 +0,0 @@ -from textattack.datasets import TextAttackDataset - - -class ClassificationDataset(TextAttackDataset): - """ - A generic class for loading classification data. - """ - - def _process_example_from_file(self, raw_line): - tokens = raw_line.strip().split() - label = int(tokens[0]) - text = " ".join(tokens[1:]) - return (text, label) diff --git a/textattack/datasets/classification/kaggle_fake_news.py b/textattack/datasets/classification/kaggle_fake_news.py deleted file mode 100644 index 1dab452c..00000000 --- a/textattack/datasets/classification/kaggle_fake_news.py +++ /dev/null @@ -1,24 +0,0 @@ -from .classification_dataset import ClassificationDataset - - -class KaggleFakeNews(ClassificationDataset): - """ - Loads samples from the Kaggle Fake News dataset. https://www.kaggle.com/mrisdal/fake-news - - Labels - 0: Real Article - 1: Fake Article - - Args: - offset (int): line to start reading from - shuffle (bool): If True, randomly shuffle loaded data - - """ - - DATA_PATH = "datasets/classification/fake" - - def __init__(self, offset=0, shuffle=False): - """ Loads a full dataset from disk. """ - self._load_classification_text_file( - KaggleFakeNews.DATA_PATH, offset=offset, shuffle=shuffle - ) diff --git a/textattack/datasets/entailment/__init__.py b/textattack/datasets/entailment/__init__.py deleted file mode 100644 index f296c8da..00000000 --- a/textattack/datasets/entailment/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .snli import SNLI diff --git a/textattack/datasets/entailment/entailment_dataset.py b/textattack/datasets/entailment/entailment_dataset.py deleted file mode 100644 index b7cca93f..00000000 --- a/textattack/datasets/entailment/entailment_dataset.py +++ /dev/null @@ -1,38 +0,0 @@ -import collections - -from textattack.datasets import TextAttackDataset -from textattack.shared import AttackedText - - -class EntailmentDataset(TextAttackDataset): - """ - A generic class for loading entailment data. - - Labels - 0: Entailment - 1: Neutral - 2: Contradiction - """ - - def _label_str_to_int(self, label_str): - if label_str == "entailment": - return 0 - elif label_str == "neutral": - return 1 - elif label_str == "contradiction": - return 2 - else: - raise ValueError(f"Unknown entailment label {label_str}") - - def _process_example_from_file(self, raw_line): - line = raw_line.strip() - label, premise, hypothesis = line.split("\t") - try: - label = int(label) - except ValueError: - # If the label is not an integer, it's a label description. - label = self._label_str_to_int(label) - text_input = collections.OrderedDict( - [("premise", premise), ("hypothesis", hypothesis),] - ) - return (text_input, label) diff --git a/textattack/datasets/entailment/snli.py b/textattack/datasets/entailment/snli.py deleted file mode 100644 index 9137535c..00000000 --- a/textattack/datasets/entailment/snli.py +++ /dev/null @@ -1,25 +0,0 @@ -from .entailment_dataset import EntailmentDataset - - -class SNLI(EntailmentDataset): - """ - Loads samples from the SNLI dataset. - - Labels - 0: Entailment - 1: Neutral - 2: Contradiction - - Args: - offset (int): line to start reading from - shuffle (bool): If True, randomly shuffle loaded data - - """ - - DATA_PATH = "datasets/entailment/snli" - - def __init__(self, offset=0, shuffle=False): - """ Loads a full dataset from disk. """ - self._load_classification_text_file( - SNLI.DATA_PATH, offset=offset, shuffle=shuffle - ) diff --git a/textattack/datasets/huggingface_nlp_dataset.py b/textattack/datasets/huggingface_nlp_dataset.py index d0da279d..1bd8fb37 100644 --- a/textattack/datasets/huggingface_nlp_dataset.py +++ b/textattack/datasets/huggingface_nlp_dataset.py @@ -35,6 +35,12 @@ def get_nlp_dataset_columns(dataset): elif {"sentence", "label"} <= schema: input_columns = ("sentence",) output_column = "label" + elif {"document", "summary"} <= schema: + input_columns = ("document",) + output_column = "summary" + elif {"content", "summary"} <= schema: + input_columns = ("content",) + output_column = "summary" else: raise ValueError( f"Unsupported dataset schema {schema}. Try loading dataset manually (from a file) instead." @@ -47,18 +53,17 @@ class HuggingFaceNLPDataset(TextAttackDataset): """ Loads a dataset from HuggingFace ``nlp`` and prepares it as a TextAttack dataset. - name: the dataset name - subset: the subset of the main dataset. Dataset will be loaded as - ``nlp.load_dataset(name, subset)``. - label_map: Mapping if output labels should be re-mapped. Useful - if model was trained with a different label arrangement than - provided in the ``nlp`` version of the dataset. - output_scale_factor (float): Factor to divide ground-truth outputs by. + - name: the dataset name + - subset: the subset of the main dataset. Dataset will be loaded as ``nlp.load_dataset(name, subset)``. + - label_map: Mapping if output labels should be re-mapped. Useful + if model was trained with a different label arrangement than + provided in the ``nlp`` version of the dataset. + - output_scale_factor (float): Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs - between 0 and 1. Some datasets test the model's *correlation* + between 0 and 1. Some datasets test the model's \*correlation\* with ground-truth output, instead of its accuracy, so these outputs may be scaled arbitrarily. - shuffle (bool): Whether to shuffle the dataset on load. + - shuffle (bool): Whether to shuffle the dataset on load. """ @@ -72,6 +77,7 @@ class HuggingFaceNLPDataset(TextAttackDataset): dataset_columns=None, shuffle=False, ): + self._name = name self._dataset = nlp.load_dataset(name, subset)[split] subset_print_str = f", subset {_cb(subset)}" if subset else "" textattack.shared.logger.info( diff --git a/textattack/datasets/translation/__init__.py b/textattack/datasets/translation/__init__.py index fdff0cd3..ba38a247 100644 --- a/textattack/datasets/translation/__init__.py +++ b/textattack/datasets/translation/__init__.py @@ -1 +1 @@ -from .translation_datasets import * +from .ted_multi import TedMultiTranslationDataset diff --git a/textattack/datasets/translation/ted_multi.py b/textattack/datasets/translation/ted_multi.py new file mode 100644 index 00000000..561ecbaf --- /dev/null +++ b/textattack/datasets/translation/ted_multi.py @@ -0,0 +1,38 @@ +import collections + +import nlp +import numpy as np + +from textattack.datasets import HuggingFaceNLPDataset + + +class TedMultiTranslationDataset(HuggingFaceNLPDataset): + """ Loads examples from the Ted Talk translation dataset using the `nlp` + package. + + dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/ + """ + + def __init__(self, source_lang="en", target_lang="de", split="test"): + self._dataset = nlp.load_dataset("ted_multi")[split] + self.examples = self._dataset["translations"] + language_options = set(self.examples[0]["language"]) + if source_lang not in language_options: + raise ValueError( + f"Source language {source_lang} invalid. Choices: {sorted(language_options)}" + ) + if target_lang not in language_options: + raise ValueError( + f"Target language {target_lang} invalid. Choices: {sorted(language_options)}" + ) + self.source_lang = source_lang + self.target_lang = target_lang + self.label_names = ("Translation",) + + def _format_raw_example(self, raw_example): + translations = np.array(raw_example["translation"]) + languages = np.array(raw_example["language"]) + source = translations[languages == self.source_lang][0] + target = translations[languages == self.target_lang][0] + source_dict = collections.OrderedDict([("Source", source)]) + return (source_dict, target) diff --git a/textattack/datasets/translation/translation_datasets.py b/textattack/datasets/translation/translation_datasets.py deleted file mode 100644 index 8c97d95e..00000000 --- a/textattack/datasets/translation/translation_datasets.py +++ /dev/null @@ -1,24 +0,0 @@ -from textattack.datasets import TextAttackDataset - - -class NewsTest2013EnglishToGerman(TextAttackDataset): - """ - Loads samples from newstest2013 dataset from the publicly available - WMT2016 translation task. (This is from the 'news' portion of - WMT2016. See http://www.statmt.org/wmt16/ for details.) Dataset - sourced from GluonNLP library. - - Samples are loaded as (input, translation) tuples of string pairs. - - Args: - offset (int): line to start reading from - shuffle (bool): If True, randomly shuffle loaded data - - """ - - DATA_PATH = "datasets/translation/NewsTest2013EnglishToGerman" - - def __init__(self, offset=0, shuffle=False): - self._load_pickle_file(NewsTest2013EnglishToGerman.DATA_PATH, offset=offset) - if shuffle: - self._shuffle_data() diff --git a/textattack/goal_function_results/__init__.py b/textattack/goal_function_results/__init__.py index 004d763e..04b431ed 100644 --- a/textattack/goal_function_results/__init__.py +++ b/textattack/goal_function_results/__init__.py @@ -1,4 +1,4 @@ -from .goal_function_result import GoalFunctionResult +from .goal_function_result import GoalFunctionResult, GoalFunctionResultStatus from .classification_goal_function_result import ClassificationGoalFunctionResult from .text_to_text_goal_function_result import TextToTextGoalFunctionResult diff --git a/textattack/goal_function_results/goal_function_result.py b/textattack/goal_function_results/goal_function_result.py index 9da0688e..1ebac848 100644 --- a/textattack/goal_function_results/goal_function_result.py +++ b/textattack/goal_function_results/goal_function_result.py @@ -1,6 +1,13 @@ import torch +class GoalFunctionResultStatus: + SUCCEEDED = 0 + SEARCHING = 1 # In process of searching for a success + MAXIMIZING = 2 + SKIPPED = 3 + + class GoalFunctionResult: """ Represents the result of a goal function evaluating a AttackedText object. @@ -8,16 +15,29 @@ class GoalFunctionResult: Args: attacked_text: The sequence that was evaluated. output: The display-friendly output. - succeeded: Whether the goal has been achieved. + goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal. score: A score representing how close the model is to achieving its goal. + num_queries: How many model queries have been used + ground_truth_output: The ground truth output """ - def __init__(self, attacked_text, raw_output, output, succeeded, score): + def __init__( + self, + attacked_text, + raw_output, + output, + goal_status, + score, + num_queries, + ground_truth_output, + ): self.attacked_text = attacked_text self.raw_output = raw_output self.output = output self.score = score - self.succeeded = succeeded + self.goal_status = goal_status + self.num_queries = num_queries + self.ground_truth_output = ground_truth_output if isinstance(self.raw_output, torch.Tensor): self.raw_output = self.raw_output.cpu() @@ -25,9 +45,6 @@ class GoalFunctionResult: if isinstance(self.score, torch.Tensor): self.score = self.score.item() - if isinstance(self.succeeded, torch.Tensor): - self.succeeded = self.succeeded.item() - def get_text_color_input(self): """ A string representing the color this result's changed portion should be if it represents the original input. diff --git a/textattack/goal_functions/classification/__init__.py b/textattack/goal_functions/classification/__init__.py index bfb7f327..8f9ba8d9 100644 --- a/textattack/goal_functions/classification/__init__.py +++ b/textattack/goal_functions/classification/__init__.py @@ -1,2 +1,3 @@ +from .input_reduction import InputReduction from .untargeted_classification import UntargetedClassification from .targeted_classification import TargetedClassification diff --git a/textattack/goal_functions/classification/classification_goal_function.py b/textattack/goal_functions/classification/classification_goal_function.py index 5647f6bb..f6dcb55e 100644 --- a/textattack/goal_functions/classification/classification_goal_function.py +++ b/textattack/goal_functions/classification/classification_goal_function.py @@ -52,3 +52,6 @@ class ClassificationGoalFunction(GoalFunction): def extra_repr_keys(self): return [] + + def _get_displayed_output(self, raw_output): + return int(raw_output.argmax()) diff --git a/textattack/goal_functions/classification/input_reduction.py b/textattack/goal_functions/classification/input_reduction.py new file mode 100644 index 00000000..96186f7b --- /dev/null +++ b/textattack/goal_functions/classification/input_reduction.py @@ -0,0 +1,44 @@ +from .classification_goal_function import ClassificationGoalFunction + + +class InputReduction(ClassificationGoalFunction): + """ + Attempts to reduce the input down to as few words as possible while maintaining + the same predicted label. + + From Feng, Wallace, Grissom, Iyyer, Rodriguez, Boyd-Graber. (2018). + Pathologies of Neural Models Make Interpretations Difficult. + ArXiv, abs/1804.07781. + """ + + def __init__(self, *args, target_num_words=1, **kwargs): + self.target_num_words = target_num_words + super().__init__(*args, **kwargs) + + def _is_goal_complete(self, model_output, attacked_text): + return ( + self.ground_truth_output == model_output.argmax() + and attacked_text.num_words <= self.target_num_words + ) + + def _should_skip(self, model_output, attacked_text): + return self.ground_truth_output != model_output.argmax() + + def _get_score(self, model_output, attacked_text): + # Give the lowest score possible to inputs which don't maintain the ground truth label. + if self.ground_truth_output != model_output.argmax(): + return 0 + + cur_num_words = attacked_text.num_words + initial_num_words = self.initial_attacked_text.num_words + + # The main goal is to reduce the number of words (num_words_score) + # Higher model score for the ground truth label is used as a tiebreaker (model_score) + num_words_score = max( + (initial_num_words - cur_num_words) / initial_num_words, 0 + ) + model_score = model_output[self.ground_truth_output] + return min(num_words_score + model_score / initial_num_words, 1) + + def extra_repr_keys(self): + return ["target_num_words"] diff --git a/textattack/goal_functions/classification/targeted_classification.py b/textattack/goal_functions/classification/targeted_classification.py index 68e11abe..fe223465 100644 --- a/textattack/goal_functions/classification/targeted_classification.py +++ b/textattack/goal_functions/classification/targeted_classification.py @@ -3,18 +3,18 @@ from .classification_goal_function import ClassificationGoalFunction class TargetedClassification(ClassificationGoalFunction): """ - An targeted attack on classification models which attempts to maximize the - score of the target label until it is the predicted label. + A targeted attack on classification models which attempts to maximize the + score of the target label. Complete when the arget label is the predicted label. """ - def __init__(self, model, target_class=0): - super().__init__(model) + def __init__(self, *args, target_class=0, **kwargs): + super().__init__(*args, **kwargs) self.target_class = target_class - def _is_goal_complete(self, model_output, ground_truth_output): + def _is_goal_complete(self, model_output, _): return ( self.target_class == model_output.argmax() - ) or ground_truth_output == self.target_class + ) or self.ground_truth_output == self.target_class def _get_score(self, model_output, _): if self.target_class < 0 or self.target_class >= len(model_output): @@ -24,8 +24,5 @@ class TargetedClassification(ClassificationGoalFunction): else: return model_output[self.target_class] - def _get_displayed_output(self, raw_output): - return int(raw_output.argmax()) - def extra_repr_keys(self): return ["target_class"] diff --git a/textattack/goal_functions/classification/untargeted_classification.py b/textattack/goal_functions/classification/untargeted_classification.py index e837a13f..5ceb9eb1 100644 --- a/textattack/goal_functions/classification/untargeted_classification.py +++ b/textattack/goal_functions/classification/untargeted_classification.py @@ -16,23 +16,22 @@ class UntargetedClassification(ClassificationGoalFunction): self.target_max_score = target_max_score super().__init__(*args, **kwargs) - def _is_goal_complete(self, model_output, ground_truth_output): + def _is_goal_complete(self, model_output, _): if self.target_max_score: - return model_output[ground_truth_output] < self.target_max_score - elif (model_output.numel() == 1) and isinstance(ground_truth_output, float): - return abs(ground_truth_output - model_output.item()) >= ( + return model_output[self.ground_truth_output] < self.target_max_score + elif (model_output.numel() == 1) and isinstance( + self.ground_truth_output, float + ): + return abs(self.ground_truth_output - model_output.item()) >= ( self.target_max_score or 0.5 ) else: - return model_output.argmax() != ground_truth_output + return model_output.argmax() != self.ground_truth_output - def _get_score(self, model_output, ground_truth_output): + def _get_score(self, model_output, _): # If the model outputs a single number and the ground truth output is # a float, we assume that this is a regression task. - if (model_output.numel() == 1) and isinstance(ground_truth_output, float): - return abs(model_output.item() - ground_truth_output) + if (model_output.numel() == 1) and isinstance(self.ground_truth_output, float): + return abs(model_output.item() - self.ground_truth_output) else: - return 1 - model_output[ground_truth_output] - - def _get_displayed_output(self, raw_output): - return int(raw_output.argmax()) + return 1 - model_output[self.ground_truth_output] diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index ec93c047..8aabad58 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -1,19 +1,25 @@ +from abc import ABC, abstractmethod import math import lru import numpy as np import torch +from textattack.goal_function_results.goal_function_result import ( + GoalFunctionResultStatus, +) from textattack.shared import utils, validators from textattack.shared.utils import batch_model_predict, default_class_repr -class GoalFunction: +class GoalFunction(ABC): """ Evaluates how well a perturbed attacked_text object is achieving a specified goal. Args: model: The model used for evaluation. + maximizable: Whether the goal function is maximizable, as opposed to a boolean result + of success or failure. query_budget (float): The maximum number of model queries allowed. model_batch_size (int): The batch size for making calls to the model model_cache_size (int): The maximum number of items to keep in the model @@ -23,6 +29,7 @@ class GoalFunction: def __init__( self, model, + maximizable=False, tokenizer=None, use_cache=True, query_budget=float("inf"), @@ -33,6 +40,7 @@ class GoalFunction: self.__class__, model.__class__ ) self.model = model + self.maximizable = maximizable self.tokenizer = tokenizer if not self.tokenizer: if hasattr(self.model, "tokenizer"): @@ -42,7 +50,6 @@ class GoalFunction: if not hasattr(self.tokenizer, "encode"): raise TypeError("Tokenizer must contain `encode()` method") self.use_cache = use_cache - self.num_queries = 0 self.query_budget = query_budget self.model_batch_size = model_batch_size if self.use_cache: @@ -50,13 +57,16 @@ class GoalFunction: else: self._call_model_cache = None - def should_skip(self, attacked_text, ground_truth_output): + def init_attack_example(self, attacked_text, ground_truth_output): """ - Returns whether or not the goal has already been completed for ``attacked_text``, - due to misprediction by the model. + Called before attacking ``attacked_text`` to 'reset' the goal + function and set properties for this example. """ - model_outputs = self._call_model([attacked_text]) - return self._is_goal_complete(model_outputs[0], ground_truth_output) + self.initial_attacked_text = attacked_text + self.ground_truth_output = ground_truth_output + self.num_queries = 0 + result, _ = self.get_result(attacked_text, check_skip=True) + return result, _ def get_output(self, attacked_text): """ @@ -64,16 +74,16 @@ class GoalFunction: """ return self._get_displayed_output(self._call_model([attacked_text])[0]) - def get_result(self, attacked_text, ground_truth_output): + def get_result(self, attacked_text, **kwargs): """ - A helper method that queries `self.get_results` with a single + A helper method that queries ``self.get_results`` with a single ``AttackedText`` object. """ - results, search_over = self.get_results([attacked_text], ground_truth_output) + results, search_over = self.get_results([attacked_text], **kwargs) result = results[0] if len(results) else None return result, search_over - def get_results(self, attacked_text_list, ground_truth_output): + def get_results(self, attacked_text_list, check_skip=False): """ For each attacked_text object in attacked_text_list, returns a result consisting of whether or not the goal has been achieved, the output for @@ -88,34 +98,55 @@ class GoalFunction: model_outputs = self._call_model(attacked_text_list) for attacked_text, raw_output in zip(attacked_text_list, model_outputs): displayed_output = self._get_displayed_output(raw_output) - succeeded = self._is_goal_complete(raw_output, ground_truth_output) - goal_function_score = self._get_score(raw_output, ground_truth_output) + goal_status = self._get_goal_status( + raw_output, attacked_text, check_skip=check_skip + ) + goal_function_score = self._get_score(raw_output, attacked_text) results.append( self._goal_function_result_type()( attacked_text, raw_output, displayed_output, - succeeded, + goal_status, goal_function_score, + self.num_queries, + self.ground_truth_output, ) ) return results, self.num_queries == self.query_budget - def _is_goal_complete(self, model_output, ground_truth_output): + def _get_goal_status(self, model_output, attacked_text, check_skip=False): + should_skip = check_skip and self._should_skip(model_output, attacked_text) + if should_skip: + return GoalFunctionResultStatus.SKIPPED + if self.maximizable: + return GoalFunctionResultStatus.MAXIMIZING + if self._is_goal_complete(model_output, attacked_text): + return GoalFunctionResultStatus.SUCCEEDED + return GoalFunctionResultStatus.SEARCHING + + @abstractmethod + def _is_goal_complete(self, model_output, attacked_text): raise NotImplementedError() - def _get_score(self, model_output, ground_truth_output): + def _should_skip(self, model_output, attacked_text): + return self._is_goal_complete(model_output, attacked_text) + + @abstractmethod + def _get_score(self, model_output, attacked_text): raise NotImplementedError() def _get_displayed_output(self, raw_output): return raw_output + @abstractmethod def _goal_function_result_type(self): """ Returns the class of this goal function's results. """ raise NotImplementedError() + @abstractmethod def _process_model_outputs(self, inputs, outputs): """ Processes and validates a list of model outputs. @@ -142,7 +173,7 @@ class GoalFunction: return self._process_model_outputs(attacked_text_list, outputs) def _call_model(self, attacked_text_list): - """ Gets predictions for a list of `AttackedText` objects. + """ Gets predictions for a list of ``AttackedText`` objects. Gets prediction from cache if possible. If prediction is not in the cache, queries model and stores prediction in cache. diff --git a/textattack/goal_functions/text/non_overlapping_output.py b/textattack/goal_functions/text/non_overlapping_output.py index 4b0b060c..fcd8652b 100644 --- a/textattack/goal_functions/text/non_overlapping_output.py +++ b/textattack/goal_functions/text/non_overlapping_output.py @@ -14,15 +14,15 @@ class NonOverlappingOutput(TextToTextGoalFunction): Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3). """ - def _is_goal_complete(self, model_output, ground_truth_output): - return self._get_score(model_output, ground_truth_output) == 1.0 + def _is_goal_complete(self, model_output, _): + return self._get_score(model_output, self.ground_truth_output) == 1.0 - def _get_score(self, model_output, ground_truth_output): - num_words_diff = word_difference_score(model_output, ground_truth_output) + def _get_score(self, model_output, _): + num_words_diff = word_difference_score(model_output, self.ground_truth_output) if num_words_diff == 0: return 0.0 else: - return num_words_diff / len(get_words_cached(ground_truth_output)) + return num_words_diff / len(get_words_cached(self.ground_truth_output)) @functools.lru_cache(maxsize=2 ** 12) diff --git a/textattack/goal_functions/text/text_to_text_goal_function.py b/textattack/goal_functions/text/text_to_text_goal_function.py index 1e92b0e8..45d915f8 100644 --- a/textattack/goal_functions/text/text_to_text_goal_function.py +++ b/textattack/goal_functions/text/text_to_text_goal_function.py @@ -9,9 +9,6 @@ class TextToTextGoalFunction(GoalFunction): original_output: the original output of the model """ - def __init__(self, model): - super().__init__(model) - def _goal_function_result_type(self): """ Returns the class of this goal function's results. """ return TextToTextGoalFunctionResult diff --git a/textattack/loggers/csv_logger.py b/textattack/loggers/csv_logger.py index e77eac97..1dc52097 100644 --- a/textattack/loggers/csv_logger.py +++ b/textattack/loggers/csv_logger.py @@ -20,9 +20,8 @@ class CSVLogger(Logger): self._flushed = True def log_attack_result(self, result): - if isinstance(result, FailedAttackResult): - return original_text, perturbed_text = result.diff_color(self.color_method) + result_type = result.__class__.__name__.replace("AttackResult", "") row = { "original_text": original_text, "perturbed_text": perturbed_text, @@ -30,7 +29,9 @@ class CSVLogger(Logger): "perturbed_score": result.perturbed_result.score, "original_output": result.original_result.output, "perturbed_output": result.perturbed_result.output, + "ground_truth_output": result.original_result.ground_truth_output, "num_queries": result.num_queries, + "result_type": result_type, } self.df = self.df.append(row, ignore_index=True) self._flushed = False diff --git a/textattack/models/t5_for_text_to_text.py b/textattack/models/t5_for_text_to_text.py index 6bb184f0..44d5340e 100644 --- a/textattack/models/t5_for_text_to_text.py +++ b/textattack/models/t5_for_text_to_text.py @@ -31,7 +31,7 @@ class T5ForTextToText: def __init__( self, mode="english_to_german", max_length=20, num_beams=1, early_stopping=True ): - self.model = transformers.AutoModelWithLMHead.from_pretrained("t5-base") + self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base") self.model.to(utils.device) self.model.eval() self.tokenizer = T5Tokenizer(mode) diff --git a/textattack/models/tokenizers/auto_tokenizer.py b/textattack/models/tokenizers/auto_tokenizer.py index a337d51e..8af56f29 100644 --- a/textattack/models/tokenizers/auto_tokenizer.py +++ b/textattack/models/tokenizers/auto_tokenizer.py @@ -18,11 +18,7 @@ class AutoTokenizer: """ def __init__( - self, - name="bert-base-uncased", - max_length=256, - pad_to_length=False, - use_fast=True, + self, name="bert-base-uncased", max_length=256, use_fast=True, ): self.tokenizer = transformers.AutoTokenizer.from_pretrained( name, use_fast=use_fast @@ -43,7 +39,7 @@ class AutoTokenizer: *input_text, max_length=self.max_length, add_special_tokens=True, - pad_to_max_length=True, + padding="max_length", truncation=True, ) return dict(encoded_text) @@ -59,7 +55,7 @@ class AutoTokenizer: truncation=True, max_length=self.max_length, add_special_tokens=True, - pad_to_max_length=True, + padding="max_length", ) # Encodings is a `transformers.utils.BatchEncode` object, which # is basically a big dictionary that contains a key for all input diff --git a/textattack/models/tokenizers/glove_tokenizer.py b/textattack/models/tokenizers/glove_tokenizer.py index 821820a0..e5d83e15 100644 --- a/textattack/models/tokenizers/glove_tokenizer.py +++ b/textattack/models/tokenizers/glove_tokenizer.py @@ -62,7 +62,11 @@ class WordLevelTokenizer(hf_tokenizers.implementations.BaseTokenizer): normalizers = [] if unicode_normalizer: - normalizers += [unicode_normalizer_from_str(unicode_normalizer)] + normalizers += [ + hf_tokenizers.normalizers.unicode_normalizer_from_str( + unicode_normalizer + ) + ] if lowercase: normalizers += [hf_tokenizers.normalizers.Lowercase()] diff --git a/textattack/models/tokenizers/t5_tokenizer.py b/textattack/models/tokenizers/t5_tokenizer.py index 267eebbf..98888d14 100644 --- a/textattack/models/tokenizers/t5_tokenizer.py +++ b/textattack/models/tokenizers/t5_tokenizer.py @@ -11,10 +11,10 @@ class T5Tokenizer(AutoTokenizer): Supports the following modes: - * summarization: summarize English text (CNN/Daily Mail dataset) - * english_to_german: translate English to German (WMT dataset) - * english_to_french: translate English to French (WMT dataset) - * english_to_romanian: translate English to Romanian (WMT dataset) + * summarization: summarize English text + * english_to_german: translate English to German + * english_to_french: translate English to French + * english_to_romanian: translate English to Romanian """ @@ -28,7 +28,7 @@ class T5Tokenizer(AutoTokenizer): elif mode == "summarization": self.tokenization_prefix = "summarize: " else: - raise ValueError(f"Invalid t5 tokenizer mode {english_to_german}.") + raise ValueError(f"Invalid t5 tokenizer mode {mode}.") super().__init__(name="t5-base", max_length=max_length) @@ -38,12 +38,29 @@ class T5Tokenizer(AutoTokenizer): passed into T5. """ if isinstance(text, tuple): + if len(text) > 1: + raise ValueError( + f"T5Tokenizer tuple inputs must have length 1; got {len(text)}" + ) text = text[0] if not isinstance(text, str): raise TypeError(f"T5Tokenizer expects `str` input, got {type(text)}") text_to_encode = self.tokenization_prefix + text return super().encode(text_to_encode) + def batch_encode(self, input_text_list): + new_input_text_list = [] + for text in input_text_list: + if isinstance(text, tuple): + if len(text) > 1: + raise ValueError( + f"T5Tokenizer tuple inputs must have length 1; got {len(text)}" + ) + text = text[0] + new_input_text_list.append(self.tokenization_prefix + text) + + return super().batch_encode(new_input_text_list) + def decode(self, ids): """ Converts IDs (typically generated by the model) back to a string. diff --git a/textattack/search_methods/PSO_algorithm.py b/textattack/search_methods/PSO_algorithm.py new file mode 100644 index 00000000..8eace429 --- /dev/null +++ b/textattack/search_methods/PSO_algorithm.py @@ -0,0 +1,274 @@ +""" +Reimplementation of search method from Word-level Textual Adversarial Attacking as Combinatorial Optimization +by Zang et. al +``_ +``_ +""" + +from copy import deepcopy + +import numpy as np + +from textattack.goal_function_results import GoalFunctionResultStatus +from textattack.search_methods import SearchMethod + + +class PSOAlgorithm(SearchMethod): + """ + Attacks a model with word substiutitions using a Particle Swarm Optimization (PSO) algorithm. + Some key hyper-parameters are setup according to the original paper: + + "We adjust PSO on the validation set of SST and set ω_1 as 0.8 and ω_2 as 0.2. + We set the max velocity of the particles V_{max} to 3, which means the changing + probability of the particles ranges from 0.047 (sigmoid(-3)) to 0.953 (sigmoid(3))." + + Args: + pop_size (:obj:`int`, optional): The population size. Defauls to 60. + max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 20. + """ + + def __init__( + self, pop_size=60, max_iters=20, + ): + self.max_iters = max_iters + self.pop_size = pop_size + self.search_over = False + + self.Omega_1 = 0.8 + self.Omega_2 = 0.2 + self.C1_origin = 0.8 + self.C2_origin = 0.2 + self.V_max = 3.0 + + def _generate_population(self, x_orig, neighbors_list, neighbors_len): + h_score, w_list = self._gen_h_score(x_orig, neighbors_len, neighbors_list) + return [self._mutate(x_orig, h_score, w_list) for _ in range(self.pop_size)] + + def _mutate(self, x_cur, w_select_probs, w_list): + rand_idx = np.random.choice(len(w_select_probs), 1, p=w_select_probs)[0] + return x_cur.replace_word_at_index(rand_idx, w_list[rand_idx]) + + def _gen_h_score(self, x, neighbors_len, neighbors_list): + + w_list = [] + prob_list = [] + for i, orig_w in enumerate(x.words): + if neighbors_len[i] == 0: + w_list.append(orig_w) + prob_list.append(0) + continue + p, w = self._gen_most_change(x, i, neighbors_list[i]) + w_list.append(w) + prob_list.append(p) + + prob_list = self._norm(prob_list) + + h_score = prob_list + h_score = np.array(h_score) + return h_score, w_list + + def _norm(self, n): + + tn = [] + for i in n: + if i <= 0: + tn.append(0) + else: + tn.append(i) + s = np.sum(tn) + if s == 0: + for i in range(len(tn)): + tn[i] = 1 + return [t / len(tn) for t in tn] + new_n = [t / s for t in tn] + + return new_n + + # for un-targeted attacking + def _gen_most_change(self, x_cur, pos, replace_list): + orig_result, self.search_over = self.get_goal_results([x_cur]) + if self.search_over: + return 0, x_cur.words[pos] + new_x_list = [x_cur.replace_word_at_index(pos, w) for w in replace_list] + # new_x_list = self.get_transformations( + # x_cur, + # original_text=self.original_attacked_text, + # indices_to_modify=[pos], + # ) + new_x_results, self.search_over = self.get_goal_results(new_x_list) + new_x_scores = np.array([r.score for r in new_x_results]) + new_x_scores = ( + new_x_scores - orig_result[0].score + ) # minimize the score of ground truth + if len(new_x_scores): + return ( + np.max(new_x_scores), + new_x_list[np.argsort(new_x_scores)[-1]].words[pos], + ) + else: + return 0, x_cur.words[pos] + + def _get_neighbors_list(self, attacked_text): + """ + Generates this neighbors_len list + Args: + attacked_text: The original text + Returns: + A list of number of candidate neighbors for each word + """ + words = attacked_text.words + neighbors_list = [[] for _ in range(len(words))] + transformations = self.get_transformations( + attacked_text, original_text=self.original_attacked_text + ) + for transformed_text in transformations: + try: + diff_idx = attacked_text.first_word_diff_index(transformed_text) + neighbors_list[diff_idx].append(transformed_text.words[diff_idx]) + except: + assert len(attacked_text.words) == len(transformed_text.words) + assert all( + [ + w1 == w2 + for w1, w2 in zip(attacked_text.words, transformed_text.words) + ] + ) + neighbors_list = [np.array(x) for x in neighbors_list] + return neighbors_list + + def _equal(self, a, b): + if a == b: + return -self.V_max + else: + return self.V_max + + def _turn(self, x1, x2, prob, x_len): + indices_to_replace = [] + words_to_replace = [] + x2_words = x2.words + for i in range(x_len): + if np.random.uniform() < prob[i]: + indices_to_replace.append(i) + words_to_replace.append(x2_words[i]) + new_text = x1.replace_words_at_indices(indices_to_replace, words_to_replace) + return new_text + + def _count_change_ratio(self, x1, x2, x_len): + change_ratio = float(np.sum(x1.words != x2.words)) / float(x_len) + return change_ratio + + def _sigmoid(self, n): + return 1 / (1 + np.exp(-n)) + + def _perform_search(self, initial_result): + self.original_attacked_text = initial_result.attacked_text + x_len = len(self.original_attacked_text.words) + self.correct_output = initial_result.output + + # get word substitute candidates and generate population + neighbors_list = self._get_neighbors_list(self.original_attacked_text) + neighbors_len = [len(x) for x in neighbors_list] + pop = self._generate_population( + self.original_attacked_text, neighbors_list, neighbors_len + ) + + # test population against target model + pop_results, self.search_over = self.get_goal_results(pop) + if self.search_over: + return max(pop_results, key=lambda x: x.score) + pop_scores = np.array([r.score for r in pop_results]) + + # rank the scores from low to high and check if there is a successful attack + part_elites = deepcopy(pop) + part_elites_scores = pop_scores + top_attack = np.argmax(pop_scores) + all_elite = pop[top_attack] + all_elite_score = pop_scores[top_attack] + if pop_results[top_attack].goal_status == GoalFunctionResultStatus.SUCCEEDED: + return pop_results[top_attack] + + # set up hyper-parameters + V = np.random.uniform(-self.V_max, self.V_max, self.pop_size) + V_P = [[V[t] for _ in range(x_len)] for t in range(self.pop_size)] + + # start iterations + for i in range(self.max_iters): + + Omega = (self.Omega_1 - self.Omega_2) * ( + self.max_iters - i + ) / self.max_iters + self.Omega_2 + C1 = self.C1_origin - i / self.max_iters * (self.C1_origin - self.C2_origin) + C2 = self.C2_origin + i / self.max_iters * (self.C1_origin - self.C2_origin) + P1 = C1 + P2 = C2 + + all_elite_words = all_elite.words + for id in range(self.pop_size): + + # calculate the probability of turning each word + pop_words = pop[id].words + part_elites_words = part_elites[id].words + for dim in range(x_len): + V_P[id][dim] = Omega * V_P[id][dim] + (1 - Omega) * ( + self._equal(pop_words[dim], part_elites_words[dim]) + + self._equal(pop_words[dim], all_elite_words[dim]) + ) + turn_prob = [self._sigmoid(V_P[id][d]) for d in range(x_len)] + + if np.random.uniform() < P1: + pop[id] = self._turn(part_elites[id], pop[id], turn_prob, x_len) + if np.random.uniform() < P2: + pop[id] = self._turn(all_elite, pop[id], turn_prob, x_len) + + # check if there is any successful attack in the current population + pop_results, self.search_over = self.get_goal_results(pop) + if self.search_over: + return max(pop_results, key=lambda x: x.score) + pop_scores = np.array([r.score for r in pop_results]) + top_attack = np.argmax(pop_scores) + if ( + pop_results[top_attack].goal_status + == GoalFunctionResultStatus.SUCCEEDED + ): + return pop_results[top_attack] + + # mutation based on the current change rate + new_pop = [] + for x in pop: + change_ratio = self._count_change_ratio( + x, self.original_attacked_text, x_len + ) + p_change = ( + 1 - 2 * change_ratio + ) # referred from the original source code + if np.random.uniform() < p_change: + new_h, new_w_list = self._gen_h_score( + x, neighbors_len, neighbors_list + ) + new_pop.append(self._mutate(x, new_h, new_w_list)) + else: + new_pop.append(x) + pop = new_pop + + # check if there is any successful attack in the current population + pop_results, self.search_over = self.get_goal_results(pop) + if self.search_over: + return max(pop_results, key=lambda x: x.score) + pop_scores = np.array([r.score for r in pop_results]) + top_attack = np.argmax(pop_scores) + if ( + pop_results[top_attack].goal_status + == GoalFunctionResultStatus.SUCCEEDED + ): + return pop_results[top_attack] + + # update the elite if the score is increased + for k in range(self.pop_size): + if pop_scores[k] > part_elites_scores[k]: + part_elites[k] = pop[k] + part_elites_scores[k] = pop_scores[k] + if pop_scores[top_attack] > all_elite_score: + all_elite = pop[top_attack] + all_elite_score = pop_scores[top_attack] + + return initial_result diff --git a/textattack/search_methods/__init__.py b/textattack/search_methods/__init__.py index 69574cdd..fdae08f8 100644 --- a/textattack/search_methods/__init__.py +++ b/textattack/search_methods/__init__.py @@ -3,3 +3,4 @@ from .beam_search import BeamSearch from .greedy_search import GreedySearch from .greedy_word_swap_wir import GreedyWordSwapWIR from .genetic_algorithm import GeneticAlgorithm +from .PSO_algorithm import PSOAlgorithm diff --git a/textattack/search_methods/beam_search.py b/textattack/search_methods/beam_search.py index 0e1dc769..7d5b1f6a 100644 --- a/textattack/search_methods/beam_search.py +++ b/textattack/search_methods/beam_search.py @@ -1,5 +1,6 @@ import numpy as np +from textattack.goal_function_results import GoalFunctionResultStatus from textattack.search_methods import SearchMethod @@ -21,7 +22,7 @@ class BeamSearch(SearchMethod): def _perform_search(self, initial_result): beam = [initial_result.attacked_text] best_result = initial_result - while not best_result.succeeded: + while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: potential_next_beam = [] for text in beam: transformations = self.get_transformations( @@ -32,9 +33,7 @@ class BeamSearch(SearchMethod): if len(potential_next_beam) == 0: # If we did not find any possible perturbations, give up. return best_result - results, search_over = self.get_goal_results( - potential_next_beam, initial_result.output - ) + results, search_over = self.get_goal_results(potential_next_beam) scores = np.array([r.score for r in results]) best_result = results[scores.argmax()] if search_over: diff --git a/textattack/search_methods/genetic_algorithm.py b/textattack/search_methods/genetic_algorithm.py index 7e16803d..c3fda273 100644 --- a/textattack/search_methods/genetic_algorithm.py +++ b/textattack/search_methods/genetic_algorithm.py @@ -10,6 +10,7 @@ from copy import deepcopy import numpy as np import torch +from textattack.goal_function_results import GoalFunctionResultStatus from textattack.search_methods import SearchMethod from textattack.shared.validators import transformation_consists_of_word_swaps @@ -19,167 +20,209 @@ class GeneticAlgorithm(SearchMethod): Attacks a model with word substiutitions using a genetic algorithm. Args: - pop_size (:obj:`int`, optional): The population size. Defauls to 20. - max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 50. + pop_size (int): The population size. Defaults to 20. + max_iters (int): The maximum number of iterations to use. Defaults to 50. + temp (float): Temperature for softmax function used to normalize probability dist when sampling parents. + Higher temperature increases the sensitivity to lower probability candidates. + give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found. + max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints. + Setting it to 0 means we immediately take one of the parents at random as the child. """ def __init__( - self, pop_size=20, max_iters=50, temp=0.3, give_up_if_no_improvement=False + self, + pop_size=20, + max_iters=50, + temp=0.3, + give_up_if_no_improvement=False, + max_crossover_retries=20, ): self.max_iters = max_iters self.pop_size = pop_size self.temp = temp self.give_up_if_no_improvement = give_up_if_no_improvement - self.search_over = False + self.max_crossover_retries = max_crossover_retries - def _replace_at_index(self, pop_member, idx): + # internal flag to indicate if search should end immediately + self._search_over = False + + def _perturb(self, pop_member, original_result): """ - Select the best replacement for word at position (idx) - in (pop_member) to maximize score. - + Replaces a word in pop_member that has not been modified in place. Args: - pop_member: The population member being perturbed. - idx: The index at which to replace a word. - - Returns: - Whether a replacement which increased the score was found. + pop_member (PopulationMember): The population member being perturbed. + original_result (GoalFunctionResult): Result of original sample being attacked + + Returns: None """ - transformations = self.get_transformations( - pop_member.attacked_text, - original_text=self.original_attacked_text, - indices_to_modify=[idx], - ) - if not len(transformations): - return False - orig_result, self.search_over = self.get_goal_results( - [pop_member.attacked_text], self.correct_output - ) - if self.search_over: - return False - new_x_results, self.search_over = self.get_goal_results( - transformations, self.correct_output - ) - new_x_scores = torch.Tensor([r.score for r in new_x_results]) - new_x_scores = new_x_scores - orig_result[0].score - if len(new_x_scores) and new_x_scores.max() > 0: - pop_member.attacked_text = transformations[new_x_scores.argmax()] - return True - return False - - def _perturb(self, pop_member): - """ - Replaces a word in pop_member that has not been modified. - Args: - pop_member: The population member being perturbed. - """ - x_len = pop_member.neighbors_len.shape[0] - neighbors_len = deepcopy(pop_member.neighbors_len) - non_zero_indices = np.sum(np.sign(pop_member.neighbors_len)) + num_words = pop_member.num_candidates_per_word.shape[0] + num_candidates_per_word = np.copy(pop_member.num_candidates_per_word) + non_zero_indices = np.count_nonzero(num_candidates_per_word) if non_zero_indices == 0: return iterations = 0 - while iterations < non_zero_indices and not self.search_over: - w_select_probs = neighbors_len / np.sum(neighbors_len) - rand_idx = np.random.choice(x_len, 1, p=w_select_probs)[0] - if self._replace_at_index(pop_member, rand_idx): - pop_member.neighbors_len[rand_idx] = 0 + while iterations < non_zero_indices: + w_select_probs = num_candidates_per_word / np.sum(num_candidates_per_word) + rand_idx = np.random.choice(num_words, 1, p=w_select_probs)[0] + + transformations = self.get_transformations( + pop_member.attacked_text, + original_text=original_result.attacked_text, + indices_to_modify=[rand_idx], + ) + + if not len(transformations): + iterations += 1 + continue + + new_results, self._search_over = self.get_goal_results(transformations) + + if self._search_over: break - neighbors_len[rand_idx] = 0 + + diff_scores = ( + torch.Tensor([r.score for r in new_results]) - pop_member.result.score + ) + if len(diff_scores) and diff_scores.max() > 0: + idx = diff_scores.argmax() + pop_member.attacked_text = transformations[idx] + pop_member.num_candidates_per_word[rand_idx] = 0 + pop_member.results = new_results[idx] + break + + num_candidates_per_word[rand_idx] = 0 iterations += 1 - def _generate_population(self, neighbors_len, initial_result): - """ - Generates a population of texts each with one word replaced - Args: - neighbors_len: A list of the number of candidate neighbors for each word. - initial_result: The result to instantiate the population with - Returns: - The population. - """ - pop = [] - for _ in range(self.pop_size): - pop_member = PopulationMember( - self.original_attacked_text, deepcopy(neighbors_len), initial_result - ) - self._perturb(pop_member) - pop.append(pop_member) - return pop - - def _crossover(self, pop_member1, pop_member2): + def _crossover(self, pop_member1, pop_member2, original_result): """ Generates a crossover between pop_member1 and pop_member2. + If the child fails to satisfy the constraits, we re-try crossover for a fix number of times, + before taking one of the parents at random as the resulting child. Args: - pop_member1: The first population member. - pop_member2: The second population member. + pop_member1 (PopulationMember): The first population member. + pop_member2 (PopulationMember): The second population member. Returns: A population member containing the crossover. """ - indices_to_replace = [] - words_to_replace = [] x1_text = pop_member1.attacked_text - x2_words = pop_member2.attacked_text.words - new_neighbors_len = deepcopy(pop_member1.neighbors_len) - for i in range(len(x1_text.words)): - if np.random.uniform() < 0.5: - indices_to_replace.append(i) - words_to_replace.append(x2_words[i]) - new_neighbors_len[i] = pop_member2.neighbors_len[i] - new_text = x1_text.replace_words_at_indices( - indices_to_replace, words_to_replace - ) - return PopulationMember(new_text, deepcopy(new_neighbors_len)) + x2_text = pop_member2.attacked_text + x2_words = x2_text.words - def _get_neighbors_len(self, attacked_text): + num_tries = 0 + passed_constraints = False + while num_tries < self.max_crossover_retries + 1: + indices_to_replace = [] + words_to_replace = [] + num_candidates_per_word = np.copy(pop_member1.num_candidates_per_word) + for i in range(len(x1_text.words)): + if np.random.uniform() < 0.5: + indices_to_replace.append(i) + words_to_replace.append(x2_words[i]) + num_candidates_per_word[i] = pop_member2.num_candidates_per_word[i] + new_text = x1_text.replace_words_at_indices( + indices_to_replace, words_to_replace + ) + if "last_transformation" in x1_text.attack_attrs: + new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[ + "last_transformation" + ] + filtered = self.filter_transformations( + [new_text], x1_text, original_text=original_result.attacked_text + ) + elif "last_transformation" in x2_text.attack_attrs: + new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[ + "last_transformation" + ] + filtered = self.filter_transformations( + [new_text], x1_text, original_text=original_result.attacked_text + ) + else: + # In this case, neither x_1 nor x_2 has been transformed, + # meaning that new_text == original_text + filtered = [new_text] + + if filtered: + new_text = filtered[0] + passed_constraints = True + break + + num_tries += 1 + + if not passed_constraints: + # If we cannot find a child that passes the constraints, + # we just randomly pick one of the parents to be the child for the next iteration. + new_text = ( + pop_member1.attacked_text + if np.random.uniform() < 0.5 + else pop_member2.attacked_text + ) + + new_results, self._search_over = self.get_goal_results([new_text]) + + return PopulationMember(new_text, num_candidates_per_word, new_results[0]) + + def _initialize_population(self, initial_result): """ - Generates this neighbors_len list + Initialize a population of texts each with one word replaced Args: - attacked_text: The original text + initial_result (GoalFunctionResult): The result to instantiate the population with Returns: - A list of number of candidate neighbors for each word + The population. """ - words = attacked_text.words - neighbors_list = [[] for _ in range(len(words))] + words = initial_result.attacked_text.words + num_candidates_per_word = np.zeros(len(words)) transformations = self.get_transformations( - attacked_text, original_text=self.original_attacked_text + initial_result.attacked_text, original_text=initial_result.attacked_text ) for transformed_text in transformations: - diff_idx = attacked_text.first_word_diff_index(transformed_text) - neighbors_list[diff_idx].append(transformed_text.words[diff_idx]) - neighbors_list = [np.array(x) for x in neighbors_list] - neighbors_len = np.array([len(x) for x in neighbors_list]) - return neighbors_len + diff_idx = initial_result.attacked_text.first_word_diff_index( + transformed_text + ) + num_candidates_per_word[diff_idx] += 1 + + # Just b/c there are no candidates now doesn't mean we never want to select the word for perturbation + # Therefore, we give small non-zero probability for words with no candidates + # Epsilon is some small number to approximately assign 1% probability + num_total_candidates = np.sum(num_candidates_per_word) + epsilon = max(1, int(num_total_candidates * 0.01)) + for i in range(len(num_candidates_per_word)): + if num_candidates_per_word[i] == 0: + num_candidates_per_word[i] = epsilon + + population = [] + for _ in range(self.pop_size): + pop_member = PopulationMember( + initial_result.attacked_text, + np.copy(num_candidates_per_word), + initial_result, + ) + # Perturb `pop_member` in-place + self._perturb(pop_member, initial_result) + population.append(pop_member) + return population def _perform_search(self, initial_result): - self.original_attacked_text = initial_result.attacked_text - self.correct_output = initial_result.output - neighbors_len = self._get_neighbors_len(self.original_attacked_text) - pop = self._generate_population(neighbors_len, initial_result) - cur_score = initial_result.score + self._search_over = False + population = self._initialize_population(initial_result) + current_score = initial_result.score for i in range(self.max_iters): - pop_results, self.search_over = self.get_goal_results( - [pm.attacked_text for pm in pop], self.correct_output - ) - if self.search_over: - if not len(pop_results): - return pop[0].result - return max(pop_results, key=lambda x: x.score) - for idx, result in enumerate(pop_results): - pop[idx].result = pop_results[idx] - pop = sorted(pop, key=lambda x: -x.result.score) + population = sorted(population, key=lambda x: x.result.score, reverse=True) + if ( + self._search_over + or population[0].result.goal_status + == GoalFunctionResultStatus.SUCCEEDED + ): + break - pop_scores = torch.Tensor([r.score for r in pop_results]) - logits = ((-pop_scores) / self.temp).exp() - select_probs = (logits / logits.sum()).cpu().numpy() - - if pop[0].result.succeeded: - return pop[0].result - - if pop[0].result.score > cur_score: - cur_score = pop[0].result.score + if population[0].result.score > current_score: + current_score = population[0].result.score elif self.give_up_if_no_improvement: break - elite = [pop[0]] + pop_scores = torch.Tensor([pm.result.score for pm in population]) + logits = ((-pop_scores) / self.temp).exp() + select_probs = (logits / logits.sum()).cpu().numpy() + parent1_idx = np.random.choice( self.pop_size, size=self.pop_size - 1, p=select_probs ) @@ -187,16 +230,27 @@ class GeneticAlgorithm(SearchMethod): self.pop_size, size=self.pop_size - 1, p=select_probs ) - children = [ - self._crossover(pop[parent1_idx[idx]], pop[parent2_idx[idx]]) - for idx in range(self.pop_size - 1) - ] - for c in children: - self._perturb(c) + children = [] + for idx in range(self.pop_size - 1): + child = self._crossover( + population[parent1_idx[idx]], + population[parent2_idx[idx]], + initial_result, + ) + if self._search_over: + break - pop = elite + children + self._perturb(child, initial_result) + children.append(child) - return pop[0].result + # We need two `search_over` checks b/c value might change both in + # `crossover` method and `perturb` method. + if self._search_over: + break + + population = [population[0]] + children + + return population[0].result def check_transformation_compatibility(self, transformation): """ @@ -214,10 +268,10 @@ class PopulationMember: Args: attacked_text: The ``AttackedText`` of the population member. - neighbors_len: A list of the number of candidate neighbors list for each word. + num_candidates_per_word (numpy.array): A list of the number of candidate neighbors list for each word. """ - def __init__(self, attacked_text, neighbors_len, result=None): + def __init__(self, attacked_text, num_candidates_per_word, result): self.attacked_text = attacked_text - self.neighbors_len = neighbors_len + self.num_candidates_per_word = num_candidates_per_word self.result = result diff --git a/textattack/search_methods/greedy_word_swap_wir.py b/textattack/search_methods/greedy_word_swap_wir.py index d7eb6e94..949cb6bd 100644 --- a/textattack/search_methods/greedy_word_swap_wir.py +++ b/textattack/search_methods/greedy_word_swap_wir.py @@ -10,8 +10,11 @@ import numpy as np import torch from torch.nn.functional import softmax +from textattack.goal_function_results import GoalFunctionResultStatus from textattack.search_methods import SearchMethod -from textattack.shared.validators import transformation_consists_of_word_swaps +from textattack.shared.validators import ( + transformation_consists_of_word_swaps_and_deletions, +) class GreedyWordSwapWIR(SearchMethod): @@ -20,22 +23,23 @@ class GreedyWordSwapWIR(SearchMethod): order of index, after ranking indices by importance. Args: +<<<<<<< HEAD wir_method (str): Method for ranking most important words. Available choices: `unk`, `delete`, `pwws`, and `random`. ascending (bool): if True, ranks words from least-to-most important. (Default ranking shows the most important word first.) +======= + wir_method: method for ranking most important words +>>>>>>> master """ - def __init__(self, wir_method="unk", ascending=False): + def __init__(self, wir_method="unk"): self.wir_method = wir_method - self.ascending = ascending def _get_index_order(self, initial_result, texts): """ Queries model for list of attacked text objects ``text`` and ranks in order of descending score. """ - leave_one_results, search_over = self.get_goal_results( - texts, initial_result.output - ) + leave_one_results, search_over = self.get_goal_results(texts) leave_one_scores = np.array([result.score for result in leave_one_results]) return leave_one_scores, search_over @@ -98,10 +102,7 @@ class GreedyWordSwapWIR(SearchMethod): search_over = False if self.wir_method != "random": - if self.ascending: - index_order = (leave_one_scores).argsort() - else: - index_order = (-leave_one_scores).argsort() + index_order = (-leave_one_scores).argsort() i = 0 results = None @@ -114,9 +115,7 @@ class GreedyWordSwapWIR(SearchMethod): i += 1 if len(transformed_text_candidates) == 0: continue - results, search_over = self.get_goal_results( - transformed_text_candidates, initial_result.output - ) + results, search_over = self.get_goal_results(transformed_text_candidates) results = sorted(results, key=lambda x: -x.score) # Skip swaps which don't improve the score if results[0].score > cur_result.score: @@ -124,12 +123,12 @@ class GreedyWordSwapWIR(SearchMethod): else: continue # If we succeeded, return the index with best similarity. - if cur_result.succeeded: + if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: best_result = cur_result # @TODO: Use vectorwise operations max_similarity = -float("inf") for result in results: - if not result.succeeded: + if result.goal_status != GoalFunctionResultStatus.SUCCEEDED: break candidate = result.attacked_text try: @@ -149,9 +148,9 @@ class GreedyWordSwapWIR(SearchMethod): def check_transformation_compatibility(self, transformation): """ - Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swaps transformations. + Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swap and deletion transformations. """ - return transformation_consists_of_word_swaps(transformation) + return transformation_consists_of_word_swaps_and_deletions(transformation) def extra_repr_keys(self): return ["wir_method"] diff --git a/textattack/search_methods/search_method.py b/textattack/search_methods/search_method.py index a05e9c21..0167bbd8 100644 --- a/textattack/search_methods/search_method.py +++ b/textattack/search_methods/search_method.py @@ -22,6 +22,10 @@ class SearchMethod(ABC): raise AttributeError( "Search Method must have access to get_goal_results method" ) + if not hasattr(self, "filter_transformations"): + raise AttributeError( + "Search Method must have access to filter_transformations method" + ) return self._perform_search(initial_result) @abstractmethod diff --git a/textattack/shared/attack.py b/textattack/shared/attack.py index 4137b021..1c599a38 100644 --- a/textattack/shared/attack.py +++ b/textattack/shared/attack.py @@ -7,9 +7,11 @@ import numpy as np import textattack from textattack.attack_results import ( FailedAttackResult, + MaximizedAttackResult, SkippedAttackResult, SuccessfulAttackResult, ) +from textattack.goal_function_results import GoalFunctionResultStatus from textattack.shared import AttackedText, utils @@ -56,7 +58,7 @@ class Attack: self.transformation ): raise ValueError( - "SearchMethod {self.search_method} incompatible with transformation {self.transformation}" + f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}" ) self.constraints = [] @@ -74,7 +76,12 @@ class Attack: # Give search method access to functions for getting transformations and evaluating them self.search_method.get_transformations = self.get_transformations - self.search_method.get_goal_results = self.goal_function.get_results + # The search method only needs access to the first argument. The second is only used + # by the attack class when checking whether to skip the sample + self.search_method.get_goal_results = lambda attacked_text_list: self.goal_function.get_results( + attacked_text_list + ) + self.search_method.filter_transformations = self.filter_transformations def get_transformations(self, current_text, original_text=None, **kwargs): """ @@ -102,7 +109,7 @@ class Attack: **kwargs, ) ) - return self._filter_transformations( + return self.filter_transformations( transformed_texts, current_text, original_text ) @@ -138,7 +145,7 @@ class Attack: self.constraints_cache[(current_text, filtered_text)] = True return filtered_texts - def _filter_transformations( + def filter_transformations( self, transformed_texts, current_text, original_text=None ): """ @@ -180,17 +187,18 @@ class Attack: initial_result: The initial ``GoalFunctionResult`` from which to perturb. Returns: - Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``. + A ``SuccessfulAttackResult``, ``FailedAttackResult``, + or ``MaximizedAttackResult``. """ final_result = self.search_method(initial_result) - if final_result.succeeded: - return SuccessfulAttackResult( - initial_result, final_result, self.goal_function.num_queries - ) + if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: + return SuccessfulAttackResult(initial_result, final_result,) + elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING: + return FailedAttackResult(initial_result, final_result,) + elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING: + return MaximizedAttackResult(initial_result, final_result,) else: - return FailedAttackResult( - initial_result, final_result, self.goal_function.num_queries - ) + raise ValueError(f"Unrecognized goal status {final_result.goal_status}") def _get_examples_from_dataset(self, dataset, indices=None): """ @@ -222,14 +230,9 @@ class Attack: attacked_text = AttackedText( text, attack_attrs={"label_names": label_names} ) - self.goal_function.num_queries = 0 - goal_function_result, _ = self.goal_function.get_result( + goal_function_result, _ = self.goal_function.init_attack_example( attacked_text, ground_truth_output ) - if goal_function_result.succeeded: - # Store the true output on the goal function so that the - # SkippedAttackResult has the correct output, not the incorrect. - goal_function_result.output = ground_truth_output yield goal_function_result except IndexError: @@ -250,7 +253,7 @@ class Attack: examples = self._get_examples_from_dataset(dataset, indices=indices) for goal_function_result in examples: - if goal_function_result.succeeded: + if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED: yield SkippedAttackResult(goal_function_result) else: result = self.attack_one(goal_function_result) diff --git a/textattack/shared/attacked_text.py b/textattack/shared/attacked_text.py index 1c5b3649..c8a824c4 100644 --- a/textattack/shared/attacked_text.py +++ b/textattack/shared/attacked_text.py @@ -42,9 +42,11 @@ class AttackedText: raise TypeError( f"Invalid text_input type {type(text_input)} (required str or OrderedDict)" ) + # Find words in input lazily. + self._words = None + self._words_per_input = None # Format text inputs. self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()]) - self.words = words_from_text(self.text) if attack_attrs is None: self.attack_attrs = dict() elif isinstance(attack_attrs, dict): @@ -53,7 +55,7 @@ class AttackedText: raise TypeError(f"Invalid type for attack_attrs: {type(attack_attrs)}") # Indices of words from the *original* text. Allows us to map # indices between original text and this text, and vice-versa. - self.attack_attrs.setdefault("original_index_map", np.arange(len(self.words))) + self.attack_attrs.setdefault("original_index_map", np.arange(self.num_words)) # A list of all indices in *this* text that have been modified. self.attack_attrs.setdefault("modified_indices", set()) @@ -97,7 +99,7 @@ class AttackedText: def text_window_around_index(self, index, window_size): """ The text window of ``window_size`` words centered around ``index``. """ - length = len(self.words) + length = self.num_words half_size = (window_size - 1) / 2.0 if index - half_size < 0: start = 0 @@ -177,7 +179,7 @@ class AttackedText: """ Takes indices of words from original string and converts them to indices of the same words in the current string. - Uses information from ``self.attack_attrs['original_index_map'], + Uses information from ``self.attack_attrs['original_index_map']``, which maps word indices from the original to perturbed text. """ if len(self.attack_attrs["original_index_map"]) == 0: @@ -344,6 +346,29 @@ class AttackedText: """ The tuple of inputs to be passed to the tokenizer. """ return tuple(self._text_input.values()) + @property + def column_labels(self): + """ Returns the labels for this text's columns. For single-sequence + inputs, this simply returns ['text']. + """ + return list(self._text_input.keys()) + + @property + def words_per_input(self): + """ Returns a list of lists of words corresponding to each input. + """ + if not self._words_per_input: + self._words_per_input = [ + words_from_text(_input) for _input in self._text_input.values() + ] + return self._words_per_input + + @property + def words(self): + if not self._words: + self._words = words_from_text(self.text) + return self._words + @property def text(self): """ Represents full text input. Multiply inputs are joined with a line @@ -351,6 +376,11 @@ class AttackedText: """ return "\n".join(self._text_input.values()) + @property + def num_words(self): + """ Returns the number of words in the sequence. """ + return len(self.words) + def printable_text(self, key_color="bold", key_color_method=None): """ Represents full text input. Adds field descriptions. diff --git a/textattack/shared/checkpoint.py b/textattack/shared/checkpoint.py index f2d68a42..338c3dc7 100644 --- a/textattack/shared/checkpoint.py +++ b/textattack/shared/checkpoint.py @@ -6,6 +6,7 @@ import time from textattack.attack_results import ( FailedAttackResult, + MaximizedAttackResult, SkippedAttackResult, SuccessfulAttackResult, ) @@ -101,6 +102,11 @@ class Checkpoint: f"(Number of failed attacks): {self.num_failed_attacks}", 2 ) ) + breakdown_lines.append( + utils.add_indent( + f"(Number of maximized attacks): {self.num_maximized_attacks}", 2 + ) + ) breakdown_lines.append( utils.add_indent( f"(Number of skipped attacks): {self.num_skipped_attacks}", 2 @@ -140,6 +146,12 @@ class Checkpoint: isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results ) + @property + def num_maximized_attacks(self): + return sum( + isinstance(r, MaximizedAttackResult) for r in self.log_manager.results + ) + @property def num_remaining_attacks(self): if self.args.attack_n: diff --git a/textattack/shared/utils/install.py b/textattack/shared/utils/install.py index 966d0ff6..3ef17311 100644 --- a/textattack/shared/utils/install.py +++ b/textattack/shared/utils/install.py @@ -11,6 +11,9 @@ import requests import torch import tqdm +# Hide an error message from `tokenizers` if this process is forked. +os.environ["TOKENIZERS_PARALLELISM"] = "True" + def path_in_cache(file_path): try: diff --git a/textattack/shared/utils/misc.py b/textattack/shared/utils/misc.py index 7fb19a86..415f8265 100644 --- a/textattack/shared/utils/misc.py +++ b/textattack/shared/utils/misc.py @@ -24,7 +24,7 @@ def html_style_from_dict(style_dict): def html_table_from_rows(rows, title=None, header=None, style_dict=None): # Stylize the container div. if style_dict: - table_html = "
".format(style_from_dict(style_dict)) + table_html = "
".format(html_style_from_dict(style_dict)) else: table_html = "
" # Print the title string. @@ -78,6 +78,8 @@ def load_textattack_model_from_path(model_name, model_path): model = textattack.models.helpers.BERTForClassification( model_path=model_path, num_labels=num_labels ) + elif model_name.startswith("t5"): + model = textattack.models.helpers.T5ForTextToText(model_path) else: raise ValueError(f"Unknown textattack model {model_path}") return model diff --git a/textattack/shared/utils/strings.py b/textattack/shared/utils/strings.py index 359fedcd..9466641a 100644 --- a/textattack/shared/utils/strings.py +++ b/textattack/shared/utils/strings.py @@ -1,11 +1,17 @@ def has_letter(word): - """ Returns true if `word` contains at least one character in [A-Za-z]. """ + """ Returns true if `word` contains at least one character in [A-Za-z]. + """ + # TODO implement w regex for c in word: if c.isalpha(): return True return False +def is_one_word(word): + return len(words_from_text(word)) == 1 + + def add_indent(s_, numSpaces): s = s_.split("\n") # don't do anything for single-line stuff @@ -21,10 +27,15 @@ def add_indent(s_, numSpaces): def words_from_text(s, words_to_ignore=[]): """ Lowercases a string, removes all non-alphanumeric characters, and splits into words. """ + # TODO implement w regex words = [] word = "" for c in " ".join(s.split()): - if c.isalpha(): + if c.isalnum(): + word += c + elif c in "'-" and len(word) > 0: + # Allow apostrophes and hyphens as long as they don't begin the + # word. word += c elif word: if word not in words_to_ignore: diff --git a/textattack/shared/validators.py b/textattack/shared/validators.py index 0347be1e..612c3ff7 100644 --- a/textattack/shared/validators.py +++ b/textattack/shared/validators.py @@ -8,7 +8,7 @@ from . import logger # A list of goal functions and the corresponding available models. MODELS_BY_GOAL_FUNCTIONS = { - (TargetedClassification, UntargetedClassification): [ + (TargetedClassification, UntargetedClassification, InputReduction): [ r"^textattack.models.classification.*", r"^textattack.models.entailment.*", r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", @@ -108,3 +108,14 @@ def transformation_consists_of_word_swaps(transformation): from textattack.transformations import WordSwap, WordSwapGradientBased return transformation_consists_of(transformation, [WordSwap, WordSwapGradientBased]) + + +def transformation_consists_of_word_swaps_and_deletions(transformation): + """ + Determines if ``transformation`` is a word swap or consists of only word swaps and deletions. + """ + from textattack.transformations import WordDeletion, WordSwap, WordSwapGradientBased + + return transformation_consists_of( + transformation, [WordDeletion, WordSwap, WordSwapGradientBased] + ) diff --git a/textattack/shared/word_embedding.py b/textattack/shared/word_embedding.py index 2c34a392..cdd553a4 100644 --- a/textattack/shared/word_embedding.py +++ b/textattack/shared/word_embedding.py @@ -24,7 +24,7 @@ class WordEmbedding: mse_dist_file = "mse_dist.p" cos_sim_file = "cos_sim.p" else: - raise ValueError(f"Could not find word embedding {word_embedding}") + raise ValueError(f"Could not find word embedding {embedding_type}") # Download embeddings if they're not cached. word_embeddings_root_path = textattack.shared.utils.download_if_needed( diff --git a/textattack/transformations/__init__.py b/textattack/transformations/__init__.py index e9c20c9c..368c2aa2 100644 --- a/textattack/transformations/__init__.py +++ b/textattack/transformations/__init__.py @@ -5,6 +5,7 @@ from .word_swap import WordSwap # Black-box transformations from .word_deletion import WordDeletion from .word_swap_embedding import WordSwapEmbedding +from .word_swap_hownet import WordSwapHowNet from .word_swap_homoglyph_swap import WordSwapHomoglyphSwap from .word_swap_neighboring_character_swap import WordSwapNeighboringCharacterSwap from .word_swap_random_character_deletion import WordSwapRandomCharacterDeletion diff --git a/textattack/transformations/composite_transformation.py b/textattack/transformations/composite_transformation.py index f0030627..fda8a089 100644 --- a/textattack/transformations/composite_transformation.py +++ b/textattack/transformations/composite_transformation.py @@ -21,6 +21,14 @@ class CompositeTransformation(Transformation): raise ValueError("transformations cannot be empty") self.transformations = transformations + def _get_transformations(self, *_): + """ Placeholder method that would throw an error if a user tried to + treat the CompositeTransformation as a 'normal' transformation. + """ + raise RuntimeError( + "CompositeTransformation does not support _get_transformations()." + ) + def __call__(self, *args, **kwargs): new_attacked_texts = set() for transformation in self.transformations: diff --git a/textattack/transformations/transformation.py b/textattack/transformations/transformation.py index 087708aa..b756a268 100644 --- a/textattack/transformations/transformation.py +++ b/textattack/transformations/transformation.py @@ -1,7 +1,9 @@ +from abc import ABC, abstractmethod + from textattack.shared.utils import default_class_repr -class Transformation: +class Transformation(ABC): """ An abstract class for transforming a sequence of text to produce a potential adversarial example. @@ -44,6 +46,7 @@ class Transformation: text.attack_attrs["last_transformation"] = self return transformed_texts + @abstractmethod def _get_transformations(self, current_text, indices_to_modify): """ Returns a list of all possible transformations for ``current_text``, only modifying diff --git a/textattack/transformations/word_swap_hownet.py b/textattack/transformations/word_swap_hownet.py new file mode 100644 index 00000000..7a0b4bdc --- /dev/null +++ b/textattack/transformations/word_swap_hownet.py @@ -0,0 +1,106 @@ +import pickle + +from flair.data import Sentence +from flair.models import SequenceTagger + +from textattack.shared import utils +from textattack.transformations.word_swap import WordSwap + + +class WordSwapHowNet(WordSwap): + """ Transforms an input by replacing its words with synonyms in the stored synonyms bank + generated by the OpenHowNet. """ + + PATH = "transformations/hownet" + + def __init__(self, max_candidates=-1, **kwargs): + super().__init__(**kwargs) + self.max_candidates = max_candidates + + # Download synonym candidates bank if they're not cached. + cache_path = utils.download_if_needed( + "{}/{}".format(WordSwapHowNet.PATH, "word_candidates_sense.pkl") + ) + + # Actually load the files from disk. + with open(cache_path, "rb") as fp: + self.candidates_bank = pickle.load(fp) + + self._flair_pos_tagger = SequenceTagger.load("pos-fast") + self.pos_dict = {"JJ": "adj", "NN": "noun", "RB": "adv", "VB": "verb"} + + def _get_replacement_words(self, word, word_pos): + """ Returns a list of possible 'candidate words' to replace a word in a sentence + or phrase. Based on nearest neighbors selected word embeddings. + """ + word_pos = self.pos_dict.get(word_pos, None) + if word_pos is None: + return [] + + try: + candidate_words = self.candidates_bank[word.lower()][word_pos] + if self.max_candidates > 0: + candidate_words = candidate_words[: self.max_candidates] + return [ + recover_word_case(candidate_word, word) + for candidate_word in candidate_words + ] + except KeyError: + # This word is not in our synonym bank, so return an empty list. + return [] + + def _get_transformations(self, current_text, indices_to_modify): + words = current_text.words + words_str = " ".join(words) + word_list, pos_list = zip_flair_result( + self._flair_pos_tagger.predict(words_str)[0] + ) + assert len(words) == len( + word_list + ), "Part-of-speech tagger returned incorrect number of tags" + transformed_texts = [] + + for i in indices_to_modify: + word_to_replace = words[i] + word_to_replace_pos = pos_list[i][:2] # get the root POS + replacement_words = self._get_replacement_words( + word_to_replace, word_to_replace_pos + ) + transformed_texts_idx = [] + for r in replacement_words: + transformed_texts_idx.append(current_text.replace_word_at_index(i, r)) + transformed_texts.extend(transformed_texts_idx) + + return transformed_texts + + def extra_repr_keys(self): + return ["max_candidates"] + + +def recover_word_case(word, reference_word): + """ Makes the case of `word` like the case of `reference_word`. Supports + lowercase, UPPERCASE, and Capitalized. """ + if reference_word.islower(): + return word.lower() + elif reference_word.isupper() and len(reference_word) > 1: + return word.upper() + elif reference_word[0].isupper() and reference_word[1:].islower(): + return word.capitalize() + else: + # if other, just do not alter the word's case + return word + + +def zip_flair_result(pred): + """Parse the output from the FLAIR POS tagger""" + if not isinstance(pred, Sentence): + raise TypeError(f"Result from Flair POS tagger must be a `Sentence` object.") + + tokens = pred.tokens + word_list = [] + pos_list = [] + for token in tokens: + word_list.append(token.text) + pos_list.append(token.annotation_layers["pos"][0]._value) + + return word_list, pos_list diff --git a/textattack/transformations/word_swap_masked_lm.py b/textattack/transformations/word_swap_masked_lm.py index 5e8fc2b6..d2c91047 100644 --- a/textattack/transformations/word_swap_masked_lm.py +++ b/textattack/transformations/word_swap_masked_lm.py @@ -58,7 +58,8 @@ class WordSwapMaskedLM(WordSwap): encoding = self._lm_tokenizer.encode_plus( text, max_length=self.max_length, - pad_to_max_length=True, + truncation=True, + padding="max_length", return_tensors="pt", ) return {k: v.to(utils.device) for k, v in encoding.items()} @@ -93,7 +94,7 @@ class WordSwapMaskedLM(WordSwap): replacement_words = [] for id in top_ids: token = self._lm_tokenizer.convert_ids_to_tokens(id) - if check_if_word(token): + if utils.is_one_word(token): replacement_words.append(token) return replacement_words @@ -140,7 +141,7 @@ class WordSwapMaskedLM(WordSwap): replacement_words = [] for id in top_preds: token = self._lm_tokenizer.convert_ids_to_tokens(id) - if check_if_word(token): + if utils.is_one_word(token): replacement_words.append(token) return replacement_words else: @@ -162,7 +163,7 @@ class WordSwapMaskedLM(WordSwap): word = "".join( self._lm_tokenizer.convert_ids_to_tokens(word_tensor) ).replace("##", "") - if check_if_word(word): + if utils.is_one_word(word): combination_results.append((word, perplexity)) # Sort to get top-K results sorted(combination_results, key=lambda x: x[1]) @@ -228,10 +229,3 @@ def recover_word_case(word, reference_word): else: # if other, just do not alter the word's case return word - - -def check_if_word(word): - for c in word: - if not c.isalpha(): - return False - return True diff --git a/textattack/transformations/word_swap_wordnet.py b/textattack/transformations/word_swap_wordnet.py index 38ddcadb..b5bdd17c 100644 --- a/textattack/transformations/word_swap_wordnet.py +++ b/textattack/transformations/word_swap_wordnet.py @@ -1,5 +1,6 @@ from nltk.corpus import wordnet +import textattack from textattack.transformations.word_swap import WordSwap @@ -13,14 +14,12 @@ class WordSwapWordNet(WordSwap): synonyms = set() for syn in wordnet.synsets(word): for l in syn.lemmas(): - if l.name() != word and check_if_one_word(l.name()): + syn_word = l.name() + if ( + (syn_word != word) + and ("_" not in syn_word) + and (textattack.shared.utils.is_one_word(syn_word)) + ): # WordNet can suggest phrases that are joined by '_' but we ignore phrases. - synonyms.add(l.name()) + synonyms.add(syn_word) return list(synonyms) - - -def check_if_one_word(word): - for c in word: - if not c.isalpha(): - return False - return True