1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

allow maximization goal functions

This commit is contained in:
uvafan
2020-06-23 23:33:48 -04:00
parent fe109267a1
commit 0fcfb51b7f
19 changed files with 115 additions and 78 deletions

View File

@@ -1,3 +1,4 @@
from .maximized_attack_result import MaximizedAttackResult
from .failed_attack_result import FailedAttackResult
from .skipped_attack_result import SkippedAttackResult
from .successful_attack_result import SuccessfulAttackResult

View File

@@ -13,7 +13,7 @@ class AttackResult:
perturbed text. May or may not have been successful.
"""
def __init__(self, original_result, perturbed_result, num_queries=0):
def __init__(self, original_result, perturbed_result):
if original_result is None:
raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult):
@@ -27,7 +27,7 @@ class AttackResult:
self.original_result = original_result
self.perturbed_result = perturbed_result
self.num_queries = num_queries
self.num_queries = perturbed_result.num_queries
# We don't want the AttackedText attributes sticking around clogging up
# space on our devices. Delete them here, if they're still present,

View File

@@ -6,9 +6,9 @@ from .attack_result import AttackResult
class FailedAttackResult(AttackResult):
"""The result of a failed attack."""
def __init__(self, original_result, perturbed_result=None, num_queries=0):
def __init__(self, original_result, perturbed_result=None):
perturbed_result = perturbed_result or original_result
super().__init__(original_result, perturbed_result, num_queries)
super().__init__(original_result, perturbed_result)
def str_lines(self, color_method=None):
lines = (

View File

@@ -0,0 +1,5 @@
from .attack_result import AttackResult
class MaximizedAttackResult(AttackResult):
""" The result of a successful attack. """

View File

@@ -129,7 +129,7 @@ def run(args):
pbar.update()
num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult:
if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1

View File

@@ -114,7 +114,7 @@ def run(args):
num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult:
if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1

View File

@@ -1,4 +1,4 @@
from .goal_function_result import GoalFunctionResult
from .goal_function_result import GoalFunctionResult, GoalFunctionResultStatus
from .classification_goal_function_result import ClassificationGoalFunctionResult
from .text_to_text_goal_function_result import TextToTextGoalFunctionResult

View File

@@ -1,5 +1,9 @@
import torch
class GoalFunctionResultStatus:
SUCCEEDED = 0
SEARCHING = 1 # In process of searching for a success
MAXIMIZING = 2
class GoalFunctionResult:
"""
@@ -8,16 +12,20 @@ class GoalFunctionResult:
Args:
attacked_text: The sequence that was evaluated.
output: The display-friendly output.
succeeded: Whether the goal has been achieved.
goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal.
score: A score representing how close the model is to achieving its goal.
num_queries: How many model queries have been used
ground_truth_output: The ground truth output
"""
def __init__(self, attacked_text, raw_output, output, succeeded, score):
def __init__(self, attacked_text, raw_output, output, goal_status, score, num_queries, ground_truth_output):
self.attacked_text = attacked_text
self.raw_output = raw_output
self.output = output
self.score = score
self.succeeded = succeeded
self.goal_status = goal_status
self.num_queries = num_queries
self.ground_truth_output = ground_truth_output
if isinstance(self.raw_output, torch.Tensor):
self.raw_output = self.raw_output.cpu()
@@ -25,9 +33,6 @@ class GoalFunctionResult:
if isinstance(self.score, torch.Tensor):
self.score = self.score.item()
if isinstance(self.succeeded, torch.Tensor):
self.succeeded = self.succeeded.item()
def get_text_color_input(self):
""" A string representing the color this result's changed
portion should be if it represents the original input.

View File

@@ -7,16 +7,16 @@ class TargetedClassification(ClassificationGoalFunction):
score of the target label until it is the predicted label.
"""
def __init__(self, model, target_class=0):
super().__init__(model)
def __init__(self, *args, target_class=0, **kwargs):
super().__init__(*args, **kwargs)
self.target_class = target_class
def _is_goal_complete(self, model_output, ground_truth_output):
def _is_goal_complete(self, model_output):
return (
self.target_class == model_output.argmax()
) or ground_truth_output == self.target_class
) or self.ground_truth_output == self.target_class
def _get_score(self, model_output, _):
def _get_score(self, model_output):
if self.target_class < 0 or self.target_class >= len(model_output):
raise ValueError(
f"target class set to {self.target_class} with {len(model_output)} classes."

View File

@@ -16,23 +16,23 @@ class UntargetedClassification(ClassificationGoalFunction):
self.target_max_score = target_max_score
super().__init__(*args, **kwargs)
def _is_goal_complete(self, model_output, ground_truth_output):
def _is_goal_complete(self, model_output):
if self.target_max_score:
return model_output[ground_truth_output] < self.target_max_score
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float):
return abs(ground_truth_output - model_output.item()) >= (
return model_output[self.ground_truth_output] < self.target_max_score
elif (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
return abs(self.ground_truth_output - model_output.item()) >= (
self.target_max_score or 0.5
)
else:
return model_output.argmax() != ground_truth_output
return model_output.argmax() != self.ground_truth_output
def _get_score(self, model_output, ground_truth_output):
def _get_score(self, model_output):
# If the model outputs a single number and the ground truth output is
# a float, we assume that this is a regression task.
if (model_output.numel() is 1) and isinstance(ground_truth_output, float):
return abs(model_output.item() - ground_truth_output)
if (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
return abs(model_output.item() - self.ground_truth_output)
else:
return 1 - model_output[ground_truth_output]
return 1 - model_output[self.ground_truth_output]
def _get_displayed_output(self, raw_output):
return int(raw_output.argmax())

View File

@@ -4,26 +4,29 @@ import lru
import numpy as np
import torch
from textattack.goal_function_results.goal_function_result import GoalFunctionResultStatus
from textattack.shared import utils, validators
from textattack.shared.utils import batch_model_predict, default_class_repr
class GoalFunction:
"""
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
Args:
model: The PyTorch or TensorFlow model used for evaluation.
maximizable: Whether the goal function is maximizable, as opposed to a boolean result
of success or failure.
query_budget: The maximum number of model queries allowed.
"""
def __init__(
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
self, model, maximizable=False, tokenizer=None, use_cache=True, query_budget=float("inf")
):
validators.validate_model_goal_function_compatibility(
self.__class__, model.__class__
)
self.model = model
self.maximizable = maximizable
self.tokenizer = tokenizer
if not self.tokenizer:
if hasattr(self.model, "tokenizer"):
@@ -33,20 +36,16 @@ class GoalFunction:
if not hasattr(self.tokenizer, "encode"):
raise TypeError("Tokenizer must contain `encode()` method")
self.use_cache = use_cache
self.num_queries = 0
self.query_budget = query_budget
if self.use_cache:
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
else:
self._call_model_cache = None
def should_skip(self, attacked_text, ground_truth_output):
"""
Returns whether or not the goal has already been completed for ``attacked_text``\,
due to misprediction by the model.
"""
model_outputs = self._call_model([attacked_text])
return self._is_goal_complete(model_outputs[0], ground_truth_output)
def init_attack_example(self, attacked_text, ground_truth_output):
self.initial_attacked_text = attacked_text
self.ground_truth_output = ground_truth_output
self.num_queries = 0
def get_output(self, attacked_text):
"""
@@ -54,16 +53,16 @@ class GoalFunction:
"""
return self._get_displayed_output(self._call_model([attacked_text])[0])
def get_result(self, attacked_text, ground_truth_output):
def get_result(self, attacked_text):
"""
A helper method that queries `self.get_results` with a single
``AttackedText`` object.
"""
results, search_over = self.get_results([attacked_text], ground_truth_output)
results, search_over = self.get_results([attacked_text])
result = results[0] if len(results) else None
return result, search_over
def get_results(self, attacked_text_list, ground_truth_output):
def get_results(self, attacked_text_list):
"""
For each attacked_text object in attacked_text_list, returns a result
consisting of whether or not the goal has been achieved, the output for
@@ -78,23 +77,32 @@ class GoalFunction:
model_outputs = self._call_model(attacked_text_list)
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
displayed_output = self._get_displayed_output(raw_output)
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
goal_function_score = self._get_score(raw_output, ground_truth_output)
goal_status = self._get_goal_status(raw_output)
goal_function_score = self._get_score(raw_output)
results.append(
self._goal_function_result_type()(
attacked_text,
raw_output,
displayed_output,
succeeded,
goal_status,
goal_function_score,
self.num_queries,
self.ground_truth_output,
)
)
return results, self.num_queries == self.query_budget
def _is_goal_complete(self, model_output, ground_truth_output):
def _get_goal_status(self, model_output):
if self.maximizable:
return GoalFunctionResultStatus.MAXIMIZING
if self._is_goal_complete(model_output):
return GoalFunctionResultStatus.SUCCEEDED
return GoalFunctionResultStatus.SEARCHING
def _is_goal_complete(self, model_output):
raise NotImplementedError()
def _get_score(self, model_output, ground_truth_output):
def _get_score(self, model_output):
raise NotImplementedError()
def _get_displayed_output(self, raw_output):

View File

@@ -14,15 +14,15 @@ class NonOverlappingOutput(TextToTextGoalFunction):
Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3).
"""
def _is_goal_complete(self, model_output, ground_truth_output):
return self._get_score(model_output, ground_truth_output) == 1.0
def _is_goal_complete(self, model_output):
return self._get_score(model_output, self.ground_truth_output) == 1.0
def _get_score(self, model_output, ground_truth_output):
num_words_diff = word_difference_score(model_output, ground_truth_output)
def _get_score(self, model_output):
num_words_diff = word_difference_score(model_output, self.ground_truth_output)
if num_words_diff == 0:
return 0.0
else:
return num_words_diff / len(get_words_cached(ground_truth_output))
return num_words_diff / len(get_words_cached(self.ground_truth_output))
@functools.lru_cache(maxsize=2 ** 12)

View File

@@ -9,9 +9,6 @@ class TextToTextGoalFunction(GoalFunction):
original_output: the original output of the model
"""
def __init__(self, model):
super().__init__(model)
def _goal_function_result_type(self):
""" Returns the class of this goal function's results. """
return TextToTextGoalFunctionResult

View File

@@ -20,9 +20,8 @@ class CSVLogger(Logger):
self._flushed = True
def log_attack_result(self, result):
if isinstance(result, FailedAttackResult):
return
original_text, perturbed_text = result.diff_color(self.color_method)
result_type = result.__class__.__name__[:-12]
row = {
"original_text": original_text,
"perturbed_text": perturbed_text,
@@ -30,7 +29,9 @@ class CSVLogger(Logger):
"perturbed_score": result.perturbed_result.score,
"original_output": result.original_result.output,
"perturbed_output": result.perturbed_result.output,
"ground_truth_output": result.original_result.ground_truth_output,
"num_queries": result.num_queries,
"result_type": result_type
}
self.df = self.df.append(row, ignore_index=True)
self._flushed = False

View File

@@ -1,5 +1,6 @@
import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
@@ -21,7 +22,7 @@ class BeamSearch(SearchMethod):
def _perform_search(self, initial_result):
beam = [initial_result.attacked_text]
best_result = initial_result
while not best_result.succeeded:
while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
potential_next_beam = []
for text in beam:
transformations = self.get_transformations(
@@ -33,7 +34,7 @@ class BeamSearch(SearchMethod):
# If we did not find any possible perturbations, give up.
return best_result
results, search_over = self.get_goal_results(
potential_next_beam, initial_result.output
potential_next_beam
)
scores = np.array([r.score for r in results])
best_result = results[scores.argmax()]

View File

@@ -10,6 +10,7 @@ from copy import deepcopy
import numpy as np
import torch
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps
@@ -52,12 +53,12 @@ class GeneticAlgorithm(SearchMethod):
if not len(transformations):
return False
orig_result, self.search_over = self.get_goal_results(
[pop_member.attacked_text], self.correct_output
[pop_member.attacked_text]
)
if self.search_over:
return False
new_x_results, self.search_over = self.get_goal_results(
transformations, self.correct_output
transformations
)
new_x_scores = torch.Tensor([r.score for r in new_x_results])
new_x_scores = new_x_scores - orig_result[0].score
@@ -157,7 +158,7 @@ class GeneticAlgorithm(SearchMethod):
cur_score = initial_result.score
for i in range(self.max_iters):
pop_results, self.search_over = self.get_goal_results(
[pm.attacked_text for pm in pop], self.correct_output
[pm.attacked_text for pm in pop]
)
if self.search_over:
if not len(pop_results):
@@ -171,7 +172,7 @@ class GeneticAlgorithm(SearchMethod):
logits = ((-pop_scores) / self.temp).exp()
select_probs = (logits / logits.sum()).cpu().numpy()
if pop[0].result.succeeded:
if pop[0].result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return pop[0].result
if pop[0].result.score > cur_score:

View File

@@ -9,6 +9,7 @@ See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler.
import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps
@@ -33,7 +34,7 @@ class GreedyWordSwapWIR(SearchMethod):
ranks in order of descending score.
"""
leave_one_results, search_over = self.get_goal_results(
texts, initial_result.output
texts
)
leave_one_scores = np.array([result.score for result in leave_one_results])
return leave_one_scores, search_over
@@ -78,7 +79,7 @@ class GreedyWordSwapWIR(SearchMethod):
if len(transformed_text_candidates) == 0:
continue
results, search_over = self.get_goal_results(
transformed_text_candidates, initial_result.output
transformed_text_candidates
)
results = sorted(results, key=lambda x: -x.score)
# Skip swaps which don't improve the score
@@ -87,12 +88,12 @@ class GreedyWordSwapWIR(SearchMethod):
else:
continue
# If we succeeded, return the index with best similarity.
if cur_result.succeeded:
if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
best_result = cur_result
# @TODO: Use vectorwise operations
max_similarity = -float("inf")
for result in results:
if not result.succeeded:
if result.goal_status != GoalFunctionResultStatus.SUCCEEDED:
break
candidate = result.attacked_text
try:

View File

@@ -6,10 +6,12 @@ import numpy as np
import textattack
from textattack.attack_results import (
MaximizedAttackResult,
FailedAttackResult,
SkippedAttackResult,
SuccessfulAttackResult,
)
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.shared import AttackedText, utils
@@ -170,17 +172,24 @@ class Attack:
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
Returns:
Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``.
Either a ``SuccessfulAttackResult``, ``FailedAttackResult``,
or ``MaximizedAttackResult``.
"""
final_result = self.search_method(initial_result)
if final_result.succeeded:
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult(
initial_result, final_result, self.goal_function.num_queries
initial_result, final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
return FailedAttackResult(
initial_result, final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(
initial_result, final_result,
)
else:
return FailedAttackResult(
initial_result, final_result, self.goal_function.num_queries
)
raise ValueError(f'Unrecognized goal status {final_result.goal_status}')
def _get_examples_from_dataset(self, dataset, indices=None):
"""
@@ -212,14 +221,10 @@ class Attack:
attacked_text = AttackedText(
text, attack_attrs={"label_names": label_names}
)
self.goal_function.num_queries = 0
self.goal_function.init_attack_example(attacked_text, ground_truth_output)
goal_function_result, _ = self.goal_function.get_result(
attacked_text, ground_truth_output
attacked_text
)
if goal_function_result.succeeded:
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = ground_truth_output
yield goal_function_result
except IndexError:
@@ -240,7 +245,7 @@ class Attack:
examples = self._get_examples_from_dataset(dataset, indices=indices)
for goal_function_result in examples:
if goal_function_result.succeeded:
if goal_function_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
yield SkippedAttackResult(goal_function_result)
else:
result = self.attack_one(goal_function_result)

View File

@@ -5,6 +5,7 @@ import pickle
import time
from textattack.attack_results import (
MaximizedAttackResult,
FailedAttackResult,
SkippedAttackResult,
SuccessfulAttackResult,
@@ -101,6 +102,11 @@ class Checkpoint:
f"(Number of failed attacks): {self.num_failed_attacks}", 2
)
)
breakdown_lines.append(
utils.add_indent(
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2
)
)
breakdown_lines.append(
utils.add_indent(
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
@@ -140,6 +146,12 @@ class Checkpoint:
isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results
)
@property
def num_maximized_attacks(self):
return sum(
isinstance(r, MaximizedAttackResult) for r in self.log_manager.results
)
@property
def num_remaining_attacks(self):
if self.args.attack_n: