mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
@@ -69,7 +69,7 @@ or a specific command using, for example,
|
||||
textattack attack --help
|
||||
```
|
||||
|
||||
The [`examples/`](examples/) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. The[documentation website](https://textattack.readthedocs.io/en/latest) contains walkthroughs explaining basic usage of TextAttack, including building a custom transformation and a custom constraint..
|
||||
The [`examples/`](examples/) folder includes scripts showing common TextAttack usage for training models, running attacks, and augmenting a CSV file. The [documentation website](https://textattack.readthedocs.io/en/latest) contains walkthroughs explaining basic usage of TextAttack, including building a custom transformation and a custom constraint..
|
||||
|
||||
### Running Attacks
|
||||
|
||||
@@ -127,7 +127,7 @@ Attacks on sequence-to-sequence models:
|
||||
|
||||
#### Recipe Usage Examples
|
||||
|
||||
Here are some exampes of testing attacks from the literature from the command-line:
|
||||
Here are some examples of testing attacks from the literature from the command-line:
|
||||
|
||||
*TextFooler against BERT fine-tuned on SST-2:*
|
||||
```bash
|
||||
|
||||
26
docs/_static/css/custom.css
vendored
Normal file
26
docs/_static/css/custom.css
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
div.wy-side-nav-search .version {
|
||||
color: #404040;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
nav.wy-nav-top {
|
||||
background: #AA2396;
|
||||
}
|
||||
|
||||
div.wy-nav-content {
|
||||
max-width: 1000px;
|
||||
}
|
||||
|
||||
span.caption-text {
|
||||
color: #cc4878;
|
||||
}
|
||||
|
||||
/* Change header fonts to Cambria */
|
||||
.rst-content .toctree-wrapper>p.caption, h1, h2, h3, h4, h5, h6, legend {
|
||||
font-family: 'Cambria', serif;
|
||||
}
|
||||
|
||||
/* Change non-header default fonts to Helvetica */
|
||||
/** {
|
||||
font-family: 'Helvetica', sans-serif;
|
||||
}*/
|
||||
BIN
docs/_static/imgs/intro/ae_papers.png
vendored
Normal file
BIN
docs/_static/imgs/intro/ae_papers.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
docs/_static/imgs/intro/mr_aes.png
vendored
Normal file
BIN
docs/_static/imgs/intro/mr_aes.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
BIN
docs/_static/imgs/intro/mr_aes_table.png
vendored
Normal file
BIN
docs/_static/imgs/intro/mr_aes_table.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
BIN
docs/_static/imgs/intro/pig_airliner.png
vendored
Normal file
BIN
docs/_static/imgs/intro/pig_airliner.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 211 KiB |
BIN
docs/_static/imgs/intro/textattack_components.png
vendored
Normal file
BIN
docs/_static/imgs/intro/textattack_components.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 123 KiB |
18
docs/conf.py
18
docs/conf.py
@@ -22,7 +22,7 @@ copyright = "2020, UVA QData Lab"
|
||||
author = "UVA QData Lab"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = "0.2.8"
|
||||
release = "0.2.9"
|
||||
|
||||
# Set master doc to `index.rst`.
|
||||
master_doc = "index"
|
||||
@@ -54,6 +54,14 @@ exclude_patterns = ["_build", "**.ipynb_checkpoints"]
|
||||
# Mock expensive textattack imports. Docs imports are in `docs/requirements.txt`.
|
||||
autodoc_mock_imports = []
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = "textattack_doc"
|
||||
html_theme_options = {
|
||||
"logo_only": False,
|
||||
"style_nav_header_background": "transparent",
|
||||
"analytics_id": "UA-88637452-2",
|
||||
}
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
@@ -61,10 +69,10 @@ autodoc_mock_imports = []
|
||||
#
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = []
|
||||
html_static_path = ["_static"]
|
||||
html_css_files = [
|
||||
"css/custom.css",
|
||||
]
|
||||
|
||||
# Path to favicon.
|
||||
html_favicon = "favicon.png"
|
||||
|
||||
@@ -56,12 +56,28 @@ TextAttack has some other features that make it a pleasure to use:
|
||||
:hidden:
|
||||
:caption: Getting Started
|
||||
|
||||
|
||||
Installation <quickstart/installation>
|
||||
Command-Line Usage <quickstart/command_line_usage>
|
||||
What is an adversarial attack in NLP? <quickstart/what_is_an_adversarial_attack.md>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
:caption: Tutorials
|
||||
|
||||
Tutorial 0: TextAttack End-To-End (Train, Eval, Attack) <examples/0_End_to_End.ipynb>
|
||||
Tutorial 1: Transformations <examples/1_Introduction_and_Transformations.ipynb>
|
||||
Tutorial 2: Constraints <examples/2_Constraints.ipynb>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
:caption: Models
|
||||
|
||||
datasets_models/models
|
||||
Example: Attacking TensorFlow models <datasets_models/Example_0_tensorflow>
|
||||
Example: Attacking scikit-learn models <datasets_models/Example_1_sklearn.ipynb>
|
||||
Example: Attacking AllenNLP models <datasets_models/Example_2_allennlp.ipynb>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
@@ -73,7 +89,7 @@ TextAttack has some other features that make it a pleasure to use:
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:hidden:
|
||||
:caption: NLP Attacks
|
||||
:caption: Attacks
|
||||
|
||||
attacks/attack
|
||||
attacks/attack_result
|
||||
@@ -91,16 +107,6 @@ TextAttack has some other features that make it a pleasure to use:
|
||||
|
||||
augmentation/augmenter
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
:caption: Models and Tokenizers
|
||||
|
||||
datasets_models/models
|
||||
Example: Attacking TensorFlow models <datasets_models/Example_0_tensorflow>
|
||||
Example: Attacking scikit-learn models <datasets_models/Example_1_sklearn.ipynb>
|
||||
Example: Attacking AllenNLP models <datasets_models/Example_2_allennlp.ipynb>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:hidden:
|
||||
|
||||
@@ -180,14 +180,12 @@
|
||||
"\n",
|
||||
"# Create the recipe: PWWS uses a WordNet transformation.\n",
|
||||
"recipe = PWWSRen2019.build(model_wrapper)\n",
|
||||
"# WordNet defaults to english. Set the default language to French ('fra')\n",
|
||||
"recipe.transformation.language = 'fra'\n",
|
||||
"#\n",
|
||||
"# See \n",
|
||||
"# \"Building a free French wordnet from multilingual resources\", \n",
|
||||
"# WordNet defaults to english. Set the default language to French ('fra')\n",
|
||||
"#\n",
|
||||
"# See \"Building a free French wordnet from multilingual resources\", \n",
|
||||
"# E. L. R. A. (ELRA) (ed.), \n",
|
||||
"# Proceedings of the Sixth International Language Resources and Evaluation (LREC’08).\n",
|
||||
"\n",
|
||||
"recipe.transformation.language = 'fra'\n",
|
||||
"\n",
|
||||
"dataset = HuggingFaceNlpDataset('allocine', split='test')\n",
|
||||
|
||||
102
docs/quickstart/what_is_an_adversarial_attack.md
Normal file
102
docs/quickstart/what_is_an_adversarial_attack.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# What is an adversarial attack in NLP?
|
||||
|
||||
*This documentation page was adapted from [a blog post we wrote about adversarial examples in NLP](https://towardsdatascience.com/what-are-adversarial-examples-in-nlp-f928c574478e).*
|
||||
|
||||
This page is intended to clear up some terminology for those unclear on the meaning of the term ‘adversarial attack’ in natural language processing. We'll try and give an intro to NLP adversarial attacks, try to clear up lots of the scholarly jargon, and give a high-level overview of the uses of TextAttack.
|
||||
|
||||
This article talks about the concept of adversarial examples as applied to NLP (natural language processing). The terminology can be confusing at times, so we’ll begin with an overview of the language used to talk about adversarial examples and adversarial attacks. Then, we’ll talk about TextAttack, an open-source Python library for adversarial examples, data augmentation, and adversarial training in NLP that’s changing the way people research the robustness of NLP models. We’ll conclude with some thoughts on the future of this area of research.
|
||||
|
||||
An adversarial example is an input designed to fool a machine learning model [1]. In TextAttack, we are concerned with adversarial perturbations, changes to benign inputs that cause them to be misclassified by models. ‘Adversarial perturbation’ is more specific than just ‘adversarial example’, as the class of all adversarial examples also includes inputs designed from scratch to fool machine learning models. TextAttack attacks generate a specific kind of adversarial examples, adversarial perturbations.
|
||||
|
||||
As alluded to above, an adversarial attack on a machine learning model is a process for generating adversarial perturbations. TextAttack attacks iterate through a dataset (list of inputs to a model), and for each correctly predicted sample, search for an adversarial perturbation (we’ll talk more about this later). If an example is incorrectly predicted to begin with, it is not attacked, since the input already fools the model. TextAttack breaks the attack process up into stages, and provides a [system of interchangeable components](/examples/1_Introduction_and_Transformations.ipynb) for managing each stage of the attack.
|
||||
|
||||
Adversarial robustness is a measurement of a model’s susceptibility to adversarial examples. TextAttack often measures robustness using attack success rate, the percentage of attack attempts that produce successful adversarial examples, or after-attack accuracy, the percentage of inputs that are both correctly classified and unsuccessfully attacked.
|
||||
|
||||
To improve our numeracy when talking about adversarial attacks, let’s take a look at a concrete example of some attack results:
|
||||
|
||||

|
||||
|
||||
*These results come from using TextAttack to run the DeepWordBug attack on an LSTM trained on the Rotten Tomatoes Movie Review sentiment classification dataset, using 200 total examples. These results come from using TextAttack to run the DeepWordBug attack on an LSTM trained on the Rotten Tomatoes Movie Review sentiment classification dataset, using 200 total examples.*
|
||||
|
||||
This attack was run on 200 examples. Out of those 200, the model initially predicted 43 of them incorrectly; this leads to an accuracy of 157/200 or 78.5%. TextAttack ran the adversarial attack process on the remaining 157 examples to try to find a valid adversarial perturbation for each one. Out of those 157, 29 attacks failed, leading to a success rate of 128/157 or 81.5%. Another way to articulate this is that the model correctly predicted and resisted attacks for 29 out of 200 total samples, leading to an accuracy under attack (or “after-attack accuracy”) of 29/200 or 14.5%.
|
||||
|
||||
TextAttack also logged some other helpful statistics for this attack. Among the 157 successful attacks, on average, the attack changed 15.5% of words to alter the prediction, and made 32.7 queries to find a successful perturbation. Across all 200 inputs, the average number of words was 18.97.
|
||||
|
||||
Now that we have provided some terminology, let’s look at some concrete examples of proposed adversarial attacks. We will give some background on adversarial attacks in other domains and then examples of different attacks in NLP.
|
||||
|
||||
## Terminology
|
||||
|
||||
Research in 2013 [2] showed neural networks are vulnerable to adversarial examples. These original adversarial attacks apply a small, well-chosen perturbation to an image to fool an image classifier. In this example, the classifier correctly predicts the original image to be a pig. After a small perturbation, however, the classifier predicts the pig to be an airliner (with extremely high confidence!).
|
||||
|
||||

|
||||
|
||||
*An adversarial example for an ImageNet classifier. Superimposing a tiny (but deliberate) amount of noise causes the model to classify this pig as an airliner.*
|
||||
|
||||
|
||||
These adversarial examples exhibit a serious security flaw in deep neural networks. Therefore adversarial examples pose a security problem for downstream systems that include neural networks, including text-to-speech systems and self-driving cars. Adversarial examples are useful outside of security: researchers have used adversarial examples to improve and interpret deep learning models.
|
||||
|
||||
As you might imagine, adversarial examples in deep neural networks have caught the attention of many researchers around the world, and this 2013 paper spawned an explosion of research into the topic.
|
||||
|
||||
|
||||

|
||||
<br>
|
||||
*The number of papers related to ‘adversarial examples’ on arxiv.org between 2014 and 2020. [Graph from https://nicholas.carlini.com/writing/2019/all-adversarial-example-papers.html]*
|
||||
|
||||
|
||||
Many new, more sophisticated adversarial attacks have been proposed, along with “defenses,” procedures for training neural networks that are resistant (“robust”) against adversarial attacks. Training deep neural networks that are highly accurate while remaining robust to adversarial attacks remains an open problem [3].
|
||||
|
||||
Naturally, many have wondered about what adversarial examples for NLP models might be. No natural analogy to the adversarial examples in computer vision (like the pig-to-airliner bamboozle above) exists for NLP. After all, two sequences of text cannot be truly indistinguishable without being the same. (In the above example, the pig-classified input and its airliner-classified perturbation are literally indistinguishable to the human eye.)
|
||||
|
||||
|
||||
## Adversarial Examples in NLP
|
||||
|
||||

|
||||
|
||||
*Two different ideas of adversarial examples in NLP. These results were generated using TextAttack on an LSTM trained on the Rotten Tomatoes Movie Review sentiment classification dataset. These are *real* adversarial examples, generated using the DeepWordBug and TextFooler attacks. To generate them yourself, after installing TextAttack, run ‘textattack attack — model lstm-mr — num-examples 1 — recipe RECIPE — num-examples-offset 19’ where RECIPE is ‘deepwordbug’ or ‘textfooler’.*
|
||||
|
||||
Because two text sequences are never indistinguishable, researchers have proposed various alternative definitions for adversarial examples in NLP. We find it useful to group adversarial attacks based on their chosen definitions of adversarial examples.
|
||||
|
||||
Although attacks in NLP cannot find an adversarial perturbation that is literally indistinguishable to the original input, they can find a perturbation that is very similar. Our mental model groups NLP adversarial attacks into two groups, based on their notions of ‘similarity’:
|
||||
|
||||
|
||||
**Visual similarity.** Some NLP attacks consider an adversarial example to be a text sequence that looks very similar to the original input -- perhaps just a few character changes away -- but receives a different prediction from the model. Some of these adversarial attacks try to change as few characters as possible to change the model’s prediction; others try to introduce realistic ‘typos’ similar to those that humans would make.
|
||||
|
||||
Some researchers have raised concern that these attacks can be defended against quite effectively, either by using a rule-based spellchecker or a sequence-to-sequence model trained to correct adversarial typos.
|
||||
TextAttack attack recipes that fall under this category: deepwordbug, hotflip, pruthi, textbugger\*, morpheus
|
||||
|
||||
|
||||
**Semantic similarity.** Other NLP attacks consider an adversarial example valid if it is semantically indistinguishable from the original input. In other words, if the perturbation is a paraphrase of the original input, but the input and perturbation receive different predictions, then the input is a valid adversarial example.
|
||||
|
||||
Some NLP models are trained to measure semantic similarity. Adversarial attacks based on the notion of semantic indistinguishability typically use another NLP model to enforce that perturbations are grammatically valid and semantically similar to the original input.
|
||||
|
||||
TextAttack attack recipes that fall under this category: alzantot, bae, bert-attack, faster-alzantot, iga, kuleshov, pso, pwws, textbugger\*, textfooler
|
||||
|
||||
\*The textbugger attack generates perturbations using both typo-like character edits and synonym substitutions. It could be considered to use both definitions of indistinguishability.
|
||||
|
||||
## Generating adversarial examples with TextAttack
|
||||
|
||||
TextAttack supports adversarial attacks based in both definitions of indistinguishability. Both types of attacks are useful for training more robust NLP models. Our goal is to enable research into adversarial examples in NLP by providing a set of intuitive, reusable components for building as many attacks from the literature as possible.
|
||||
|
||||
We define the adversarial attack processing using four components: a goal function, constraints, transformation, and search method. (We’ll go into this in detail in a future post!) These components allow us to reuse many things between attacks from different research papers. They also make it easy to develop methods for NLP data augmentation.
|
||||
|
||||
TextAttack also includes code for loading popular NLP datasets and training models on them. By integrating this training code with adversarial attacks and data augmentation techniques, TextAttack provides an environment for researchers to test adversarial training in many different scenarios.
|
||||
|
||||
The following figure shows an overview of the main functionality of TextAttack:
|
||||
<br>
|
||||

|
||||
|
||||
|
||||
## The future of adversarial attacks in NLP
|
||||
|
||||
We are excited to see the impact that TextAttack has on the NLP research community! One thing we would like to see research in is the combination of components from various papers. TextAttack makes it easy to run ablation studies to compare the effects of swapping out, say, search method from paper A with the search method from paper B, without making any other changes. (And these tests can be run across dozens of pre-trained models and datasets with no downloads!)
|
||||
|
||||
We hope that use of TextAttack leads to more diversity in adversarial attacks. One thing that all current adversarial attacks have in common is that they make substitutions on the word or character level. We hope that future adversarial attacks in NLP can broaden scope to try different approaches to phrase-level replacements as well as full-sentence paraphrases. Additionally, there has been a focus on English in the adversarial attack literature; we look forward to seeing adversarial attacks applied to more languages.
|
||||
|
||||
To get started with TextAttack, you might want to start with one of our [introductory tutorials](/examples/0_End_to_End.ipynb).
|
||||
|
||||
|
||||
.. [1] “Attacking Machine Learning with Adversarial Examples”, Goodfellow, 2013. [https://openai.com/blog/adversarial-example-research/]
|
||||
|
||||
.. [2] “Intriguing properties of neural networks”, Szegedy, 2013. [https://arxiv.org/abs/1312.6199]
|
||||
|
||||
.. [3] “Robustness May Be at Odds with Accuracy”, Tsipras, 2018. [https://arxiv.org/abs/1805.12152]
|
||||
@@ -1,6 +1,6 @@
|
||||
bert-score
|
||||
bert-score>=0.3.5
|
||||
editdistance
|
||||
flair==0.5.1
|
||||
flair==0.6.0.post1
|
||||
filelock
|
||||
language_tool_python
|
||||
lemminflect
|
||||
|
||||
4
setup.py
4
setup.py
@@ -11,9 +11,9 @@ extras = {}
|
||||
extras["docs"] = ["recommonmark", "nbsphinx", "sphinx-autobuild", "sphinx-rtd-theme"]
|
||||
# Packages required for formatting code & running tests.
|
||||
extras["test"] = [
|
||||
"black",
|
||||
"black==20.8b1",
|
||||
"docformatter",
|
||||
"isort==5.0.3",
|
||||
"isort==5.4.2",
|
||||
"flake8",
|
||||
"pytest",
|
||||
"pytest-xdist",
|
||||
|
||||
@@ -12,6 +12,34 @@ from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
flair.device = textattack.shared.utils.device
|
||||
|
||||
|
||||
def load_flair_upos_fast():
|
||||
"""Loads flair 'upos-fast' SequenceTagger.
|
||||
|
||||
This is a temporary workaround for flair v0.6. Will be fixed when
|
||||
flair pushes the bug fix.
|
||||
"""
|
||||
import pathlib
|
||||
import warnings
|
||||
|
||||
from flair import file_utils
|
||||
import torch
|
||||
|
||||
hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
|
||||
upos_path = "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"])
|
||||
model_path = file_utils.cached_path(upos_path, cache_dir=pathlib.Path("models"))
|
||||
model_file = SequenceTagger._fetch_model(model_path)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
# load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups
|
||||
# see https://github.com/zalandoresearch/flair/issues/351
|
||||
f = file_utils.load_big_file(str(model_file))
|
||||
state = torch.load(f, map_location="cpu")
|
||||
model = SequenceTagger._init_model_with_state_dict(state)
|
||||
model.eval()
|
||||
model.to(textattack.shared.utils.device)
|
||||
return model
|
||||
|
||||
|
||||
class PartOfSpeech(Constraint):
|
||||
"""Constraints word swaps to only swap words with the same part of speech.
|
||||
Uses the NLTK universal part-of-speech tagger by default. An implementation
|
||||
@@ -43,7 +71,7 @@ class PartOfSpeech(Constraint):
|
||||
self._pos_tag_cache = lru.LRU(2 ** 14)
|
||||
if tagger_type == "flair":
|
||||
if tagset == "universal":
|
||||
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
|
||||
self._flair_pos_tagger = load_flair_upos_fast()
|
||||
else:
|
||||
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
|
||||
@@ -69,14 +97,16 @@ class PartOfSpeech(Constraint):
|
||||
if self.tagger_type == "flair":
|
||||
context_key_sentence = Sentence(context_key)
|
||||
self._flair_pos_tagger.predict(context_key_sentence)
|
||||
word_list, pos_list = zip_flair_result(context_key_sentence)
|
||||
word_list, pos_list = textattack.shared.utils.zip_flair_result(
|
||||
context_key_sentence
|
||||
)
|
||||
|
||||
self._pos_tag_cache[context_key] = (word_list, pos_list)
|
||||
|
||||
# idx of `word` in `context_words`
|
||||
idx = len(before_ctx)
|
||||
assert word_list[idx] == word, "POS list not matched with original word list."
|
||||
return pos_list[idx]
|
||||
assert word in word_list, "POS list not matched with original word list."
|
||||
word_idx = word_list.index(word)
|
||||
return pos_list[word_idx]
|
||||
|
||||
def _check_constraint(self, transformed_text, reference_text):
|
||||
try:
|
||||
@@ -109,17 +139,3 @@ class PartOfSpeech(Constraint):
|
||||
"tagset",
|
||||
"allow_verb_noun_swap",
|
||||
] + super().extra_repr_keys()
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
if not isinstance(pred, Sentence):
|
||||
raise TypeError("Result from Flair POS tagger must be a `Sentence` object.")
|
||||
|
||||
tokens = pred.tokens
|
||||
word_list = []
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import flair
|
||||
from flair.data import Sentence
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
|
||||
from .utils import words_from_text
|
||||
from .utils import device, words_from_text
|
||||
|
||||
flair.device = device
|
||||
|
||||
|
||||
class AttackedText:
|
||||
@@ -40,9 +44,10 @@ class AttackedText:
|
||||
raise TypeError(
|
||||
f"Invalid text_input type {type(text_input)} (required str or OrderedDict)"
|
||||
)
|
||||
# Find words in input lazily.
|
||||
# Process input lazily.
|
||||
self._words = None
|
||||
self._words_per_input = None
|
||||
self._pos_tags = None
|
||||
# Format text inputs.
|
||||
self._text_input = OrderedDict([(k, v) for k, v in self._text_input.items()])
|
||||
if attack_attrs is None:
|
||||
@@ -113,6 +118,34 @@ class AttackedText:
|
||||
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
|
||||
return self.text[text_idx_start:text_idx_end]
|
||||
|
||||
def pos_of_word_index(self, desired_word_idx):
|
||||
"""Returns the part-of-speech of the word at index `word_idx`.
|
||||
|
||||
Uses FLAIR part-of-speech tagger.
|
||||
"""
|
||||
if not self._pos_tags:
|
||||
sentence = Sentence(self.text)
|
||||
textattack.shared.utils.flair_tag(sentence)
|
||||
self._pos_tags = sentence
|
||||
flair_word_list, flair_pos_list = textattack.shared.utils.zip_flair_result(
|
||||
self._pos_tags
|
||||
)
|
||||
|
||||
for word_idx, word in enumerate(self.words):
|
||||
assert (
|
||||
word in flair_word_list
|
||||
), "word absent in flair returned part-of-speech tags"
|
||||
word_idx_in_flair_tags = flair_word_list.index(word)
|
||||
if word_idx == desired_word_idx:
|
||||
return flair_pos_list[word_idx_in_flair_tags]
|
||||
else:
|
||||
flair_word_list = flair_word_list[word_idx_in_flair_tags + 1 :]
|
||||
flair_pos_list = flair_pos_list[word_idx_in_flair_tags + 1 :]
|
||||
|
||||
raise ValueError(
|
||||
f"Did not find word from index {desired_word_idx} in flair POS tag"
|
||||
)
|
||||
|
||||
def _text_index_of_word_index(self, i):
|
||||
"""Returns the index of word ``i`` in self.text."""
|
||||
pre_words = self.words[: i + 1]
|
||||
|
||||
@@ -163,3 +163,34 @@ def color_text(text, color=None, method=None):
|
||||
return color + text + ANSI_ESCAPE_CODES.STOP
|
||||
elif method == "file":
|
||||
return "[[" + text + "]]"
|
||||
|
||||
|
||||
_flair_pos_tagger = None
|
||||
|
||||
|
||||
def flair_tag(sentence):
|
||||
"""Tags a `Sentence` object using `flair` part-of-speech tagger."""
|
||||
global _flair_pos_tagger
|
||||
if not _flair_pos_tagger:
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
_flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
_flair_pos_tagger.predict(sentence)
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
"""Takes a sentence tagging from `flair` and returns two lists, of words
|
||||
and their corresponding parts-of-speech."""
|
||||
from flair.data import Sentence
|
||||
|
||||
if not isinstance(pred, Sentence):
|
||||
raise TypeError("Result from Flair POS tagger must be a `Sentence` object.")
|
||||
|
||||
tokens = pred.tokens
|
||||
word_list = []
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
import pickle
|
||||
|
||||
import flair
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
from textattack.shared import utils
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
flair.device = utils.device
|
||||
|
||||
|
||||
class WordSwapHowNet(WordSwap):
|
||||
"""Transforms an input by replacing its words with synonyms in the stored
|
||||
@@ -29,7 +23,6 @@ class WordSwapHowNet(WordSwap):
|
||||
with open(cache_path, "rb") as fp:
|
||||
self.candidates_bank = pickle.load(fp)
|
||||
|
||||
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
self.pos_dict = {"JJ": "adj", "NN": "noun", "RB": "adv", "VB": "verb"}
|
||||
|
||||
def _get_replacement_words(self, word, word_pos):
|
||||
@@ -55,20 +48,10 @@ class WordSwapHowNet(WordSwap):
|
||||
return []
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
words = current_text.words
|
||||
sentence = Sentence(" ".join(words))
|
||||
# in-place POS tagging
|
||||
self._flair_pos_tagger.predict(sentence)
|
||||
word_list, pos_list = zip_flair_result(sentence)
|
||||
|
||||
assert len(words) == len(
|
||||
word_list
|
||||
), "Part-of-speech tagger returned incorrect number of tags"
|
||||
transformed_texts = []
|
||||
|
||||
for i in indices_to_modify:
|
||||
word_to_replace = words[i]
|
||||
word_to_replace_pos = pos_list[i][:2] # get the root POS
|
||||
word_to_replace = current_text.words[i]
|
||||
word_to_replace_pos = current_text.pos_of_word_index(i)
|
||||
replacement_words = self._get_replacement_words(
|
||||
word_to_replace, word_to_replace_pos
|
||||
)
|
||||
@@ -100,18 +83,3 @@ def recover_word_case(word, reference_word):
|
||||
else:
|
||||
# if other, just do not alter the word's case
|
||||
return word
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
"""Parse the output from the FLAIR POS tagger."""
|
||||
if not isinstance(pred, Sentence):
|
||||
raise TypeError("Result from Flair POS tagger must be a `Sentence` object.")
|
||||
|
||||
tokens = pred.tokens
|
||||
word_list = []
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
import flair
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
import lemminflect
|
||||
|
||||
from textattack.shared import utils
|
||||
from textattack.transformations.word_swap import WordSwap
|
||||
|
||||
flair.device = utils.device
|
||||
|
||||
|
||||
class WordSwapInflections(WordSwap):
|
||||
"""Transforms an input by replacing its words with their inflections.
|
||||
@@ -22,8 +16,6 @@ class WordSwapInflections(WordSwap):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._flair_pos_tagger = SequenceTagger.load("pos-fast")
|
||||
self._flair_to_lemminflect_pos_map = {"NN": "NOUN", "VB": "VERB", "JJ": "ADJ"}
|
||||
|
||||
def _get_replacement_words(self, word, word_part_of_speech):
|
||||
@@ -35,23 +27,16 @@ class WordSwapInflections(WordSwap):
|
||||
# to available inflections. First, map part-of-speech from flair
|
||||
# POS tag to lemminflect.
|
||||
lemminflect_pos = self._flair_to_lemminflect_pos_map[word_part_of_speech]
|
||||
return replacement_inflections_dict.get(lemminflect_pos, None)
|
||||
replacement_words = replacement_inflections_dict.get(lemminflect_pos, list())
|
||||
replacement_words = [r for r in replacement_words if r != word]
|
||||
return replacement_words
|
||||
|
||||
def _get_transformations(self, current_text, indices_to_modify):
|
||||
words = current_text.words
|
||||
sentence = Sentence(" ".join(words))
|
||||
self._flair_pos_tagger.predict(sentence)
|
||||
word_list, pos_list = zip_flair_result(sentence)
|
||||
|
||||
assert len(words) == len(
|
||||
word_list
|
||||
), "Part-of-speech tagger returned incorrect number of tags"
|
||||
|
||||
transformed_texts = []
|
||||
|
||||
for i in indices_to_modify:
|
||||
word_to_replace = words[i]
|
||||
word_to_replace_pos = pos_list[i][:2] # get the root POS
|
||||
word_to_replace = current_text.words[i]
|
||||
word_to_replace_pos = current_text.pos_of_word_index(i)
|
||||
replacement_words = (
|
||||
self._get_replacement_words(word_to_replace, word_to_replace_pos) or []
|
||||
)
|
||||
@@ -59,18 +44,3 @@ class WordSwapInflections(WordSwap):
|
||||
transformed_texts.append(current_text.replace_word_at_index(i, r))
|
||||
|
||||
return transformed_texts
|
||||
|
||||
|
||||
def zip_flair_result(pred):
|
||||
"""Parse the output from the FLAIR POS tagger."""
|
||||
if not isinstance(pred, Sentence):
|
||||
raise TypeError("Result from Flair POS tagger must be a `Sentence` object.")
|
||||
|
||||
tokens = pred.tokens
|
||||
word_list = []
|
||||
pos_list = []
|
||||
for token in tokens:
|
||||
word_list.append(token.text)
|
||||
pos_list.append(token.annotation_layers["pos"][0]._value)
|
||||
|
||||
return word_list, pos_list
|
||||
|
||||
Reference in New Issue
Block a user