mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
@@ -47,3 +47,34 @@ You can also install other miscallenous optional dependencies by running
|
||||
To install both groups of packages, run
|
||||
|
||||
pip install textattack[tensorflow,optional]
|
||||
|
||||
|
||||
|
||||
## FAQ on installation
|
||||
|
||||
For many of the dependent library issues, the following command is the first you could try:
|
||||
```bash
|
||||
pip install --force-reinstall textattack
|
||||
```
|
||||
|
||||
OR
|
||||
```bash
|
||||
pip install textattack[tensorflow,optional]
|
||||
```
|
||||
|
||||
|
||||
Besides, we highly recommend you to use virtual environment for textattack use,
|
||||
see [information here](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#removing-an-environment). Here is one conda example:
|
||||
|
||||
```bash
|
||||
conda create -n textattackenv python=3.7
|
||||
conda activate textattackenv
|
||||
conda env list
|
||||
```
|
||||
|
||||
If you want to use the most-up-to-date version of textattack (normally with newer bug fixes), you can run the following:
|
||||
```bash
|
||||
git clone https://github.com/QData/TextAttack.git
|
||||
cd TextAttack
|
||||
pip install .[dev]
|
||||
```
|
||||
@@ -7,6 +7,7 @@ textattack.constraints.grammaticality.language\_models package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.constraints.grammaticality package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.constraints package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.constraints.semantics package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.constraints.semantics.sentence\_encoders package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.datasets package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.goal\_functions package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
26
docs/apidoc/textattack.metrics.attack_metrics.rst
Normal file
26
docs/apidoc/textattack.metrics.attack_metrics.rst
Normal file
@@ -0,0 +1,26 @@
|
||||
textattack.metrics.attack\_metrics package
|
||||
==========================================
|
||||
|
||||
.. automodule:: textattack.metrics.attack_metrics
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. automodule:: textattack.metrics.attack_metrics.attack_queries
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. automodule:: textattack.metrics.attack_metrics.attack_success_rate
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. automodule:: textattack.metrics.attack_metrics.words_perturbed
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
20
docs/apidoc/textattack.metrics.quality_metrics.rst
Normal file
20
docs/apidoc/textattack.metrics.quality_metrics.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
textattack.metrics.quality\_metrics package
|
||||
===========================================
|
||||
|
||||
.. automodule:: textattack.metrics.quality_metrics
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. automodule:: textattack.metrics.quality_metrics.perplexity
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. automodule:: textattack.metrics.quality_metrics.use
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
22
docs/apidoc/textattack.metrics.rst
Normal file
22
docs/apidoc/textattack.metrics.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
textattack.metrics package
|
||||
==========================
|
||||
|
||||
.. automodule:: textattack.metrics
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
textattack.metrics.attack_metrics
|
||||
textattack.metrics.quality_metrics
|
||||
|
||||
|
||||
|
||||
.. automodule:: textattack.metrics.metric
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@@ -7,6 +7,7 @@ textattack.models package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ textattack package
|
||||
textattack.goal_function_results
|
||||
textattack.goal_functions
|
||||
textattack.loggers
|
||||
textattack.metrics
|
||||
textattack.models
|
||||
textattack.search_methods
|
||||
textattack.shared
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.shared package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ textattack.transformations package
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 6
|
||||
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
/.*/Attack(
|
||||
(search_method): BeamSearch(
|
||||
(beam_width): 10
|
||||
)
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): WordSwapGradientBased(
|
||||
(top_n): 1
|
||||
)
|
||||
(constraints):
|
||||
(0): MaxWordsPerturbed(
|
||||
(max_num_words): 2
|
||||
(compare_against_original): True
|
||||
)
|
||||
(1): WordEmbeddingDistance(
|
||||
(embedding): WordEmbedding
|
||||
(min_cos_sim): 0.8
|
||||
(cased): False
|
||||
(include_unknown_words): True
|
||||
(compare_against_original): True
|
||||
)
|
||||
(2): PartOfSpeech(
|
||||
(tagger_type): nltk
|
||||
(tagset): universal
|
||||
(allow_verb_noun_swap): True
|
||||
(compare_against_original): True
|
||||
)
|
||||
(3): RepeatModification
|
||||
(4): StopwordModification
|
||||
(is_black_box): False
|
||||
)
|
||||
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[[Positive (96%)]] --> [[Negative (77%)]]
|
||||
|
||||
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[supplies]] with tremendous skill .
|
||||
|
||||
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[stagnated]] with tremendous skill .
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[[Negative (57%)]] --> [[[SKIPPED]]]
|
||||
|
||||
red dragon " never cuts corners .
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[[Positive (51%)]] --> [[[FAILED]]]
|
||||
|
||||
fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense .
|
||||
|
||||
|
||||
--------------------------------------------- Result 4 ---------------------------------------------
|
||||
[[Positive (89%)]] --> [[[FAILED]]]
|
||||
|
||||
throws in enough clever and unexpected twists to make the formula feel fresh .
|
||||
|
||||
|
||||
|
||||
+-------------------------------+--------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+--------+
|
||||
| Number of successful attacks: | 1 |
|
||||
| Number of failed attacks: | 2 |
|
||||
| Number of skipped attacks: | 1 |
|
||||
| Original accuracy: | 75.0% |
|
||||
| Accuracy under attack: | 50.0% |
|
||||
| Attack success rate: | 33.33% |
|
||||
| Average perturbed word %: | 5.56% |
|
||||
| Average num. words per input: | 15.5 |
|
||||
| Avg num queries: | 1.33 |
|
||||
| Average Original Perplexity: | 291.47 |
|
||||
| Average Attack Perplexity: | 320.33 |
|
||||
| Average Attack USE Score: | 0.91 |
|
||||
+-------------------------------+--------+
|
||||
@@ -0,0 +1,68 @@
|
||||
/.*/Attack(
|
||||
(search_method): GreedyWordSwapWIR(
|
||||
(wir_method): unk
|
||||
)
|
||||
(goal_function): UntargetedClassification
|
||||
(transformation): CompositeTransformation(
|
||||
(0): WordSwapNeighboringCharacterSwap(
|
||||
(random_one): True
|
||||
)
|
||||
(1): WordSwapRandomCharacterSubstitution(
|
||||
(random_one): True
|
||||
)
|
||||
(2): WordSwapRandomCharacterDeletion(
|
||||
(random_one): True
|
||||
)
|
||||
(3): WordSwapRandomCharacterInsertion(
|
||||
(random_one): True
|
||||
)
|
||||
)
|
||||
(constraints):
|
||||
(0): LevenshteinEditDistance(
|
||||
(max_edit_distance): 30
|
||||
(compare_against_original): True
|
||||
)
|
||||
(1): RepeatModification
|
||||
(2): StopwordModification
|
||||
(is_black_box): True
|
||||
)
|
||||
|
||||
--------------------------------------------- Result 1 ---------------------------------------------
|
||||
[[Negative (100%)]] --> [[Positive (71%)]]
|
||||
|
||||
[[hide]] [[new]] secretions from the parental units
|
||||
|
||||
[[Ehide]] [[enw]] secretions from the parental units
|
||||
|
||||
|
||||
--------------------------------------------- Result 2 ---------------------------------------------
|
||||
[[Negative (100%)]] --> [[[FAILED]]]
|
||||
|
||||
contains no wit , only labored gags
|
||||
|
||||
|
||||
--------------------------------------------- Result 3 ---------------------------------------------
|
||||
[[Positive (100%)]] --> [[Negative (96%)]]
|
||||
|
||||
that [[loves]] its characters and communicates [[something]] [[rather]] [[beautiful]] about human nature
|
||||
|
||||
that [[lodes]] its characters and communicates [[somethNng]] [[rathrer]] [[beautifdul]] about human nature
|
||||
|
||||
|
||||
|
||||
+-------------------------------+---------+
|
||||
| Attack Results | |
|
||||
+-------------------------------+---------+
|
||||
| Number of successful attacks: | 2 |
|
||||
| Number of failed attacks: | 1 |
|
||||
| Number of skipped attacks: | 0 |
|
||||
| Original accuracy: | 100.0% |
|
||||
| Accuracy under attack: | 33.33% |
|
||||
| Attack success rate: | 66.67% |
|
||||
| Average perturbed word %: | 30.95% |
|
||||
| Average num. words per input: | 8.33 |
|
||||
| Avg num queries: | 22.67 |
|
||||
| Average Original Perplexity: | 1126.57 |
|
||||
| Average Attack Perplexity: | 2823/.*/|
|
||||
| Average Attack USE Score: | 0.76 |
|
||||
+-------------------------------+---------+
|
||||
@@ -48,6 +48,20 @@ attack_test_params = [
|
||||
"tests/sample_outputs/run_attack_transformers_datasets.txt",
|
||||
),
|
||||
#
|
||||
# test loading an attack from the transformers model hub and calculate perplexity and use
|
||||
#
|
||||
(
|
||||
"attack_from_transformers_adv_metrics",
|
||||
(
|
||||
"textattack attack --model-from-huggingface "
|
||||
"distilbert-base-uncased-finetuned-sst-2-english "
|
||||
"--dataset-from-huggingface glue^sst2^train --recipe deepwordbug --num-examples 3 "
|
||||
"--enable-advance-metrics"
|
||||
""
|
||||
),
|
||||
"tests/sample_outputs/run_attack_transformers_datasets_adv_metrics.txt",
|
||||
),
|
||||
#
|
||||
# test running an attack by loading a model and dataset from file
|
||||
#
|
||||
(
|
||||
@@ -72,6 +86,17 @@ attack_test_params = [
|
||||
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt",
|
||||
),
|
||||
#
|
||||
# test hotflip on 10 samples from LSTM MR and calculate perplexity and use
|
||||
#
|
||||
(
|
||||
"run_attack_hotflip_lstm_mr_4_adv_metrics",
|
||||
(
|
||||
"textattack attack --model lstm-mr --recipe hotflip "
|
||||
"--num-examples 4 --num-examples-offset 3 --enable-advance-metrics "
|
||||
),
|
||||
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4_adv_metrics.txt",
|
||||
),
|
||||
#
|
||||
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
|
||||
#
|
||||
(
|
||||
|
||||
@@ -8,7 +8,6 @@ TextAttack makes experimenting with the robustness of NLP models seamless, fast,
|
||||
|
||||
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
|
||||
"""
|
||||
|
||||
from .attack_args import AttackArgs, CommandLineAttackArgs
|
||||
from .augment_args import AugmenterArgs
|
||||
from .dataset_args import DatasetArgs
|
||||
@@ -17,6 +16,7 @@ from .training_args import TrainingArgs, CommandLineTrainingArgs
|
||||
from .attack import Attack
|
||||
from .attacker import Attacker
|
||||
from .trainer import Trainer
|
||||
from .metrics import Metric
|
||||
|
||||
from . import (
|
||||
attack_recipes,
|
||||
@@ -28,10 +28,12 @@ from . import (
|
||||
goal_function_results,
|
||||
goal_functions,
|
||||
loggers,
|
||||
metrics,
|
||||
models,
|
||||
search_methods,
|
||||
shared,
|
||||
transformations,
|
||||
)
|
||||
|
||||
|
||||
name = "textattack"
|
||||
|
||||
@@ -174,6 +174,8 @@ class AttackArgs:
|
||||
Disable displaying individual attack results to stdout.
|
||||
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
|
||||
enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
|
||||
"""
|
||||
|
||||
num_examples: int = 10
|
||||
@@ -194,6 +196,7 @@ class AttackArgs:
|
||||
log_to_wandb: str = None
|
||||
disable_stdout: bool = False
|
||||
silent: bool = False
|
||||
enable_advance_metrics: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_successful_examples:
|
||||
@@ -351,6 +354,12 @@ class AttackArgs:
|
||||
default=default_obj.silent,
|
||||
help="Disable all logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-advance-metrics",
|
||||
action="store_true",
|
||||
default=default_obj.enable_advance_metrics,
|
||||
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -219,6 +219,10 @@ class Attacker:
|
||||
# Enable summary stdout
|
||||
if not self.attack_args.silent and self.attack_args.disable_stdout:
|
||||
self.attack_log_manager.enable_stdout()
|
||||
|
||||
if self.attack_args.enable_advance_metrics:
|
||||
self.attack_log_manager.enable_advance_metrics = True
|
||||
|
||||
self.attack_log_manager.log_summary()
|
||||
self.attack_log_manager.flush()
|
||||
print()
|
||||
@@ -390,6 +394,10 @@ class Attacker:
|
||||
# Enable summary stdout.
|
||||
if not self.attack_args.silent and self.attack_args.disable_stdout:
|
||||
self.attack_log_manager.enable_stdout()
|
||||
|
||||
if self.attack_args.enable_advance_metrics:
|
||||
self.attack_log_manager.enable_advance_metrics = True
|
||||
|
||||
self.attack_log_manager.log_summary()
|
||||
self.attack_log_manager.flush()
|
||||
print()
|
||||
|
||||
@@ -39,6 +39,8 @@ class EvalModelCommand(TextAttackCommand):
|
||||
def test_model_on_dataset(self, args):
|
||||
model = ModelArgs._create_model_from_args(args)
|
||||
dataset = DatasetArgs._create_dataset_from_args(args)
|
||||
if args.num_examples == -1:
|
||||
args.num_examples = len(dataset)
|
||||
|
||||
preds = []
|
||||
ground_truth_outputs = []
|
||||
|
||||
@@ -275,7 +275,9 @@ class DatasetArgs:
|
||||
dataset_args = (dataset_args,)
|
||||
if args.dataset_split:
|
||||
if len(dataset_args) > 1:
|
||||
dataset_args[2] = args.dataset_split
|
||||
dataset_args = (
|
||||
dataset_args[:1] + (args.dataset_split,) + dataset_args[2:]
|
||||
)
|
||||
dataset = textattack.datasets.HuggingFaceDataset(
|
||||
*dataset_args, shuffle=False
|
||||
)
|
||||
|
||||
@@ -3,9 +3,12 @@ Managing Attack Logs.
|
||||
========================
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
from textattack.metrics.attack_metrics import (
|
||||
AttackQueries,
|
||||
AttackSuccessRate,
|
||||
WordsPerturbed,
|
||||
)
|
||||
from textattack.metrics.quality_metrics import Perplexity, USEMetric
|
||||
|
||||
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
|
||||
|
||||
@@ -16,6 +19,7 @@ class AttackLogManager:
|
||||
def __init__(self):
|
||||
self.loggers = []
|
||||
self.results = []
|
||||
self.enable_advance_metrics = False
|
||||
|
||||
def enable_stdout(self):
|
||||
self.loggers.append(FileLogger(stdout=True))
|
||||
@@ -72,103 +76,77 @@ class AttackLogManager:
|
||||
total_attacks = len(self.results)
|
||||
if total_attacks == 0:
|
||||
return
|
||||
# Count things about attacks.
|
||||
all_num_words = np.zeros(len(self.results))
|
||||
perturbed_word_percentages = np.zeros(len(self.results))
|
||||
num_words_changed_until_success = np.zeros(
|
||||
2 ** 16
|
||||
) # @ TODO: be smarter about this
|
||||
failed_attacks = 0
|
||||
skipped_attacks = 0
|
||||
successful_attacks = 0
|
||||
max_words_changed = 0
|
||||
for i, result in enumerate(self.results):
|
||||
all_num_words[i] = len(result.original_result.attacked_text.words)
|
||||
if isinstance(result, FailedAttackResult):
|
||||
failed_attacks += 1
|
||||
continue
|
||||
elif isinstance(result, SkippedAttackResult):
|
||||
skipped_attacks += 1
|
||||
continue
|
||||
else:
|
||||
successful_attacks += 1
|
||||
num_words_changed = result.original_result.attacked_text.words_diff_num(
|
||||
result.perturbed_result.attacked_text
|
||||
)
|
||||
# num_words_changed = len(
|
||||
# result.original_result.attacked_text.all_words_diff(
|
||||
# result.perturbed_result.attacked_text
|
||||
# )
|
||||
# )
|
||||
num_words_changed_until_success[num_words_changed - 1] += 1
|
||||
max_words_changed = max(
|
||||
max_words_changed or num_words_changed, num_words_changed
|
||||
)
|
||||
if len(result.original_result.attacked_text.words) > 0:
|
||||
perturbed_word_percentage = (
|
||||
num_words_changed
|
||||
* 100.0
|
||||
/ len(result.original_result.attacked_text.words)
|
||||
)
|
||||
else:
|
||||
perturbed_word_percentage = 0
|
||||
perturbed_word_percentages[i] = perturbed_word_percentage
|
||||
|
||||
# Original classifier success rate on these samples.
|
||||
original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks)
|
||||
original_accuracy = str(round(original_accuracy, 2)) + "%"
|
||||
|
||||
# New classifier success rate on these samples.
|
||||
accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks)
|
||||
accuracy_under_attack = str(round(accuracy_under_attack, 2)) + "%"
|
||||
|
||||
# Attack success rate.
|
||||
if successful_attacks + failed_attacks == 0:
|
||||
attack_success_rate = 0
|
||||
else:
|
||||
attack_success_rate = (
|
||||
successful_attacks * 100.0 / (successful_attacks + failed_attacks)
|
||||
)
|
||||
attack_success_rate = str(round(attack_success_rate, 2)) + "%"
|
||||
|
||||
perturbed_word_percentages = perturbed_word_percentages[
|
||||
perturbed_word_percentages > 0
|
||||
]
|
||||
average_perc_words_perturbed = perturbed_word_percentages.mean()
|
||||
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + "%"
|
||||
|
||||
average_num_words = all_num_words.mean()
|
||||
average_num_words = str(round(average_num_words, 2))
|
||||
# Default metrics - calculated on every attack
|
||||
attack_success_stats = AttackSuccessRate().calculate(self.results)
|
||||
words_perturbed_stats = WordsPerturbed().calculate(self.results)
|
||||
attack_query_stats = AttackQueries().calculate(self.results)
|
||||
|
||||
# @TODO generate this table based on user input - each column in specific class
|
||||
# Example to demonstrate:
|
||||
# summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ...
|
||||
summary_table_rows = [
|
||||
["Number of successful attacks:", str(successful_attacks)],
|
||||
["Number of failed attacks:", str(failed_attacks)],
|
||||
["Number of skipped attacks:", str(skipped_attacks)],
|
||||
["Original accuracy:", original_accuracy],
|
||||
["Accuracy under attack:", accuracy_under_attack],
|
||||
["Attack success rate:", attack_success_rate],
|
||||
["Average perturbed word %:", average_perc_words_perturbed],
|
||||
["Average num. words per input:", average_num_words],
|
||||
[
|
||||
"Number of successful attacks:",
|
||||
attack_success_stats["successful_attacks"],
|
||||
],
|
||||
["Number of failed attacks:", attack_success_stats["failed_attacks"]],
|
||||
["Number of skipped attacks:", attack_success_stats["skipped_attacks"]],
|
||||
[
|
||||
"Original accuracy:",
|
||||
str(attack_success_stats["original_accuracy"]) + "%",
|
||||
],
|
||||
[
|
||||
"Accuracy under attack:",
|
||||
str(attack_success_stats["attack_accuracy_perc"]) + "%",
|
||||
],
|
||||
[
|
||||
"Attack success rate:",
|
||||
str(attack_success_stats["attack_success_rate"]) + "%",
|
||||
],
|
||||
[
|
||||
"Average perturbed word %:",
|
||||
str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%",
|
||||
],
|
||||
[
|
||||
"Average num. words per input:",
|
||||
words_perturbed_stats["avg_word_perturbed"],
|
||||
],
|
||||
]
|
||||
|
||||
num_queries = np.array(
|
||||
[
|
||||
r.num_queries
|
||||
for r in self.results
|
||||
if not isinstance(r, SkippedAttackResult)
|
||||
]
|
||||
summary_table_rows.append(
|
||||
["Avg num queries:", attack_query_stats["avg_num_queries"]]
|
||||
)
|
||||
avg_num_queries = num_queries.mean()
|
||||
avg_num_queries = str(round(avg_num_queries, 2))
|
||||
summary_table_rows.append(["Avg num queries:", avg_num_queries])
|
||||
|
||||
if self.enable_advance_metrics:
|
||||
perplexity_stats = Perplexity().calculate(self.results)
|
||||
use_stats = USEMetric().calculate(self.results)
|
||||
|
||||
summary_table_rows.append(
|
||||
[
|
||||
"Average Original Perplexity:",
|
||||
perplexity_stats["avg_original_perplexity"],
|
||||
]
|
||||
)
|
||||
|
||||
summary_table_rows.append(
|
||||
[
|
||||
"Average Attack Perplexity:",
|
||||
perplexity_stats["avg_attack_perplexity"],
|
||||
]
|
||||
)
|
||||
summary_table_rows.append(
|
||||
["Average Attack USE Score:", use_stats["avg_attack_use_score"]]
|
||||
)
|
||||
|
||||
self.log_summary_rows(
|
||||
summary_table_rows, "Attack Results", "attack_results_summary"
|
||||
)
|
||||
# Show histogram of words changed.
|
||||
numbins = max(max_words_changed, 10)
|
||||
numbins = max(words_perturbed_stats["max_words_changed"], 10)
|
||||
for logger in self.loggers:
|
||||
logger.log_hist(
|
||||
num_words_changed_until_success[:numbins],
|
||||
words_perturbed_stats["num_words_changed_until_success"][:numbins],
|
||||
numbins=numbins,
|
||||
title="Num Words Perturbed",
|
||||
window_id="num_words_perturbed",
|
||||
|
||||
11
textattack/metrics/__init__.py
Normal file
11
textattack/metrics/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
"""
|
||||
|
||||
from .metric import Metric
|
||||
|
||||
from .attack_metrics import AttackSuccessRate
|
||||
from .attack_metrics import WordsPerturbed
|
||||
from .attack_metrics import AttackQueries
|
||||
|
||||
from .quality_metrics import Perplexity
|
||||
from .quality_metrics import USEMetric
|
||||
12
textattack/metrics/attack_metrics/__init__.py
Normal file
12
textattack/metrics/attack_metrics/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
|
||||
attack_metrics:
|
||||
======================
|
||||
|
||||
TextAttack provide users common metrics on attacks' quality.
|
||||
|
||||
"""
|
||||
|
||||
from .attack_queries import AttackQueries
|
||||
from .attack_success_rate import AttackSuccessRate
|
||||
from .words_perturbed import WordsPerturbed
|
||||
41
textattack/metrics/attack_metrics/attack_queries.py
Normal file
41
textattack/metrics/attack_metrics/attack_queries.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
|
||||
Metrics on AttackQueries
|
||||
=========================
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from textattack.attack_results import SkippedAttackResult
|
||||
from textattack.metrics import Metric
|
||||
|
||||
|
||||
class AttackQueries(Metric):
|
||||
def __init__(self):
|
||||
self.all_metrics = {}
|
||||
|
||||
def calculate(self, results):
|
||||
"""Calculates all metrics related to number of queries in an attack
|
||||
|
||||
Args:
|
||||
results (``AttackResult`` objects):
|
||||
Attack results for each instance in dataset
|
||||
"""
|
||||
|
||||
self.results = results
|
||||
self.num_queries = np.array(
|
||||
[
|
||||
r.num_queries
|
||||
for r in self.results
|
||||
if not isinstance(r, SkippedAttackResult)
|
||||
]
|
||||
)
|
||||
self.all_metrics["avg_num_queries"] = self.avg_num_queries()
|
||||
|
||||
return self.all_metrics
|
||||
|
||||
def avg_num_queries(self):
|
||||
avg_num_queries = self.num_queries.mean()
|
||||
avg_num_queries = round(avg_num_queries, 2)
|
||||
return avg_num_queries
|
||||
74
textattack/metrics/attack_metrics/attack_success_rate.py
Normal file
74
textattack/metrics/attack_metrics/attack_success_rate.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
|
||||
Metrics on AttackSuccessRate
|
||||
=============================
|
||||
|
||||
"""
|
||||
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
from textattack.metrics import Metric
|
||||
|
||||
|
||||
class AttackSuccessRate(Metric):
|
||||
def __init__(self):
|
||||
self.failed_attacks = 0
|
||||
self.skipped_attacks = 0
|
||||
self.successful_attacks = 0
|
||||
|
||||
self.all_metrics = {}
|
||||
|
||||
def calculate(self, results):
|
||||
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack
|
||||
|
||||
Args:
|
||||
results (``AttackResult`` objects):
|
||||
Attack results for each instance in dataset
|
||||
"""
|
||||
self.results = results
|
||||
self.total_attacks = len(self.results)
|
||||
|
||||
for i, result in enumerate(self.results):
|
||||
if isinstance(result, FailedAttackResult):
|
||||
self.failed_attacks += 1
|
||||
continue
|
||||
elif isinstance(result, SkippedAttackResult):
|
||||
self.skipped_attacks += 1
|
||||
continue
|
||||
else:
|
||||
self.successful_attacks += 1
|
||||
|
||||
# Calculated numbers
|
||||
self.all_metrics["successful_attacks"] = self.successful_attacks
|
||||
self.all_metrics["failed_attacks"] = self.failed_attacks
|
||||
self.all_metrics["skipped_attacks"] = self.skipped_attacks
|
||||
|
||||
# Percentages wrt the calculations
|
||||
self.all_metrics["original_accuracy"] = self.original_accuracy_perc()
|
||||
self.all_metrics["attack_accuracy_perc"] = self.attack_accuracy_perc()
|
||||
self.all_metrics["attack_success_rate"] = self.attack_success_rate_perc()
|
||||
|
||||
return self.all_metrics
|
||||
|
||||
def original_accuracy_perc(self):
|
||||
original_accuracy = (
|
||||
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks)
|
||||
)
|
||||
original_accuracy = round(original_accuracy, 2)
|
||||
return original_accuracy
|
||||
|
||||
def attack_accuracy_perc(self):
|
||||
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks)
|
||||
accuracy_under_attack = round(accuracy_under_attack, 2)
|
||||
return accuracy_under_attack
|
||||
|
||||
def attack_success_rate_perc(self):
|
||||
if self.successful_attacks + self.failed_attacks == 0:
|
||||
attack_success_rate = 0
|
||||
else:
|
||||
attack_success_rate = (
|
||||
self.successful_attacks
|
||||
* 100.0
|
||||
/ (self.successful_attacks + self.failed_attacks)
|
||||
)
|
||||
attack_success_rate = round(attack_success_rate, 2)
|
||||
return attack_success_rate
|
||||
85
textattack/metrics/attack_metrics/words_perturbed.py
Normal file
85
textattack/metrics/attack_metrics/words_perturbed.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
|
||||
Metrics on perturbed words
|
||||
=============================
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
from textattack.metrics import Metric
|
||||
|
||||
|
||||
class WordsPerturbed(Metric):
|
||||
def __init__(self):
|
||||
self.total_attacks = 0
|
||||
self.all_num_words = None
|
||||
self.perturbed_word_percentages = None
|
||||
self.num_words_changed_until_success = 0
|
||||
self.all_metrics = {}
|
||||
|
||||
def calculate(self, results):
|
||||
"""Calculates all metrics related to perturbed words in an attack
|
||||
|
||||
Args:
|
||||
results (``AttackResult`` objects):
|
||||
Attack results for each instance in dataset
|
||||
"""
|
||||
|
||||
self.results = results
|
||||
self.total_attacks = len(self.results)
|
||||
self.all_num_words = np.zeros(len(self.results))
|
||||
self.perturbed_word_percentages = np.zeros(len(self.results))
|
||||
self.num_words_changed_until_success = np.zeros(2 ** 16)
|
||||
self.max_words_changed = 0
|
||||
|
||||
for i, result in enumerate(self.results):
|
||||
self.all_num_words[i] = len(result.original_result.attacked_text.words)
|
||||
|
||||
if isinstance(result, FailedAttackResult) or isinstance(
|
||||
result, SkippedAttackResult
|
||||
):
|
||||
continue
|
||||
|
||||
num_words_changed = len(
|
||||
result.original_result.attacked_text.all_words_diff(
|
||||
result.perturbed_result.attacked_text
|
||||
)
|
||||
)
|
||||
self.num_words_changed_until_success[num_words_changed - 1] += 1
|
||||
self.max_words_changed = max(
|
||||
self.max_words_changed or num_words_changed, num_words_changed
|
||||
)
|
||||
if len(result.original_result.attacked_text.words) > 0:
|
||||
perturbed_word_percentage = (
|
||||
num_words_changed
|
||||
* 100.0
|
||||
/ len(result.original_result.attacked_text.words)
|
||||
)
|
||||
else:
|
||||
perturbed_word_percentage = 0
|
||||
|
||||
self.perturbed_word_percentages[i] = perturbed_word_percentage
|
||||
|
||||
self.all_metrics["avg_word_perturbed"] = self.avg_number_word_perturbed_num()
|
||||
self.all_metrics["avg_word_perturbed_perc"] = self.avg_perturbation_perc()
|
||||
self.all_metrics["max_words_changed"] = self.max_words_changed
|
||||
self.all_metrics[
|
||||
"num_words_changed_until_success"
|
||||
] = self.num_words_changed_until_success
|
||||
|
||||
return self.all_metrics
|
||||
|
||||
def avg_number_word_perturbed_num(self):
|
||||
average_num_words = self.all_num_words.mean()
|
||||
average_num_words = round(average_num_words, 2)
|
||||
return average_num_words
|
||||
|
||||
def avg_perturbation_perc(self):
|
||||
self.perturbed_word_percentages = self.perturbed_word_percentages[
|
||||
self.perturbed_word_percentages > 0
|
||||
]
|
||||
average_perc_words_perturbed = self.perturbed_word_percentages.mean()
|
||||
average_perc_words_perturbed = round(average_perc_words_perturbed, 2)
|
||||
return average_perc_words_perturbed
|
||||
28
textattack/metrics/metric.py
Normal file
28
textattack/metrics/metric.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Metric Class
|
||||
========================
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
"""A metric for evaluating Adversarial Attack candidates."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs):
|
||||
"""Creates pre-built :class:`~textattack.Metric` that correspond to
|
||||
evaluation metrics for adversarial examples.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, results):
|
||||
"""Abstract function for computing any values which are to be calculated as a whole during initialization
|
||||
Args:
|
||||
results (``AttackResult`` objects):
|
||||
Attack results for each instance in dataset
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
12
textattack/metrics/quality_metrics/__init__.py
Normal file
12
textattack/metrics/quality_metrics/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
|
||||
Metrics on Quality
|
||||
======================
|
||||
|
||||
TextAttack provide users common metrics on text examples' quality.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from .perplexity import Perplexity
|
||||
from .use import USEMetric
|
||||
93
textattack/metrics/quality_metrics/perplexity.py
Normal file
93
textattack/metrics/quality_metrics/perplexity.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
|
||||
Perplexity Metric:
|
||||
======================
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
from textattack.metrics import Metric
|
||||
import textattack.shared.utils
|
||||
|
||||
|
||||
class Perplexity(Metric):
|
||||
def __init__(self):
|
||||
self.all_metrics = {}
|
||||
self.original_candidates = []
|
||||
self.successful_candidates = []
|
||||
self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
self.ppl_model.to(textattack.shared.utils.device)
|
||||
self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
self.ppl_model.eval()
|
||||
self.max_length = self.ppl_model.config.n_positions
|
||||
self.stride = 512
|
||||
|
||||
def calculate(self, results):
|
||||
"""Calculates average Perplexity on all successfull attacks using a pre-trained small GPT-2 model
|
||||
|
||||
Args:
|
||||
results (``AttackResult`` objects):
|
||||
Attack results for each instance in dataset
|
||||
"""
|
||||
self.results = results
|
||||
self.original_candidates_ppl = []
|
||||
self.successful_candidates_ppl = []
|
||||
|
||||
for i, result in enumerate(self.results):
|
||||
if isinstance(result, FailedAttackResult):
|
||||
continue
|
||||
elif isinstance(result, SkippedAttackResult):
|
||||
continue
|
||||
else:
|
||||
self.original_candidates.append(
|
||||
result.original_result.attacked_text.text.lower()
|
||||
)
|
||||
self.successful_candidates.append(
|
||||
result.perturbed_result.attacked_text.text.lower()
|
||||
)
|
||||
|
||||
ppl_orig = self.calc_ppl(self.original_candidates)
|
||||
ppl_attack = self.calc_ppl(self.successful_candidates)
|
||||
|
||||
self.all_metrics["avg_original_perplexity"] = round(ppl_orig[0], 2)
|
||||
self.all_metrics["original_perplexity_list"] = ppl_orig[1]
|
||||
|
||||
self.all_metrics["avg_attack_perplexity"] = round(ppl_attack[0], 2)
|
||||
self.all_metrics["attack_perplexity_list"] = ppl_attack[1]
|
||||
|
||||
return self.all_metrics
|
||||
|
||||
def calc_ppl(self, texts):
|
||||
|
||||
ppl_vals = []
|
||||
|
||||
with torch.no_grad():
|
||||
for text in texts:
|
||||
eval_loss = []
|
||||
input_ids = torch.tensor(
|
||||
self.ppl_tokenizer.encode(text, add_special_tokens=True)
|
||||
).unsqueeze(0)
|
||||
# Strided perplexity calculation from huggingface.co/transformers/perplexity.html
|
||||
for i in range(0, input_ids.size(1), self.stride):
|
||||
begin_loc = max(i + self.stride - self.max_length, 0)
|
||||
end_loc = min(i + self.stride, input_ids.size(1))
|
||||
trg_len = end_loc - i
|
||||
input_ids_t = input_ids[:, begin_loc:end_loc].to(
|
||||
textattack.shared.utils.device
|
||||
)
|
||||
target_ids = input_ids_t.clone()
|
||||
target_ids[:, :-trg_len] = -100
|
||||
|
||||
outputs = self.ppl_model(input_ids_t, labels=target_ids)
|
||||
log_likelihood = outputs[0] * trg_len
|
||||
|
||||
eval_loss.append(log_likelihood)
|
||||
|
||||
ppl_vals.append(
|
||||
torch.exp(torch.stack(eval_loss).sum() / end_loc).item()
|
||||
)
|
||||
|
||||
return sum(ppl_vals) / len(ppl_vals), ppl_vals
|
||||
44
textattack/metrics/quality_metrics/use.py
Normal file
44
textattack/metrics/quality_metrics/use.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
|
||||
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
||||
from textattack.metrics import Metric
|
||||
|
||||
|
||||
class USEMetric(Metric):
|
||||
def __init__(self, **kwargs):
|
||||
self.use_obj = UniversalSentenceEncoder()
|
||||
self.use_obj.model = UniversalSentenceEncoder()
|
||||
self.original_candidates = []
|
||||
self.successful_candidates = []
|
||||
self.all_metrics = {}
|
||||
|
||||
def calculate(self, results):
|
||||
"""Calculates average USE similarity on all successfull attacks
|
||||
|
||||
Args:
|
||||
results (``AttackResult`` objects):
|
||||
Attack results for each instance in dataset
|
||||
"""
|
||||
self.results = results
|
||||
|
||||
for i, result in enumerate(self.results):
|
||||
if isinstance(result, FailedAttackResult):
|
||||
continue
|
||||
elif isinstance(result, SkippedAttackResult):
|
||||
continue
|
||||
else:
|
||||
self.original_candidates.append(result.original_result.attacked_text)
|
||||
self.successful_candidates.append(result.perturbed_result.attacked_text)
|
||||
|
||||
use_scores = []
|
||||
for c in range(len(self.original_candidates)):
|
||||
use_scores.append(
|
||||
self.use_obj._sim_score(
|
||||
self.original_candidates[c], self.successful_candidates[c]
|
||||
).item()
|
||||
)
|
||||
|
||||
self.all_metrics["avg_attack_use_score"] = round(
|
||||
sum(use_scores) / len(use_scores), 2
|
||||
)
|
||||
|
||||
return self.all_metrics
|
||||
Reference in New Issue
Block a user