mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
allow maximization goal functions
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from .maximized_attack_result import MaximizedAttackResult
|
||||||
from .failed_attack_result import FailedAttackResult
|
from .failed_attack_result import FailedAttackResult
|
||||||
from .skipped_attack_result import SkippedAttackResult
|
from .skipped_attack_result import SkippedAttackResult
|
||||||
from .successful_attack_result import SuccessfulAttackResult
|
from .successful_attack_result import SuccessfulAttackResult
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class AttackResult:
|
|||||||
perturbed text. May or may not have been successful.
|
perturbed text. May or may not have been successful.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, original_result, perturbed_result, num_queries=0):
|
def __init__(self, original_result, perturbed_result):
|
||||||
if original_result is None:
|
if original_result is None:
|
||||||
raise ValueError("Attack original result cannot be None")
|
raise ValueError("Attack original result cannot be None")
|
||||||
elif not isinstance(original_result, GoalFunctionResult):
|
elif not isinstance(original_result, GoalFunctionResult):
|
||||||
@@ -27,7 +27,7 @@ class AttackResult:
|
|||||||
|
|
||||||
self.original_result = original_result
|
self.original_result = original_result
|
||||||
self.perturbed_result = perturbed_result
|
self.perturbed_result = perturbed_result
|
||||||
self.num_queries = num_queries
|
self.num_queries = perturbed_result.num_queries
|
||||||
|
|
||||||
# We don't want the AttackedText attributes sticking around clogging up
|
# We don't want the AttackedText attributes sticking around clogging up
|
||||||
# space on our devices. Delete them here, if they're still present,
|
# space on our devices. Delete them here, if they're still present,
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ from .attack_result import AttackResult
|
|||||||
class FailedAttackResult(AttackResult):
|
class FailedAttackResult(AttackResult):
|
||||||
"""The result of a failed attack."""
|
"""The result of a failed attack."""
|
||||||
|
|
||||||
def __init__(self, original_result, perturbed_result=None, num_queries=0):
|
def __init__(self, original_result, perturbed_result=None):
|
||||||
perturbed_result = perturbed_result or original_result
|
perturbed_result = perturbed_result or original_result
|
||||||
super().__init__(original_result, perturbed_result, num_queries)
|
super().__init__(original_result, perturbed_result)
|
||||||
|
|
||||||
def str_lines(self, color_method=None):
|
def str_lines(self, color_method=None):
|
||||||
lines = (
|
lines = (
|
||||||
|
|||||||
5
textattack/attack_results/maximized_attack_result.py
Normal file
5
textattack/attack_results/maximized_attack_result.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .attack_result import AttackResult
|
||||||
|
|
||||||
|
|
||||||
|
class MaximizedAttackResult(AttackResult):
|
||||||
|
""" The result of a successful attack. """
|
||||||
@@ -129,7 +129,7 @@ def run(args):
|
|||||||
pbar.update()
|
pbar.update()
|
||||||
num_results += 1
|
num_results += 1
|
||||||
|
|
||||||
if type(result) == textattack.attack_results.SuccessfulAttackResult:
|
if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
|
||||||
num_successes += 1
|
num_successes += 1
|
||||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||||
num_failures += 1
|
num_failures += 1
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ def run(args):
|
|||||||
|
|
||||||
num_results += 1
|
num_results += 1
|
||||||
|
|
||||||
if type(result) == textattack.attack_results.SuccessfulAttackResult:
|
if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
|
||||||
num_successes += 1
|
num_successes += 1
|
||||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||||
num_failures += 1
|
num_failures += 1
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .goal_function_result import GoalFunctionResult
|
from .goal_function_result import GoalFunctionResult, GoalFunctionResultStatus
|
||||||
|
|
||||||
from .classification_goal_function_result import ClassificationGoalFunctionResult
|
from .classification_goal_function_result import ClassificationGoalFunctionResult
|
||||||
from .text_to_text_goal_function_result import TextToTextGoalFunctionResult
|
from .text_to_text_goal_function_result import TextToTextGoalFunctionResult
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
class GoalFunctionResultStatus:
|
||||||
|
SUCCEEDED = 0
|
||||||
|
SEARCHING = 1 # In process of searching for a success
|
||||||
|
MAXIMIZING = 2
|
||||||
|
|
||||||
class GoalFunctionResult:
|
class GoalFunctionResult:
|
||||||
"""
|
"""
|
||||||
@@ -8,16 +12,20 @@ class GoalFunctionResult:
|
|||||||
Args:
|
Args:
|
||||||
attacked_text: The sequence that was evaluated.
|
attacked_text: The sequence that was evaluated.
|
||||||
output: The display-friendly output.
|
output: The display-friendly output.
|
||||||
succeeded: Whether the goal has been achieved.
|
goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal.
|
||||||
score: A score representing how close the model is to achieving its goal.
|
score: A score representing how close the model is to achieving its goal.
|
||||||
|
num_queries: How many model queries have been used
|
||||||
|
ground_truth_output: The ground truth output
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, attacked_text, raw_output, output, succeeded, score):
|
def __init__(self, attacked_text, raw_output, output, goal_status, score, num_queries, ground_truth_output):
|
||||||
self.attacked_text = attacked_text
|
self.attacked_text = attacked_text
|
||||||
self.raw_output = raw_output
|
self.raw_output = raw_output
|
||||||
self.output = output
|
self.output = output
|
||||||
self.score = score
|
self.score = score
|
||||||
self.succeeded = succeeded
|
self.goal_status = goal_status
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.ground_truth_output = ground_truth_output
|
||||||
|
|
||||||
if isinstance(self.raw_output, torch.Tensor):
|
if isinstance(self.raw_output, torch.Tensor):
|
||||||
self.raw_output = self.raw_output.cpu()
|
self.raw_output = self.raw_output.cpu()
|
||||||
@@ -25,9 +33,6 @@ class GoalFunctionResult:
|
|||||||
if isinstance(self.score, torch.Tensor):
|
if isinstance(self.score, torch.Tensor):
|
||||||
self.score = self.score.item()
|
self.score = self.score.item()
|
||||||
|
|
||||||
if isinstance(self.succeeded, torch.Tensor):
|
|
||||||
self.succeeded = self.succeeded.item()
|
|
||||||
|
|
||||||
def get_text_color_input(self):
|
def get_text_color_input(self):
|
||||||
""" A string representing the color this result's changed
|
""" A string representing the color this result's changed
|
||||||
portion should be if it represents the original input.
|
portion should be if it represents the original input.
|
||||||
|
|||||||
@@ -7,16 +7,16 @@ class TargetedClassification(ClassificationGoalFunction):
|
|||||||
score of the target label until it is the predicted label.
|
score of the target label until it is the predicted label.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, target_class=0):
|
def __init__(self, *args, target_class=0, **kwargs):
|
||||||
super().__init__(model)
|
super().__init__(*args, **kwargs)
|
||||||
self.target_class = target_class
|
self.target_class = target_class
|
||||||
|
|
||||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
def _is_goal_complete(self, model_output):
|
||||||
return (
|
return (
|
||||||
self.target_class == model_output.argmax()
|
self.target_class == model_output.argmax()
|
||||||
) or ground_truth_output == self.target_class
|
) or self.ground_truth_output == self.target_class
|
||||||
|
|
||||||
def _get_score(self, model_output, _):
|
def _get_score(self, model_output):
|
||||||
if self.target_class < 0 or self.target_class >= len(model_output):
|
if self.target_class < 0 or self.target_class >= len(model_output):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"target class set to {self.target_class} with {len(model_output)} classes."
|
f"target class set to {self.target_class} with {len(model_output)} classes."
|
||||||
|
|||||||
@@ -16,23 +16,23 @@ class UntargetedClassification(ClassificationGoalFunction):
|
|||||||
self.target_max_score = target_max_score
|
self.target_max_score = target_max_score
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
def _is_goal_complete(self, model_output):
|
||||||
if self.target_max_score:
|
if self.target_max_score:
|
||||||
return model_output[ground_truth_output] < self.target_max_score
|
return model_output[self.ground_truth_output] < self.target_max_score
|
||||||
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
elif (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
|
||||||
return abs(ground_truth_output - model_output.item()) >= (
|
return abs(self.ground_truth_output - model_output.item()) >= (
|
||||||
self.target_max_score or 0.5
|
self.target_max_score or 0.5
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return model_output.argmax() != ground_truth_output
|
return model_output.argmax() != self.ground_truth_output
|
||||||
|
|
||||||
def _get_score(self, model_output, ground_truth_output):
|
def _get_score(self, model_output):
|
||||||
# If the model outputs a single number and the ground truth output is
|
# If the model outputs a single number and the ground truth output is
|
||||||
# a float, we assume that this is a regression task.
|
# a float, we assume that this is a regression task.
|
||||||
if (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
if (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
|
||||||
return abs(model_output.item() - ground_truth_output)
|
return abs(model_output.item() - self.ground_truth_output)
|
||||||
else:
|
else:
|
||||||
return 1 - model_output[ground_truth_output]
|
return 1 - model_output[self.ground_truth_output]
|
||||||
|
|
||||||
def _get_displayed_output(self, raw_output):
|
def _get_displayed_output(self, raw_output):
|
||||||
return int(raw_output.argmax())
|
return int(raw_output.argmax())
|
||||||
|
|||||||
@@ -4,26 +4,29 @@ import lru
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from textattack.goal_function_results.goal_function_result import GoalFunctionResultStatus
|
||||||
from textattack.shared import utils, validators
|
from textattack.shared import utils, validators
|
||||||
from textattack.shared.utils import batch_model_predict, default_class_repr
|
from textattack.shared.utils import batch_model_predict, default_class_repr
|
||||||
|
|
||||||
|
|
||||||
class GoalFunction:
|
class GoalFunction:
|
||||||
"""
|
"""
|
||||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The PyTorch or TensorFlow model used for evaluation.
|
model: The PyTorch or TensorFlow model used for evaluation.
|
||||||
|
maximizable: Whether the goal function is maximizable, as opposed to a boolean result
|
||||||
|
of success or failure.
|
||||||
query_budget: The maximum number of model queries allowed.
|
query_budget: The maximum number of model queries allowed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
|
self, model, maximizable=False, tokenizer=None, use_cache=True, query_budget=float("inf")
|
||||||
):
|
):
|
||||||
validators.validate_model_goal_function_compatibility(
|
validators.validate_model_goal_function_compatibility(
|
||||||
self.__class__, model.__class__
|
self.__class__, model.__class__
|
||||||
)
|
)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.maximizable = maximizable
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
if not self.tokenizer:
|
if not self.tokenizer:
|
||||||
if hasattr(self.model, "tokenizer"):
|
if hasattr(self.model, "tokenizer"):
|
||||||
@@ -33,20 +36,16 @@ class GoalFunction:
|
|||||||
if not hasattr(self.tokenizer, "encode"):
|
if not hasattr(self.tokenizer, "encode"):
|
||||||
raise TypeError("Tokenizer must contain `encode()` method")
|
raise TypeError("Tokenizer must contain `encode()` method")
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.num_queries = 0
|
|
||||||
self.query_budget = query_budget
|
self.query_budget = query_budget
|
||||||
if self.use_cache:
|
if self.use_cache:
|
||||||
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
|
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
|
||||||
else:
|
else:
|
||||||
self._call_model_cache = None
|
self._call_model_cache = None
|
||||||
|
|
||||||
def should_skip(self, attacked_text, ground_truth_output):
|
def init_attack_example(self, attacked_text, ground_truth_output):
|
||||||
"""
|
self.initial_attacked_text = attacked_text
|
||||||
Returns whether or not the goal has already been completed for ``attacked_text``\,
|
self.ground_truth_output = ground_truth_output
|
||||||
due to misprediction by the model.
|
self.num_queries = 0
|
||||||
"""
|
|
||||||
model_outputs = self._call_model([attacked_text])
|
|
||||||
return self._is_goal_complete(model_outputs[0], ground_truth_output)
|
|
||||||
|
|
||||||
def get_output(self, attacked_text):
|
def get_output(self, attacked_text):
|
||||||
"""
|
"""
|
||||||
@@ -54,16 +53,16 @@ class GoalFunction:
|
|||||||
"""
|
"""
|
||||||
return self._get_displayed_output(self._call_model([attacked_text])[0])
|
return self._get_displayed_output(self._call_model([attacked_text])[0])
|
||||||
|
|
||||||
def get_result(self, attacked_text, ground_truth_output):
|
def get_result(self, attacked_text):
|
||||||
"""
|
"""
|
||||||
A helper method that queries `self.get_results` with a single
|
A helper method that queries `self.get_results` with a single
|
||||||
``AttackedText`` object.
|
``AttackedText`` object.
|
||||||
"""
|
"""
|
||||||
results, search_over = self.get_results([attacked_text], ground_truth_output)
|
results, search_over = self.get_results([attacked_text])
|
||||||
result = results[0] if len(results) else None
|
result = results[0] if len(results) else None
|
||||||
return result, search_over
|
return result, search_over
|
||||||
|
|
||||||
def get_results(self, attacked_text_list, ground_truth_output):
|
def get_results(self, attacked_text_list):
|
||||||
"""
|
"""
|
||||||
For each attacked_text object in attacked_text_list, returns a result
|
For each attacked_text object in attacked_text_list, returns a result
|
||||||
consisting of whether or not the goal has been achieved, the output for
|
consisting of whether or not the goal has been achieved, the output for
|
||||||
@@ -78,23 +77,32 @@ class GoalFunction:
|
|||||||
model_outputs = self._call_model(attacked_text_list)
|
model_outputs = self._call_model(attacked_text_list)
|
||||||
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
|
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
|
||||||
displayed_output = self._get_displayed_output(raw_output)
|
displayed_output = self._get_displayed_output(raw_output)
|
||||||
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
|
goal_status = self._get_goal_status(raw_output)
|
||||||
goal_function_score = self._get_score(raw_output, ground_truth_output)
|
goal_function_score = self._get_score(raw_output)
|
||||||
results.append(
|
results.append(
|
||||||
self._goal_function_result_type()(
|
self._goal_function_result_type()(
|
||||||
attacked_text,
|
attacked_text,
|
||||||
raw_output,
|
raw_output,
|
||||||
displayed_output,
|
displayed_output,
|
||||||
succeeded,
|
goal_status,
|
||||||
goal_function_score,
|
goal_function_score,
|
||||||
|
self.num_queries,
|
||||||
|
self.ground_truth_output,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return results, self.num_queries == self.query_budget
|
return results, self.num_queries == self.query_budget
|
||||||
|
|
||||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
def _get_goal_status(self, model_output):
|
||||||
|
if self.maximizable:
|
||||||
|
return GoalFunctionResultStatus.MAXIMIZING
|
||||||
|
if self._is_goal_complete(model_output):
|
||||||
|
return GoalFunctionResultStatus.SUCCEEDED
|
||||||
|
return GoalFunctionResultStatus.SEARCHING
|
||||||
|
|
||||||
|
def _is_goal_complete(self, model_output):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _get_score(self, model_output, ground_truth_output):
|
def _get_score(self, model_output):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _get_displayed_output(self, raw_output):
|
def _get_displayed_output(self, raw_output):
|
||||||
|
|||||||
@@ -14,15 +14,15 @@ class NonOverlappingOutput(TextToTextGoalFunction):
|
|||||||
Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3).
|
Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
def _is_goal_complete(self, model_output):
|
||||||
return self._get_score(model_output, ground_truth_output) == 1.0
|
return self._get_score(model_output, self.ground_truth_output) == 1.0
|
||||||
|
|
||||||
def _get_score(self, model_output, ground_truth_output):
|
def _get_score(self, model_output):
|
||||||
num_words_diff = word_difference_score(model_output, ground_truth_output)
|
num_words_diff = word_difference_score(model_output, self.ground_truth_output)
|
||||||
if num_words_diff == 0:
|
if num_words_diff == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
else:
|
else:
|
||||||
return num_words_diff / len(get_words_cached(ground_truth_output))
|
return num_words_diff / len(get_words_cached(self.ground_truth_output))
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=2 ** 12)
|
@functools.lru_cache(maxsize=2 ** 12)
|
||||||
|
|||||||
@@ -9,9 +9,6 @@ class TextToTextGoalFunction(GoalFunction):
|
|||||||
original_output: the original output of the model
|
original_output: the original output of the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__(model)
|
|
||||||
|
|
||||||
def _goal_function_result_type(self):
|
def _goal_function_result_type(self):
|
||||||
""" Returns the class of this goal function's results. """
|
""" Returns the class of this goal function's results. """
|
||||||
return TextToTextGoalFunctionResult
|
return TextToTextGoalFunctionResult
|
||||||
|
|||||||
@@ -20,9 +20,8 @@ class CSVLogger(Logger):
|
|||||||
self._flushed = True
|
self._flushed = True
|
||||||
|
|
||||||
def log_attack_result(self, result):
|
def log_attack_result(self, result):
|
||||||
if isinstance(result, FailedAttackResult):
|
|
||||||
return
|
|
||||||
original_text, perturbed_text = result.diff_color(self.color_method)
|
original_text, perturbed_text = result.diff_color(self.color_method)
|
||||||
|
result_type = result.__class__.__name__[:-12]
|
||||||
row = {
|
row = {
|
||||||
"original_text": original_text,
|
"original_text": original_text,
|
||||||
"perturbed_text": perturbed_text,
|
"perturbed_text": perturbed_text,
|
||||||
@@ -30,7 +29,9 @@ class CSVLogger(Logger):
|
|||||||
"perturbed_score": result.perturbed_result.score,
|
"perturbed_score": result.perturbed_result.score,
|
||||||
"original_output": result.original_result.output,
|
"original_output": result.original_result.output,
|
||||||
"perturbed_output": result.perturbed_result.output,
|
"perturbed_output": result.perturbed_result.output,
|
||||||
|
"ground_truth_output": result.original_result.ground_truth_output,
|
||||||
"num_queries": result.num_queries,
|
"num_queries": result.num_queries,
|
||||||
|
"result_type": result_type
|
||||||
}
|
}
|
||||||
self.df = self.df.append(row, ignore_index=True)
|
self.df = self.df.append(row, ignore_index=True)
|
||||||
self._flushed = False
|
self._flushed = False
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||||
from textattack.search_methods import SearchMethod
|
from textattack.search_methods import SearchMethod
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ class BeamSearch(SearchMethod):
|
|||||||
def _perform_search(self, initial_result):
|
def _perform_search(self, initial_result):
|
||||||
beam = [initial_result.attacked_text]
|
beam = [initial_result.attacked_text]
|
||||||
best_result = initial_result
|
best_result = initial_result
|
||||||
while not best_result.succeeded:
|
while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||||
potential_next_beam = []
|
potential_next_beam = []
|
||||||
for text in beam:
|
for text in beam:
|
||||||
transformations = self.get_transformations(
|
transformations = self.get_transformations(
|
||||||
@@ -33,7 +34,7 @@ class BeamSearch(SearchMethod):
|
|||||||
# If we did not find any possible perturbations, give up.
|
# If we did not find any possible perturbations, give up.
|
||||||
return best_result
|
return best_result
|
||||||
results, search_over = self.get_goal_results(
|
results, search_over = self.get_goal_results(
|
||||||
potential_next_beam, initial_result.output
|
potential_next_beam
|
||||||
)
|
)
|
||||||
scores = np.array([r.score for r in results])
|
scores = np.array([r.score for r in results])
|
||||||
best_result = results[scores.argmax()]
|
best_result = results[scores.argmax()]
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from copy import deepcopy
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||||
from textattack.search_methods import SearchMethod
|
from textattack.search_methods import SearchMethod
|
||||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||||
|
|
||||||
@@ -52,12 +53,12 @@ class GeneticAlgorithm(SearchMethod):
|
|||||||
if not len(transformations):
|
if not len(transformations):
|
||||||
return False
|
return False
|
||||||
orig_result, self.search_over = self.get_goal_results(
|
orig_result, self.search_over = self.get_goal_results(
|
||||||
[pop_member.attacked_text], self.correct_output
|
[pop_member.attacked_text]
|
||||||
)
|
)
|
||||||
if self.search_over:
|
if self.search_over:
|
||||||
return False
|
return False
|
||||||
new_x_results, self.search_over = self.get_goal_results(
|
new_x_results, self.search_over = self.get_goal_results(
|
||||||
transformations, self.correct_output
|
transformations
|
||||||
)
|
)
|
||||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||||
new_x_scores = new_x_scores - orig_result[0].score
|
new_x_scores = new_x_scores - orig_result[0].score
|
||||||
@@ -157,7 +158,7 @@ class GeneticAlgorithm(SearchMethod):
|
|||||||
cur_score = initial_result.score
|
cur_score = initial_result.score
|
||||||
for i in range(self.max_iters):
|
for i in range(self.max_iters):
|
||||||
pop_results, self.search_over = self.get_goal_results(
|
pop_results, self.search_over = self.get_goal_results(
|
||||||
[pm.attacked_text for pm in pop], self.correct_output
|
[pm.attacked_text for pm in pop]
|
||||||
)
|
)
|
||||||
if self.search_over:
|
if self.search_over:
|
||||||
if not len(pop_results):
|
if not len(pop_results):
|
||||||
@@ -171,7 +172,7 @@ class GeneticAlgorithm(SearchMethod):
|
|||||||
logits = ((-pop_scores) / self.temp).exp()
|
logits = ((-pop_scores) / self.temp).exp()
|
||||||
select_probs = (logits / logits.sum()).cpu().numpy()
|
select_probs = (logits / logits.sum()).cpu().numpy()
|
||||||
|
|
||||||
if pop[0].result.succeeded:
|
if pop[0].result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||||
return pop[0].result
|
return pop[0].result
|
||||||
|
|
||||||
if pop[0].result.score > cur_score:
|
if pop[0].result.score > cur_score:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler.
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||||
from textattack.search_methods import SearchMethod
|
from textattack.search_methods import SearchMethod
|
||||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
|||||||
ranks in order of descending score.
|
ranks in order of descending score.
|
||||||
"""
|
"""
|
||||||
leave_one_results, search_over = self.get_goal_results(
|
leave_one_results, search_over = self.get_goal_results(
|
||||||
texts, initial_result.output
|
texts
|
||||||
)
|
)
|
||||||
leave_one_scores = np.array([result.score for result in leave_one_results])
|
leave_one_scores = np.array([result.score for result in leave_one_results])
|
||||||
return leave_one_scores, search_over
|
return leave_one_scores, search_over
|
||||||
@@ -78,7 +79,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
|||||||
if len(transformed_text_candidates) == 0:
|
if len(transformed_text_candidates) == 0:
|
||||||
continue
|
continue
|
||||||
results, search_over = self.get_goal_results(
|
results, search_over = self.get_goal_results(
|
||||||
transformed_text_candidates, initial_result.output
|
transformed_text_candidates
|
||||||
)
|
)
|
||||||
results = sorted(results, key=lambda x: -x.score)
|
results = sorted(results, key=lambda x: -x.score)
|
||||||
# Skip swaps which don't improve the score
|
# Skip swaps which don't improve the score
|
||||||
@@ -87,12 +88,12 @@ class GreedyWordSwapWIR(SearchMethod):
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
# If we succeeded, return the index with best similarity.
|
# If we succeeded, return the index with best similarity.
|
||||||
if cur_result.succeeded:
|
if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||||
best_result = cur_result
|
best_result = cur_result
|
||||||
# @TODO: Use vectorwise operations
|
# @TODO: Use vectorwise operations
|
||||||
max_similarity = -float("inf")
|
max_similarity = -float("inf")
|
||||||
for result in results:
|
for result in results:
|
||||||
if not result.succeeded:
|
if result.goal_status != GoalFunctionResultStatus.SUCCEEDED:
|
||||||
break
|
break
|
||||||
candidate = result.attacked_text
|
candidate = result.attacked_text
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ import numpy as np
|
|||||||
|
|
||||||
import textattack
|
import textattack
|
||||||
from textattack.attack_results import (
|
from textattack.attack_results import (
|
||||||
|
MaximizedAttackResult,
|
||||||
FailedAttackResult,
|
FailedAttackResult,
|
||||||
SkippedAttackResult,
|
SkippedAttackResult,
|
||||||
SuccessfulAttackResult,
|
SuccessfulAttackResult,
|
||||||
)
|
)
|
||||||
|
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||||
from textattack.shared import AttackedText, utils
|
from textattack.shared import AttackedText, utils
|
||||||
|
|
||||||
|
|
||||||
@@ -170,17 +172,24 @@ class Attack:
|
|||||||
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
|
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``.
|
Either a ``SuccessfulAttackResult``, ``FailedAttackResult``,
|
||||||
|
or ``MaximizedAttackResult``.
|
||||||
"""
|
"""
|
||||||
final_result = self.search_method(initial_result)
|
final_result = self.search_method(initial_result)
|
||||||
if final_result.succeeded:
|
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||||
return SuccessfulAttackResult(
|
return SuccessfulAttackResult(
|
||||||
initial_result, final_result, self.goal_function.num_queries
|
initial_result, final_result,
|
||||||
|
)
|
||||||
|
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
|
||||||
|
return FailedAttackResult(
|
||||||
|
initial_result, final_result,
|
||||||
|
)
|
||||||
|
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
|
||||||
|
return MaximizedAttackResult(
|
||||||
|
initial_result, final_result,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return FailedAttackResult(
|
raise ValueError(f'Unrecognized goal status {final_result.goal_status}')
|
||||||
initial_result, final_result, self.goal_function.num_queries
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_examples_from_dataset(self, dataset, indices=None):
|
def _get_examples_from_dataset(self, dataset, indices=None):
|
||||||
"""
|
"""
|
||||||
@@ -212,14 +221,10 @@ class Attack:
|
|||||||
attacked_text = AttackedText(
|
attacked_text = AttackedText(
|
||||||
text, attack_attrs={"label_names": label_names}
|
text, attack_attrs={"label_names": label_names}
|
||||||
)
|
)
|
||||||
self.goal_function.num_queries = 0
|
self.goal_function.init_attack_example(attacked_text, ground_truth_output)
|
||||||
goal_function_result, _ = self.goal_function.get_result(
|
goal_function_result, _ = self.goal_function.get_result(
|
||||||
attacked_text, ground_truth_output
|
attacked_text
|
||||||
)
|
)
|
||||||
if goal_function_result.succeeded:
|
|
||||||
# Store the true output on the goal function so that the
|
|
||||||
# SkippedAttackResult has the correct output, not the incorrect.
|
|
||||||
goal_function_result.output = ground_truth_output
|
|
||||||
yield goal_function_result
|
yield goal_function_result
|
||||||
|
|
||||||
except IndexError:
|
except IndexError:
|
||||||
@@ -240,7 +245,7 @@ class Attack:
|
|||||||
examples = self._get_examples_from_dataset(dataset, indices=indices)
|
examples = self._get_examples_from_dataset(dataset, indices=indices)
|
||||||
|
|
||||||
for goal_function_result in examples:
|
for goal_function_result in examples:
|
||||||
if goal_function_result.succeeded:
|
if goal_function_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||||
yield SkippedAttackResult(goal_function_result)
|
yield SkippedAttackResult(goal_function_result)
|
||||||
else:
|
else:
|
||||||
result = self.attack_one(goal_function_result)
|
result = self.attack_one(goal_function_result)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pickle
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from textattack.attack_results import (
|
from textattack.attack_results import (
|
||||||
|
MaximizedAttackResult,
|
||||||
FailedAttackResult,
|
FailedAttackResult,
|
||||||
SkippedAttackResult,
|
SkippedAttackResult,
|
||||||
SuccessfulAttackResult,
|
SuccessfulAttackResult,
|
||||||
@@ -101,6 +102,11 @@ class Checkpoint:
|
|||||||
f"(Number of failed attacks): {self.num_failed_attacks}", 2
|
f"(Number of failed attacks): {self.num_failed_attacks}", 2
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
breakdown_lines.append(
|
||||||
|
utils.add_indent(
|
||||||
|
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2
|
||||||
|
)
|
||||||
|
)
|
||||||
breakdown_lines.append(
|
breakdown_lines.append(
|
||||||
utils.add_indent(
|
utils.add_indent(
|
||||||
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
|
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
|
||||||
@@ -140,6 +146,12 @@ class Checkpoint:
|
|||||||
isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results
|
isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_maximized_attacks(self):
|
||||||
|
return sum(
|
||||||
|
isinstance(r, MaximizedAttackResult) for r in self.log_manager.results
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_remaining_attacks(self):
|
def num_remaining_attacks(self):
|
||||||
if self.args.attack_n:
|
if self.args.attack_n:
|
||||||
|
|||||||
Reference in New Issue
Block a user