1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #541 from QData/master

Update
This commit is contained in:
Hanyu-Liu-123
2021-10-08 01:07:34 -04:00
committed by GitHub
33 changed files with 768 additions and 90 deletions

View File

@@ -47,3 +47,34 @@ You can also install other miscallenous optional dependencies by running
To install both groups of packages, run
pip install textattack[tensorflow,optional]
## FAQ on installation
For many of the dependent library issues, the following command is the first you could try:
```bash
pip install --force-reinstall textattack
```
OR
```bash
pip install textattack[tensorflow,optional]
```
Besides, we highly recommend you to use virtual environment for textattack use,
see [information here](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#removing-an-environment). Here is one conda example:
```bash
conda create -n textattackenv python=3.7
conda activate textattackenv
conda env list
```
If you want to use the most-up-to-date version of textattack (normally with newer bug fixes), you can run the following:
```bash
git clone https://github.com/QData/TextAttack.git
cd TextAttack
pip install .[dev]
```

View File

@@ -7,6 +7,7 @@ textattack.constraints.grammaticality.language\_models package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints.grammaticality package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints.semantics package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.constraints.semantics.sentence\_encoders package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.datasets package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.goal\_functions package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -0,0 +1,26 @@
textattack.metrics.attack\_metrics package
==========================================
.. automodule:: textattack.metrics.attack_metrics
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.attack_metrics.attack_queries
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.attack_metrics.attack_success_rate
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.attack_metrics.words_perturbed
:members:
:undoc-members:
:show-inheritance:

View File

@@ -0,0 +1,20 @@
textattack.metrics.quality\_metrics package
===========================================
.. automodule:: textattack.metrics.quality_metrics
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.quality_metrics.perplexity
:members:
:undoc-members:
:show-inheritance:
.. automodule:: textattack.metrics.quality_metrics.use
:members:
:undoc-members:
:show-inheritance:

View File

@@ -0,0 +1,22 @@
textattack.metrics package
==========================
.. automodule:: textattack.metrics
:members:
:undoc-members:
:show-inheritance:
.. toctree::
:maxdepth: 6
textattack.metrics.attack_metrics
textattack.metrics.quality_metrics
.. automodule:: textattack.metrics.metric
:members:
:undoc-members:
:show-inheritance:

View File

@@ -7,6 +7,7 @@ textattack.models package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -19,6 +19,7 @@ textattack package
textattack.goal_function_results
textattack.goal_functions
textattack.loggers
textattack.metrics
textattack.models
textattack.search_methods
textattack.shared

View File

@@ -7,6 +7,7 @@ textattack.shared package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -7,6 +7,7 @@ textattack.transformations package
:show-inheritance:
.. toctree::
:maxdepth: 6

View File

@@ -0,0 +1,74 @@
/.*/Attack(
(search_method): BeamSearch(
(beam_width): 10
)
(goal_function): UntargetedClassification
(transformation): WordSwapGradientBased(
(top_n): 1
)
(constraints):
(0): MaxWordsPerturbed(
(max_num_words): 2
(compare_against_original): True
)
(1): WordEmbeddingDistance(
(embedding): WordEmbedding
(min_cos_sim): 0.8
(cased): False
(include_unknown_words): True
(compare_against_original): True
)
(2): PartOfSpeech(
(tagger_type): nltk
(tagset): universal
(allow_verb_noun_swap): True
(compare_against_original): True
)
(3): RepeatModification
(4): StopwordModification
(is_black_box): False
)
--------------------------------------------- Result 1 ---------------------------------------------
[[Positive (96%)]] --> [[Negative (77%)]]
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[supplies]] with tremendous skill .
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[stagnated]] with tremendous skill .
--------------------------------------------- Result 2 ---------------------------------------------
[[Negative (57%)]] --> [[[SKIPPED]]]
red dragon " never cuts corners .
--------------------------------------------- Result 3 ---------------------------------------------
[[Positive (51%)]] --> [[[FAILED]]]
fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense .
--------------------------------------------- Result 4 ---------------------------------------------
[[Positive (89%)]] --> [[[FAILED]]]
throws in enough clever and unexpected twists to make the formula feel fresh .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 1 |
| Number of failed attacks: | 2 |
| Number of skipped attacks: | 1 |
| Original accuracy: | 75.0% |
| Accuracy under attack: | 50.0% |
| Attack success rate: | 33.33% |
| Average perturbed word %: | 5.56% |
| Average num. words per input: | 15.5 |
| Avg num queries: | 1.33 |
| Average Original Perplexity: | 291.47 |
| Average Attack Perplexity: | 320.33 |
| Average Attack USE Score: | 0.91 |
+-------------------------------+--------+

View File

@@ -0,0 +1,68 @@
/.*/Attack(
(search_method): GreedyWordSwapWIR(
(wir_method): unk
)
(goal_function): UntargetedClassification
(transformation): CompositeTransformation(
(0): WordSwapNeighboringCharacterSwap(
(random_one): True
)
(1): WordSwapRandomCharacterSubstitution(
(random_one): True
)
(2): WordSwapRandomCharacterDeletion(
(random_one): True
)
(3): WordSwapRandomCharacterInsertion(
(random_one): True
)
)
(constraints):
(0): LevenshteinEditDistance(
(max_edit_distance): 30
(compare_against_original): True
)
(1): RepeatModification
(2): StopwordModification
(is_black_box): True
)
--------------------------------------------- Result 1 ---------------------------------------------
[[Negative (100%)]] --> [[Positive (71%)]]
[[hide]] [[new]] secretions from the parental units
[[Ehide]] [[enw]] secretions from the parental units
--------------------------------------------- Result 2 ---------------------------------------------
[[Negative (100%)]] --> [[[FAILED]]]
contains no wit , only labored gags
--------------------------------------------- Result 3 ---------------------------------------------
[[Positive (100%)]] --> [[Negative (96%)]]
that [[loves]] its characters and communicates [[something]] [[rather]] [[beautiful]] about human nature
that [[lodes]] its characters and communicates [[somethNng]] [[rathrer]] [[beautifdul]] about human nature
+-------------------------------+---------+
| Attack Results | |
+-------------------------------+---------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 1 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 66.67% |
| Average perturbed word %: | 30.95% |
| Average num. words per input: | 8.33 |
| Avg num queries: | 22.67 |
| Average Original Perplexity: | 1126.57 |
| Average Attack Perplexity: | 2823/.*/|
| Average Attack USE Score: | 0.76 |
+-------------------------------+---------+

View File

@@ -48,6 +48,20 @@ attack_test_params = [
"tests/sample_outputs/run_attack_transformers_datasets.txt",
),
#
# test loading an attack from the transformers model hub and calculate perplexity and use
#
(
"attack_from_transformers_adv_metrics",
(
"textattack attack --model-from-huggingface "
"distilbert-base-uncased-finetuned-sst-2-english "
"--dataset-from-huggingface glue^sst2^train --recipe deepwordbug --num-examples 3 "
"--enable-advance-metrics"
""
),
"tests/sample_outputs/run_attack_transformers_datasets_adv_metrics.txt",
),
#
# test running an attack by loading a model and dataset from file
#
(
@@ -72,6 +86,17 @@ attack_test_params = [
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt",
),
#
# test hotflip on 10 samples from LSTM MR and calculate perplexity and use
#
(
"run_attack_hotflip_lstm_mr_4_adv_metrics",
(
"textattack attack --model lstm-mr --recipe hotflip "
"--num-examples 4 --num-examples-offset 3 --enable-advance-metrics "
),
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4_adv_metrics.txt",
),
#
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
#
(

View File

@@ -8,7 +8,6 @@ TextAttack makes experimenting with the robustness of NLP models seamless, fast,
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
"""
from .attack_args import AttackArgs, CommandLineAttackArgs
from .augment_args import AugmenterArgs
from .dataset_args import DatasetArgs
@@ -17,6 +16,7 @@ from .training_args import TrainingArgs, CommandLineTrainingArgs
from .attack import Attack
from .attacker import Attacker
from .trainer import Trainer
from .metrics import Metric
from . import (
attack_recipes,
@@ -28,10 +28,12 @@ from . import (
goal_function_results,
goal_functions,
loggers,
metrics,
models,
search_methods,
shared,
transformations,
)
name = "textattack"

View File

@@ -174,6 +174,8 @@ class AttackArgs:
Disable displaying individual attack results to stdout.
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
"""
num_examples: int = 10
@@ -194,6 +196,7 @@ class AttackArgs:
log_to_wandb: str = None
disable_stdout: bool = False
silent: bool = False
enable_advance_metrics: bool = False
def __post_init__(self):
if self.num_successful_examples:
@@ -351,6 +354,12 @@ class AttackArgs:
default=default_obj.silent,
help="Disable all logging",
)
parser.add_argument(
"--enable-advance-metrics",
action="store_true",
default=default_obj.enable_advance_metrics,
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
)
return parser

View File

@@ -219,6 +219,10 @@ class Attacker:
# Enable summary stdout
if not self.attack_args.silent and self.attack_args.disable_stdout:
self.attack_log_manager.enable_stdout()
if self.attack_args.enable_advance_metrics:
self.attack_log_manager.enable_advance_metrics = True
self.attack_log_manager.log_summary()
self.attack_log_manager.flush()
print()
@@ -390,6 +394,10 @@ class Attacker:
# Enable summary stdout.
if not self.attack_args.silent and self.attack_args.disable_stdout:
self.attack_log_manager.enable_stdout()
if self.attack_args.enable_advance_metrics:
self.attack_log_manager.enable_advance_metrics = True
self.attack_log_manager.log_summary()
self.attack_log_manager.flush()
print()

View File

@@ -39,6 +39,8 @@ class EvalModelCommand(TextAttackCommand):
def test_model_on_dataset(self, args):
model = ModelArgs._create_model_from_args(args)
dataset = DatasetArgs._create_dataset_from_args(args)
if args.num_examples == -1:
args.num_examples = len(dataset)
preds = []
ground_truth_outputs = []

View File

@@ -275,7 +275,9 @@ class DatasetArgs:
dataset_args = (dataset_args,)
if args.dataset_split:
if len(dataset_args) > 1:
dataset_args[2] = args.dataset_split
dataset_args = (
dataset_args[:1] + (args.dataset_split,) + dataset_args[2:]
)
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_args, shuffle=False
)

View File

@@ -3,9 +3,12 @@ Managing Attack Logs.
========================
"""
import numpy as np
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics.attack_metrics import (
AttackQueries,
AttackSuccessRate,
WordsPerturbed,
)
from textattack.metrics.quality_metrics import Perplexity, USEMetric
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
@@ -16,6 +19,7 @@ class AttackLogManager:
def __init__(self):
self.loggers = []
self.results = []
self.enable_advance_metrics = False
def enable_stdout(self):
self.loggers.append(FileLogger(stdout=True))
@@ -72,103 +76,77 @@ class AttackLogManager:
total_attacks = len(self.results)
if total_attacks == 0:
return
# Count things about attacks.
all_num_words = np.zeros(len(self.results))
perturbed_word_percentages = np.zeros(len(self.results))
num_words_changed_until_success = np.zeros(
2 ** 16
) # @ TODO: be smarter about this
failed_attacks = 0
skipped_attacks = 0
successful_attacks = 0
max_words_changed = 0
for i, result in enumerate(self.results):
all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult):
failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
skipped_attacks += 1
continue
else:
successful_attacks += 1
num_words_changed = result.original_result.attacked_text.words_diff_num(
result.perturbed_result.attacked_text
)
# num_words_changed = len(
# result.original_result.attacked_text.all_words_diff(
# result.perturbed_result.attacked_text
# )
# )
num_words_changed_until_success[num_words_changed - 1] += 1
max_words_changed = max(
max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0
perturbed_word_percentages[i] = perturbed_word_percentage
# Original classifier success rate on these samples.
original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks)
original_accuracy = str(round(original_accuracy, 2)) + "%"
# New classifier success rate on these samples.
accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks)
accuracy_under_attack = str(round(accuracy_under_attack, 2)) + "%"
# Attack success rate.
if successful_attacks + failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
successful_attacks * 100.0 / (successful_attacks + failed_attacks)
)
attack_success_rate = str(round(attack_success_rate, 2)) + "%"
perturbed_word_percentages = perturbed_word_percentages[
perturbed_word_percentages > 0
]
average_perc_words_perturbed = perturbed_word_percentages.mean()
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + "%"
average_num_words = all_num_words.mean()
average_num_words = str(round(average_num_words, 2))
# Default metrics - calculated on every attack
attack_success_stats = AttackSuccessRate().calculate(self.results)
words_perturbed_stats = WordsPerturbed().calculate(self.results)
attack_query_stats = AttackQueries().calculate(self.results)
# @TODO generate this table based on user input - each column in specific class
# Example to demonstrate:
# summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ...
summary_table_rows = [
["Number of successful attacks:", str(successful_attacks)],
["Number of failed attacks:", str(failed_attacks)],
["Number of skipped attacks:", str(skipped_attacks)],
["Original accuracy:", original_accuracy],
["Accuracy under attack:", accuracy_under_attack],
["Attack success rate:", attack_success_rate],
["Average perturbed word %:", average_perc_words_perturbed],
["Average num. words per input:", average_num_words],
[
"Number of successful attacks:",
attack_success_stats["successful_attacks"],
],
["Number of failed attacks:", attack_success_stats["failed_attacks"]],
["Number of skipped attacks:", attack_success_stats["skipped_attacks"]],
[
"Original accuracy:",
str(attack_success_stats["original_accuracy"]) + "%",
],
[
"Accuracy under attack:",
str(attack_success_stats["attack_accuracy_perc"]) + "%",
],
[
"Attack success rate:",
str(attack_success_stats["attack_success_rate"]) + "%",
],
[
"Average perturbed word %:",
str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%",
],
[
"Average num. words per input:",
words_perturbed_stats["avg_word_perturbed"],
],
]
num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
summary_table_rows.append(
["Avg num queries:", attack_query_stats["avg_num_queries"]]
)
avg_num_queries = num_queries.mean()
avg_num_queries = str(round(avg_num_queries, 2))
summary_table_rows.append(["Avg num queries:", avg_num_queries])
if self.enable_advance_metrics:
perplexity_stats = Perplexity().calculate(self.results)
use_stats = USEMetric().calculate(self.results)
summary_table_rows.append(
[
"Average Original Perplexity:",
perplexity_stats["avg_original_perplexity"],
]
)
summary_table_rows.append(
[
"Average Attack Perplexity:",
perplexity_stats["avg_attack_perplexity"],
]
)
summary_table_rows.append(
["Average Attack USE Score:", use_stats["avg_attack_use_score"]]
)
self.log_summary_rows(
summary_table_rows, "Attack Results", "attack_results_summary"
)
# Show histogram of words changed.
numbins = max(max_words_changed, 10)
numbins = max(words_perturbed_stats["max_words_changed"], 10)
for logger in self.loggers:
logger.log_hist(
num_words_changed_until_success[:numbins],
words_perturbed_stats["num_words_changed_until_success"][:numbins],
numbins=numbins,
title="Num Words Perturbed",
window_id="num_words_perturbed",

View File

@@ -0,0 +1,11 @@
"""
"""
from .metric import Metric
from .attack_metrics import AttackSuccessRate
from .attack_metrics import WordsPerturbed
from .attack_metrics import AttackQueries
from .quality_metrics import Perplexity
from .quality_metrics import USEMetric

View File

@@ -0,0 +1,12 @@
"""
attack_metrics:
======================
TextAttack provide users common metrics on attacks' quality.
"""
from .attack_queries import AttackQueries
from .attack_success_rate import AttackSuccessRate
from .words_perturbed import WordsPerturbed

View File

@@ -0,0 +1,41 @@
"""
Metrics on AttackQueries
=========================
"""
import numpy as np
from textattack.attack_results import SkippedAttackResult
from textattack.metrics import Metric
class AttackQueries(Metric):
def __init__(self):
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to number of queries in an attack
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)
self.all_metrics["avg_num_queries"] = self.avg_num_queries()
return self.all_metrics
def avg_num_queries(self):
avg_num_queries = self.num_queries.mean()
avg_num_queries = round(avg_num_queries, 2)
return avg_num_queries

View File

@@ -0,0 +1,74 @@
"""
Metrics on AttackSuccessRate
=============================
"""
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
class AttackSuccessRate(Metric):
def __init__(self):
self.failed_attacks = 0
self.skipped_attacks = 0
self.successful_attacks = 0
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.total_attacks = len(self.results)
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
self.failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
self.skipped_attacks += 1
continue
else:
self.successful_attacks += 1
# Calculated numbers
self.all_metrics["successful_attacks"] = self.successful_attacks
self.all_metrics["failed_attacks"] = self.failed_attacks
self.all_metrics["skipped_attacks"] = self.skipped_attacks
# Percentages wrt the calculations
self.all_metrics["original_accuracy"] = self.original_accuracy_perc()
self.all_metrics["attack_accuracy_perc"] = self.attack_accuracy_perc()
self.all_metrics["attack_success_rate"] = self.attack_success_rate_perc()
return self.all_metrics
def original_accuracy_perc(self):
original_accuracy = (
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks)
)
original_accuracy = round(original_accuracy, 2)
return original_accuracy
def attack_accuracy_perc(self):
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks)
accuracy_under_attack = round(accuracy_under_attack, 2)
return accuracy_under_attack
def attack_success_rate_perc(self):
if self.successful_attacks + self.failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
self.successful_attacks
* 100.0
/ (self.successful_attacks + self.failed_attacks)
)
attack_success_rate = round(attack_success_rate, 2)
return attack_success_rate

View File

@@ -0,0 +1,85 @@
"""
Metrics on perturbed words
=============================
"""
import numpy as np
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
class WordsPerturbed(Metric):
def __init__(self):
self.total_attacks = 0
self.all_num_words = None
self.perturbed_word_percentages = None
self.num_words_changed_until_success = 0
self.all_metrics = {}
def calculate(self, results):
"""Calculates all metrics related to perturbed words in an attack
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2 ** 16)
self.max_words_changed = 0
for i, result in enumerate(self.results):
self.all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult) or isinstance(
result, SkippedAttackResult
):
continue
num_words_changed = len(
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
self.num_words_changed_until_success[num_words_changed - 1] += 1
self.max_words_changed = max(
self.max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0
self.perturbed_word_percentages[i] = perturbed_word_percentage
self.all_metrics["avg_word_perturbed"] = self.avg_number_word_perturbed_num()
self.all_metrics["avg_word_perturbed_perc"] = self.avg_perturbation_perc()
self.all_metrics["max_words_changed"] = self.max_words_changed
self.all_metrics[
"num_words_changed_until_success"
] = self.num_words_changed_until_success
return self.all_metrics
def avg_number_word_perturbed_num(self):
average_num_words = self.all_num_words.mean()
average_num_words = round(average_num_words, 2)
return average_num_words
def avg_perturbation_perc(self):
self.perturbed_word_percentages = self.perturbed_word_percentages[
self.perturbed_word_percentages > 0
]
average_perc_words_perturbed = self.perturbed_word_percentages.mean()
average_perc_words_perturbed = round(average_perc_words_perturbed, 2)
return average_perc_words_perturbed

View File

@@ -0,0 +1,28 @@
"""
Metric Class
========================
"""
from abc import ABC, abstractmethod
class Metric(ABC):
"""A metric for evaluating Adversarial Attack candidates."""
@abstractmethod
def __init__(self, **kwargs):
"""Creates pre-built :class:`~textattack.Metric` that correspond to
evaluation metrics for adversarial examples.
"""
raise NotImplementedError()
@abstractmethod
def calculate(self, results):
"""Abstract function for computing any values which are to be calculated as a whole during initialization
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
raise NotImplementedError

View File

@@ -0,0 +1,12 @@
"""
Metrics on Quality
======================
TextAttack provide users common metrics on text examples' quality.
"""
from .perplexity import Perplexity
from .use import USEMetric

View File

@@ -0,0 +1,93 @@
"""
Perplexity Metric:
======================
"""
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
import textattack.shared.utils
class Perplexity(Metric):
def __init__(self):
self.all_metrics = {}
self.original_candidates = []
self.successful_candidates = []
self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2")
self.ppl_model.to(textattack.shared.utils.device)
self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.ppl_model.eval()
self.max_length = self.ppl_model.config.n_positions
self.stride = 512
def calculate(self, results):
"""Calculates average Perplexity on all successfull attacks using a pre-trained small GPT-2 model
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
self.original_candidates_ppl = []
self.successful_candidates_ppl = []
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
continue
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(
result.original_result.attacked_text.text.lower()
)
self.successful_candidates.append(
result.perturbed_result.attacked_text.text.lower()
)
ppl_orig = self.calc_ppl(self.original_candidates)
ppl_attack = self.calc_ppl(self.successful_candidates)
self.all_metrics["avg_original_perplexity"] = round(ppl_orig[0], 2)
self.all_metrics["original_perplexity_list"] = ppl_orig[1]
self.all_metrics["avg_attack_perplexity"] = round(ppl_attack[0], 2)
self.all_metrics["attack_perplexity_list"] = ppl_attack[1]
return self.all_metrics
def calc_ppl(self, texts):
ppl_vals = []
with torch.no_grad():
for text in texts:
eval_loss = []
input_ids = torch.tensor(
self.ppl_tokenizer.encode(text, add_special_tokens=True)
).unsqueeze(0)
# Strided perplexity calculation from huggingface.co/transformers/perplexity.html
for i in range(0, input_ids.size(1), self.stride):
begin_loc = max(i + self.stride - self.max_length, 0)
end_loc = min(i + self.stride, input_ids.size(1))
trg_len = end_loc - i
input_ids_t = input_ids[:, begin_loc:end_loc].to(
textattack.shared.utils.device
)
target_ids = input_ids_t.clone()
target_ids[:, :-trg_len] = -100
outputs = self.ppl_model(input_ids_t, labels=target_ids)
log_likelihood = outputs[0] * trg_len
eval_loss.append(log_likelihood)
ppl_vals.append(
torch.exp(torch.stack(eval_loss).sum() / end_loc).item()
)
return sum(ppl_vals) / len(ppl_vals), ppl_vals

View File

@@ -0,0 +1,44 @@
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.metrics import Metric
class USEMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = UniversalSentenceEncoder()
self.use_obj.model = UniversalSentenceEncoder()
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}
def calculate(self, results):
"""Calculates average USE similarity on all successfull attacks
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
"""
self.results = results
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
continue
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(result.original_result.attacked_text)
self.successful_candidates.append(result.perturbed_result.attacked_text)
use_scores = []
for c in range(len(self.original_candidates)):
use_scores.append(
self.use_obj._sim_score(
self.original_candidates[c], self.successful_candidates[c]
).item()
)
self.all_metrics["avg_attack_use_score"] = round(
sum(use_scores) / len(use_scores), 2
)
return self.all_metrics