1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

[CODE] Fix metrics, add tests

This commit is contained in:
sanchit97
2021-09-10 02:33:50 -04:00
parent 10ee24b8da
commit 32c3e43adc
8 changed files with 209 additions and 43 deletions

View File

@@ -0,0 +1,74 @@
/.*/Attack(
(search_method): BeamSearch(
(beam_width): 10
)
(goal_function): UntargetedClassification
(transformation): WordSwapGradientBased(
(top_n): 1
)
(constraints):
(0): MaxWordsPerturbed(
(max_num_words): 2
(compare_against_original): True
)
(1): WordEmbeddingDistance(
(embedding): WordEmbedding
(min_cos_sim): 0.8
(cased): False
(include_unknown_words): True
(compare_against_original): True
)
(2): PartOfSpeech(
(tagger_type): nltk
(tagset): universal
(allow_verb_noun_swap): True
(compare_against_original): True
)
(3): RepeatModification
(4): StopwordModification
(is_black_box): False
)
--------------------------------------------- Result 1 ---------------------------------------------
[[Positive (96%)]] --> [[Negative (77%)]]
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[supplies]] with tremendous skill .
the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[stagnated]] with tremendous skill .
--------------------------------------------- Result 2 ---------------------------------------------
[[Negative (57%)]] --> [[[SKIPPED]]]
red dragon " never cuts corners .
--------------------------------------------- Result 3 ---------------------------------------------
[[Positive (51%)]] --> [[[FAILED]]]
fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense .
--------------------------------------------- Result 4 ---------------------------------------------
[[Positive (89%)]] --> [[[FAILED]]]
throws in enough clever and unexpected twists to make the formula feel fresh .
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 1 |
| Number of failed attacks: | 2 |
| Number of skipped attacks: | 1 |
| Original accuracy: | 75.0% |
| Accuracy under attack: | 50.0% |
| Attack success rate: | 33.33% |
| Average perturbed word %: | 5.56% |
| Average num. words per input: | 15.5 |
| Avg num queries: | 1.33 |
| Average Original Perplexity: | 291.47 |
| Average Attack Perplexity: | 320.33 |
| Average Attack USE Score: | 0.91 |
+-------------------------------+--------+

View File

@@ -0,0 +1,68 @@
/.*/Attack(
(search_method): GreedyWordSwapWIR(
(wir_method): unk
)
(goal_function): UntargetedClassification
(transformation): CompositeTransformation(
(0): WordSwapNeighboringCharacterSwap(
(random_one): True
)
(1): WordSwapRandomCharacterSubstitution(
(random_one): True
)
(2): WordSwapRandomCharacterDeletion(
(random_one): True
)
(3): WordSwapRandomCharacterInsertion(
(random_one): True
)
)
(constraints):
(0): LevenshteinEditDistance(
(max_edit_distance): 30
(compare_against_original): True
)
(1): RepeatModification
(2): StopwordModification
(is_black_box): True
)
--------------------------------------------- Result 1 ---------------------------------------------
[[Negative (100%)]] --> [[Positive (71%)]]
[[hide]] [[new]] secretions from the parental units
[[Ehide]] [[enw]] secretions from the parental units
--------------------------------------------- Result 2 ---------------------------------------------
[[Negative (100%)]] --> [[[FAILED]]]
contains no wit , only labored gags
--------------------------------------------- Result 3 ---------------------------------------------
[[Positive (100%)]] --> [[Negative (96%)]]
that [[loves]] its characters and communicates [[something]] [[rather]] [[beautiful]] about human nature
that [[lodes]] its characters and communicates [[somethNng]] [[rathrer]] [[beautifdul]] about human nature
+-------------------------------+---------+
| Attack Results | |
+-------------------------------+---------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 1 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 66.67% |
| Average perturbed word %: | 30.95% |
| Average num. words per input: | 8.33 |
| Avg num queries: | 22.67 |
| Average Original Perplexity: | 1126.57 |
| Average Attack Perplexity: | 2823.86 |
| Average Attack USE Score: | 0.76 |
+-------------------------------+---------+

View File

@@ -48,6 +48,20 @@ attack_test_params = [
"tests/sample_outputs/run_attack_transformers_datasets.txt",
),
#
# test loading an attack from the transformers model hub and calculate perplexity and use
#
(
"attack_from_transformers",
(
"textattack attack --model-from-huggingface "
"distilbert-base-uncased-finetuned-sst-2-english "
"--dataset-from-huggingface glue^sst2^train --recipe deepwordbug --num-examples 3 "
"--enable-advance-metrics"
""
),
"tests/sample_outputs/run_attack_transformers_datasets_adv_metrics.txt",
),
#
# test running an attack by loading a model and dataset from file
#
(
@@ -72,6 +86,17 @@ attack_test_params = [
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt",
),
#
# test hotflip on 10 samples from LSTM MR and calculate perplexity and use
#
(
"run_attack_hotflip_lstm_mr_4",
(
"textattack attack --model lstm-mr --recipe hotflip "
"--num-examples 4 --num-examples-offset 3 --enable-advance-metrics "
),
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4_adv_metrics.txt",
),
#
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
#
(

View File

@@ -8,10 +8,7 @@ from textattack.metrics.attack_metrics import (
AttackSuccessRate,
WordsPerturbed,
)
from textattack.metrics.quality_metrics import (
Perplexity,
USEMetric
)
from textattack.metrics.quality_metrics import Perplexity, USEMetric
from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger
@@ -123,17 +120,23 @@ class AttackLogManager:
if self.enable_advance_metrics:
perplexity_stats = Perplexity().calculate(self.results)
use_stats = USEMetric(**{"large":False}).calculate(self.results)
print(use_stats)
use_stats = USEMetric().calculate(self.results)
summary_table_rows.append(
[
"Avg Original Perplexity:",
"Average Original Perplexity:",
perplexity_stats["avg_original_perplexity"],
]
)
summary_table_rows.append(
["Avg Attack USE Score:", use_stats["avg_attack_use_score"]]
[
"Average Attack Perplexity:",
perplexity_stats["avg_attack_perplexity"],
]
)
summary_table_rows.append(
["Average Attack USE Score:", use_stats["avg_attack_use_score"]]
)
self.log_summary_rows(

View File

@@ -8,4 +8,4 @@ from .attack_metrics import WordsPerturbed
from .attack_metrics import AttackQueries
from .quality_metrics import Perplexity
from .quality_metrics import USEMetric
from .quality_metrics import USEMetric

View File

@@ -20,7 +20,7 @@ class AttackSuccessRate(Metric):
def calculate(self, results):
self.results = results
self.total_attacks = len(self.results)
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
self.failed_attacks += 1

View File

@@ -49,46 +49,42 @@ class Perplexity(Metric):
ppl_orig = self.calc_ppl(self.original_candidates)
ppl_attack = self.calc_ppl(self.successful_candidates)
self.all_metrics["avg_original_perplexity"] = round(
ppl_orig[0], 2
)
self.all_metrics["avg_original_perplexity"] = round(ppl_orig[0], 2)
self.all_metrics["original_perplexity_list"] = ppl_orig[1]
self.all_metrics["avg_attack_perplexity"] = round(
ppl_attack[0], 2
)
self.all_metrics["avg_attack_perplexity"] = round(ppl_attack[0], 2)
self.all_metrics["attack_perplexity_list"] = ppl_attack[1]
return self.all_metrics
def calc_ppl(self, texts):
ppl_vals = []
with torch.no_grad():
for text in texts:
eval_loss = []
input_ids = torch.tensor(
self.ppl_tokenizer.encode(
text, add_special_tokens=True
)
self.ppl_tokenizer.encode(text, add_special_tokens=True)
).unsqueeze(0)
# Strided perplexity calculation from huggingface.co/transformers/perplexity.html
for i in range(0, input_ids.size(1), self.stride):
begin_loc = max(i + self.stride - self.max_length, 0)
end_loc = min(i + self.stride, input_ids.size(1))
trg_len = end_loc - i
input_ids_t = input_ids[:,begin_loc:end_loc].to(textattack.shared.utils.device)
input_ids_t = input_ids[:, begin_loc:end_loc].to(
textattack.shared.utils.device
)
target_ids = input_ids_t.clone()
target_ids[:,:-trg_len] = -100
target_ids[:, :-trg_len] = -100
outputs = self.ppl_model(input_ids_t, labels=target_ids)
log_likelihood = outputs[0] * trg_len
eval_loss.append(log_likelihood)
ppl_vals.append(torch.exp(torch.stack(eval_loss).sum() / end_loc).item())
return sum(ppl_vals)/len(ppl_vals), ppl_vals
ppl_vals.append(
torch.exp(torch.stack(eval_loss).sum() / end_loc).item()
)
return sum(ppl_vals) / len(ppl_vals), ppl_vals

View File

@@ -3,11 +3,13 @@ from textattack.metrics import Metric
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
class USEMetric(Metric):
"""Constraint using similarity between sentence encodings of x and x_adv
where the text embeddings are created using the Universal Sentence
Encoder."""
"""Calculates average USE similarity on all successfull attacks
Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""
def __init__(self, **kwargs):
self.use_obj = UniversalSentenceEncoder()
@@ -16,7 +18,6 @@ class USEMetric(Metric):
self.successful_candidates = []
self.all_metrics = {}
def calculate(self, results):
self.results = results
@@ -26,20 +27,19 @@ class USEMetric(Metric):
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(
result.original_result.attacked_text
)
self.successful_candidates.append(
result.perturbed_result.attacked_text
)
self.original_candidates.append(result.original_result.attacked_text)
self.successful_candidates.append(result.perturbed_result.attacked_text)
use_scores = []
for c in range(len(self.original_candidates)):
use_scores.append(self.use_obj._sim_score(self.original_candidates[c],self.successful_candidates[c]).item())
use_scores.append(
self.use_obj._sim_score(
self.original_candidates[c], self.successful_candidates[c]
).item()
)
print(use_scores)
self.all_metrics["avg_attack_use_score"] = round(
sum(use_scores) / len(use_scores), 2
)
self.all_metrics['avg_attack_use_score'] = round(sum(use_scores)/len(use_scores),2)
return self.all_metrics
return self.all_metrics