1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

allow maximization goal functions

This commit is contained in:
uvafan
2020-06-23 23:33:48 -04:00
parent fe109267a1
commit 0fcfb51b7f
19 changed files with 115 additions and 78 deletions

View File

@@ -1,3 +1,4 @@
from .maximized_attack_result import MaximizedAttackResult
from .failed_attack_result import FailedAttackResult from .failed_attack_result import FailedAttackResult
from .skipped_attack_result import SkippedAttackResult from .skipped_attack_result import SkippedAttackResult
from .successful_attack_result import SuccessfulAttackResult from .successful_attack_result import SuccessfulAttackResult

View File

@@ -13,7 +13,7 @@ class AttackResult:
perturbed text. May or may not have been successful. perturbed text. May or may not have been successful.
""" """
def __init__(self, original_result, perturbed_result, num_queries=0): def __init__(self, original_result, perturbed_result):
if original_result is None: if original_result is None:
raise ValueError("Attack original result cannot be None") raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult): elif not isinstance(original_result, GoalFunctionResult):
@@ -27,7 +27,7 @@ class AttackResult:
self.original_result = original_result self.original_result = original_result
self.perturbed_result = perturbed_result self.perturbed_result = perturbed_result
self.num_queries = num_queries self.num_queries = perturbed_result.num_queries
# We don't want the AttackedText attributes sticking around clogging up # We don't want the AttackedText attributes sticking around clogging up
# space on our devices. Delete them here, if they're still present, # space on our devices. Delete them here, if they're still present,

View File

@@ -6,9 +6,9 @@ from .attack_result import AttackResult
class FailedAttackResult(AttackResult): class FailedAttackResult(AttackResult):
"""The result of a failed attack.""" """The result of a failed attack."""
def __init__(self, original_result, perturbed_result=None, num_queries=0): def __init__(self, original_result, perturbed_result=None):
perturbed_result = perturbed_result or original_result perturbed_result = perturbed_result or original_result
super().__init__(original_result, perturbed_result, num_queries) super().__init__(original_result, perturbed_result)
def str_lines(self, color_method=None): def str_lines(self, color_method=None):
lines = ( lines = (

View File

@@ -0,0 +1,5 @@
from .attack_result import AttackResult
class MaximizedAttackResult(AttackResult):
""" The result of a successful attack. """

View File

@@ -129,7 +129,7 @@ def run(args):
pbar.update() pbar.update()
num_results += 1 num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult: if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
num_successes += 1 num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult: if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1 num_failures += 1

View File

@@ -114,7 +114,7 @@ def run(args):
num_results += 1 num_results += 1
if type(result) == textattack.attack_results.SuccessfulAttackResult: if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
num_successes += 1 num_successes += 1
if type(result) == textattack.attack_results.FailedAttackResult: if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1 num_failures += 1

View File

@@ -1,4 +1,4 @@
from .goal_function_result import GoalFunctionResult from .goal_function_result import GoalFunctionResult, GoalFunctionResultStatus
from .classification_goal_function_result import ClassificationGoalFunctionResult from .classification_goal_function_result import ClassificationGoalFunctionResult
from .text_to_text_goal_function_result import TextToTextGoalFunctionResult from .text_to_text_goal_function_result import TextToTextGoalFunctionResult

View File

@@ -1,5 +1,9 @@
import torch import torch
class GoalFunctionResultStatus:
SUCCEEDED = 0
SEARCHING = 1 # In process of searching for a success
MAXIMIZING = 2
class GoalFunctionResult: class GoalFunctionResult:
""" """
@@ -8,16 +12,20 @@ class GoalFunctionResult:
Args: Args:
attacked_text: The sequence that was evaluated. attacked_text: The sequence that was evaluated.
output: The display-friendly output. output: The display-friendly output.
succeeded: Whether the goal has been achieved. goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal.
score: A score representing how close the model is to achieving its goal. score: A score representing how close the model is to achieving its goal.
num_queries: How many model queries have been used
ground_truth_output: The ground truth output
""" """
def __init__(self, attacked_text, raw_output, output, succeeded, score): def __init__(self, attacked_text, raw_output, output, goal_status, score, num_queries, ground_truth_output):
self.attacked_text = attacked_text self.attacked_text = attacked_text
self.raw_output = raw_output self.raw_output = raw_output
self.output = output self.output = output
self.score = score self.score = score
self.succeeded = succeeded self.goal_status = goal_status
self.num_queries = num_queries
self.ground_truth_output = ground_truth_output
if isinstance(self.raw_output, torch.Tensor): if isinstance(self.raw_output, torch.Tensor):
self.raw_output = self.raw_output.cpu() self.raw_output = self.raw_output.cpu()
@@ -25,9 +33,6 @@ class GoalFunctionResult:
if isinstance(self.score, torch.Tensor): if isinstance(self.score, torch.Tensor):
self.score = self.score.item() self.score = self.score.item()
if isinstance(self.succeeded, torch.Tensor):
self.succeeded = self.succeeded.item()
def get_text_color_input(self): def get_text_color_input(self):
""" A string representing the color this result's changed """ A string representing the color this result's changed
portion should be if it represents the original input. portion should be if it represents the original input.

View File

@@ -7,16 +7,16 @@ class TargetedClassification(ClassificationGoalFunction):
score of the target label until it is the predicted label. score of the target label until it is the predicted label.
""" """
def __init__(self, model, target_class=0): def __init__(self, *args, target_class=0, **kwargs):
super().__init__(model) super().__init__(*args, **kwargs)
self.target_class = target_class self.target_class = target_class
def _is_goal_complete(self, model_output, ground_truth_output): def _is_goal_complete(self, model_output):
return ( return (
self.target_class == model_output.argmax() self.target_class == model_output.argmax()
) or ground_truth_output == self.target_class ) or self.ground_truth_output == self.target_class
def _get_score(self, model_output, _): def _get_score(self, model_output):
if self.target_class < 0 or self.target_class >= len(model_output): if self.target_class < 0 or self.target_class >= len(model_output):
raise ValueError( raise ValueError(
f"target class set to {self.target_class} with {len(model_output)} classes." f"target class set to {self.target_class} with {len(model_output)} classes."

View File

@@ -16,23 +16,23 @@ class UntargetedClassification(ClassificationGoalFunction):
self.target_max_score = target_max_score self.target_max_score = target_max_score
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _is_goal_complete(self, model_output, ground_truth_output): def _is_goal_complete(self, model_output):
if self.target_max_score: if self.target_max_score:
return model_output[ground_truth_output] < self.target_max_score return model_output[self.ground_truth_output] < self.target_max_score
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float): elif (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
return abs(ground_truth_output - model_output.item()) >= ( return abs(self.ground_truth_output - model_output.item()) >= (
self.target_max_score or 0.5 self.target_max_score or 0.5
) )
else: else:
return model_output.argmax() != ground_truth_output return model_output.argmax() != self.ground_truth_output
def _get_score(self, model_output, ground_truth_output): def _get_score(self, model_output):
# If the model outputs a single number and the ground truth output is # If the model outputs a single number and the ground truth output is
# a float, we assume that this is a regression task. # a float, we assume that this is a regression task.
if (model_output.numel() is 1) and isinstance(ground_truth_output, float): if (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
return abs(model_output.item() - ground_truth_output) return abs(model_output.item() - self.ground_truth_output)
else: else:
return 1 - model_output[ground_truth_output] return 1 - model_output[self.ground_truth_output]
def _get_displayed_output(self, raw_output): def _get_displayed_output(self, raw_output):
return int(raw_output.argmax()) return int(raw_output.argmax())

View File

@@ -4,26 +4,29 @@ import lru
import numpy as np import numpy as np
import torch import torch
from textattack.goal_function_results.goal_function_result import GoalFunctionResultStatus
from textattack.shared import utils, validators from textattack.shared import utils, validators
from textattack.shared.utils import batch_model_predict, default_class_repr from textattack.shared.utils import batch_model_predict, default_class_repr
class GoalFunction: class GoalFunction:
""" """
Evaluates how well a perturbed attacked_text object is achieving a specified goal. Evaluates how well a perturbed attacked_text object is achieving a specified goal.
Args: Args:
model: The PyTorch or TensorFlow model used for evaluation. model: The PyTorch or TensorFlow model used for evaluation.
maximizable: Whether the goal function is maximizable, as opposed to a boolean result
of success or failure.
query_budget: The maximum number of model queries allowed. query_budget: The maximum number of model queries allowed.
""" """
def __init__( def __init__(
self, model, tokenizer=None, use_cache=True, query_budget=float("inf") self, model, maximizable=False, tokenizer=None, use_cache=True, query_budget=float("inf")
): ):
validators.validate_model_goal_function_compatibility( validators.validate_model_goal_function_compatibility(
self.__class__, model.__class__ self.__class__, model.__class__
) )
self.model = model self.model = model
self.maximizable = maximizable
self.tokenizer = tokenizer self.tokenizer = tokenizer
if not self.tokenizer: if not self.tokenizer:
if hasattr(self.model, "tokenizer"): if hasattr(self.model, "tokenizer"):
@@ -33,20 +36,16 @@ class GoalFunction:
if not hasattr(self.tokenizer, "encode"): if not hasattr(self.tokenizer, "encode"):
raise TypeError("Tokenizer must contain `encode()` method") raise TypeError("Tokenizer must contain `encode()` method")
self.use_cache = use_cache self.use_cache = use_cache
self.num_queries = 0
self.query_budget = query_budget self.query_budget = query_budget
if self.use_cache: if self.use_cache:
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE")) self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
else: else:
self._call_model_cache = None self._call_model_cache = None
def should_skip(self, attacked_text, ground_truth_output): def init_attack_example(self, attacked_text, ground_truth_output):
""" self.initial_attacked_text = attacked_text
Returns whether or not the goal has already been completed for ``attacked_text``\, self.ground_truth_output = ground_truth_output
due to misprediction by the model. self.num_queries = 0
"""
model_outputs = self._call_model([attacked_text])
return self._is_goal_complete(model_outputs[0], ground_truth_output)
def get_output(self, attacked_text): def get_output(self, attacked_text):
""" """
@@ -54,16 +53,16 @@ class GoalFunction:
""" """
return self._get_displayed_output(self._call_model([attacked_text])[0]) return self._get_displayed_output(self._call_model([attacked_text])[0])
def get_result(self, attacked_text, ground_truth_output): def get_result(self, attacked_text):
""" """
A helper method that queries `self.get_results` with a single A helper method that queries `self.get_results` with a single
``AttackedText`` object. ``AttackedText`` object.
""" """
results, search_over = self.get_results([attacked_text], ground_truth_output) results, search_over = self.get_results([attacked_text])
result = results[0] if len(results) else None result = results[0] if len(results) else None
return result, search_over return result, search_over
def get_results(self, attacked_text_list, ground_truth_output): def get_results(self, attacked_text_list):
""" """
For each attacked_text object in attacked_text_list, returns a result For each attacked_text object in attacked_text_list, returns a result
consisting of whether or not the goal has been achieved, the output for consisting of whether or not the goal has been achieved, the output for
@@ -78,23 +77,32 @@ class GoalFunction:
model_outputs = self._call_model(attacked_text_list) model_outputs = self._call_model(attacked_text_list)
for attacked_text, raw_output in zip(attacked_text_list, model_outputs): for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
displayed_output = self._get_displayed_output(raw_output) displayed_output = self._get_displayed_output(raw_output)
succeeded = self._is_goal_complete(raw_output, ground_truth_output) goal_status = self._get_goal_status(raw_output)
goal_function_score = self._get_score(raw_output, ground_truth_output) goal_function_score = self._get_score(raw_output)
results.append( results.append(
self._goal_function_result_type()( self._goal_function_result_type()(
attacked_text, attacked_text,
raw_output, raw_output,
displayed_output, displayed_output,
succeeded, goal_status,
goal_function_score, goal_function_score,
self.num_queries,
self.ground_truth_output,
) )
) )
return results, self.num_queries == self.query_budget return results, self.num_queries == self.query_budget
def _is_goal_complete(self, model_output, ground_truth_output): def _get_goal_status(self, model_output):
if self.maximizable:
return GoalFunctionResultStatus.MAXIMIZING
if self._is_goal_complete(model_output):
return GoalFunctionResultStatus.SUCCEEDED
return GoalFunctionResultStatus.SEARCHING
def _is_goal_complete(self, model_output):
raise NotImplementedError() raise NotImplementedError()
def _get_score(self, model_output, ground_truth_output): def _get_score(self, model_output):
raise NotImplementedError() raise NotImplementedError()
def _get_displayed_output(self, raw_output): def _get_displayed_output(self, raw_output):

View File

@@ -14,15 +14,15 @@ class NonOverlappingOutput(TextToTextGoalFunction):
Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3). Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3).
""" """
def _is_goal_complete(self, model_output, ground_truth_output): def _is_goal_complete(self, model_output):
return self._get_score(model_output, ground_truth_output) == 1.0 return self._get_score(model_output, self.ground_truth_output) == 1.0
def _get_score(self, model_output, ground_truth_output): def _get_score(self, model_output):
num_words_diff = word_difference_score(model_output, ground_truth_output) num_words_diff = word_difference_score(model_output, self.ground_truth_output)
if num_words_diff == 0: if num_words_diff == 0:
return 0.0 return 0.0
else: else:
return num_words_diff / len(get_words_cached(ground_truth_output)) return num_words_diff / len(get_words_cached(self.ground_truth_output))
@functools.lru_cache(maxsize=2 ** 12) @functools.lru_cache(maxsize=2 ** 12)

View File

@@ -9,9 +9,6 @@ class TextToTextGoalFunction(GoalFunction):
original_output: the original output of the model original_output: the original output of the model
""" """
def __init__(self, model):
super().__init__(model)
def _goal_function_result_type(self): def _goal_function_result_type(self):
""" Returns the class of this goal function's results. """ """ Returns the class of this goal function's results. """
return TextToTextGoalFunctionResult return TextToTextGoalFunctionResult

View File

@@ -20,9 +20,8 @@ class CSVLogger(Logger):
self._flushed = True self._flushed = True
def log_attack_result(self, result): def log_attack_result(self, result):
if isinstance(result, FailedAttackResult):
return
original_text, perturbed_text = result.diff_color(self.color_method) original_text, perturbed_text = result.diff_color(self.color_method)
result_type = result.__class__.__name__[:-12]
row = { row = {
"original_text": original_text, "original_text": original_text,
"perturbed_text": perturbed_text, "perturbed_text": perturbed_text,
@@ -30,7 +29,9 @@ class CSVLogger(Logger):
"perturbed_score": result.perturbed_result.score, "perturbed_score": result.perturbed_result.score,
"original_output": result.original_result.output, "original_output": result.original_result.output,
"perturbed_output": result.perturbed_result.output, "perturbed_output": result.perturbed_result.output,
"ground_truth_output": result.original_result.ground_truth_output,
"num_queries": result.num_queries, "num_queries": result.num_queries,
"result_type": result_type
} }
self.df = self.df.append(row, ignore_index=True) self.df = self.df.append(row, ignore_index=True)
self._flushed = False self._flushed = False

View File

@@ -1,5 +1,6 @@
import numpy as np import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod from textattack.search_methods import SearchMethod
@@ -21,7 +22,7 @@ class BeamSearch(SearchMethod):
def _perform_search(self, initial_result): def _perform_search(self, initial_result):
beam = [initial_result.attacked_text] beam = [initial_result.attacked_text]
best_result = initial_result best_result = initial_result
while not best_result.succeeded: while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
potential_next_beam = [] potential_next_beam = []
for text in beam: for text in beam:
transformations = self.get_transformations( transformations = self.get_transformations(
@@ -33,7 +34,7 @@ class BeamSearch(SearchMethod):
# If we did not find any possible perturbations, give up. # If we did not find any possible perturbations, give up.
return best_result return best_result
results, search_over = self.get_goal_results( results, search_over = self.get_goal_results(
potential_next_beam, initial_result.output potential_next_beam
) )
scores = np.array([r.score for r in results]) scores = np.array([r.score for r in results])
best_result = results[scores.argmax()] best_result = results[scores.argmax()]

View File

@@ -10,6 +10,7 @@ from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps from textattack.shared.validators import transformation_consists_of_word_swaps
@@ -52,12 +53,12 @@ class GeneticAlgorithm(SearchMethod):
if not len(transformations): if not len(transformations):
return False return False
orig_result, self.search_over = self.get_goal_results( orig_result, self.search_over = self.get_goal_results(
[pop_member.attacked_text], self.correct_output [pop_member.attacked_text]
) )
if self.search_over: if self.search_over:
return False return False
new_x_results, self.search_over = self.get_goal_results( new_x_results, self.search_over = self.get_goal_results(
transformations, self.correct_output transformations
) )
new_x_scores = torch.Tensor([r.score for r in new_x_results]) new_x_scores = torch.Tensor([r.score for r in new_x_results])
new_x_scores = new_x_scores - orig_result[0].score new_x_scores = new_x_scores - orig_result[0].score
@@ -157,7 +158,7 @@ class GeneticAlgorithm(SearchMethod):
cur_score = initial_result.score cur_score = initial_result.score
for i in range(self.max_iters): for i in range(self.max_iters):
pop_results, self.search_over = self.get_goal_results( pop_results, self.search_over = self.get_goal_results(
[pm.attacked_text for pm in pop], self.correct_output [pm.attacked_text for pm in pop]
) )
if self.search_over: if self.search_over:
if not len(pop_results): if not len(pop_results):
@@ -171,7 +172,7 @@ class GeneticAlgorithm(SearchMethod):
logits = ((-pop_scores) / self.temp).exp() logits = ((-pop_scores) / self.temp).exp()
select_probs = (logits / logits.sum()).cpu().numpy() select_probs = (logits / logits.sum()).cpu().numpy()
if pop[0].result.succeeded: if pop[0].result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return pop[0].result return pop[0].result
if pop[0].result.score > cur_score: if pop[0].result.score > cur_score:

View File

@@ -9,6 +9,7 @@ See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler.
import numpy as np import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps from textattack.shared.validators import transformation_consists_of_word_swaps
@@ -33,7 +34,7 @@ class GreedyWordSwapWIR(SearchMethod):
ranks in order of descending score. ranks in order of descending score.
""" """
leave_one_results, search_over = self.get_goal_results( leave_one_results, search_over = self.get_goal_results(
texts, initial_result.output texts
) )
leave_one_scores = np.array([result.score for result in leave_one_results]) leave_one_scores = np.array([result.score for result in leave_one_results])
return leave_one_scores, search_over return leave_one_scores, search_over
@@ -78,7 +79,7 @@ class GreedyWordSwapWIR(SearchMethod):
if len(transformed_text_candidates) == 0: if len(transformed_text_candidates) == 0:
continue continue
results, search_over = self.get_goal_results( results, search_over = self.get_goal_results(
transformed_text_candidates, initial_result.output transformed_text_candidates
) )
results = sorted(results, key=lambda x: -x.score) results = sorted(results, key=lambda x: -x.score)
# Skip swaps which don't improve the score # Skip swaps which don't improve the score
@@ -87,12 +88,12 @@ class GreedyWordSwapWIR(SearchMethod):
else: else:
continue continue
# If we succeeded, return the index with best similarity. # If we succeeded, return the index with best similarity.
if cur_result.succeeded: if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
best_result = cur_result best_result = cur_result
# @TODO: Use vectorwise operations # @TODO: Use vectorwise operations
max_similarity = -float("inf") max_similarity = -float("inf")
for result in results: for result in results:
if not result.succeeded: if result.goal_status != GoalFunctionResultStatus.SUCCEEDED:
break break
candidate = result.attacked_text candidate = result.attacked_text
try: try:

View File

@@ -6,10 +6,12 @@ import numpy as np
import textattack import textattack
from textattack.attack_results import ( from textattack.attack_results import (
MaximizedAttackResult,
FailedAttackResult, FailedAttackResult,
SkippedAttackResult, SkippedAttackResult,
SuccessfulAttackResult, SuccessfulAttackResult,
) )
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.shared import AttackedText, utils from textattack.shared import AttackedText, utils
@@ -170,17 +172,24 @@ class Attack:
initial_result: The initial ``GoalFunctionResult`` from which to perturb. initial_result: The initial ``GoalFunctionResult`` from which to perturb.
Returns: Returns:
Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``. Either a ``SuccessfulAttackResult``, ``FailedAttackResult``,
or ``MaximizedAttackResult``.
""" """
final_result = self.search_method(initial_result) final_result = self.search_method(initial_result)
if final_result.succeeded: if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult( return SuccessfulAttackResult(
initial_result, final_result, self.goal_function.num_queries initial_result, final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
return FailedAttackResult(
initial_result, final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(
initial_result, final_result,
) )
else: else:
return FailedAttackResult( raise ValueError(f'Unrecognized goal status {final_result.goal_status}')
initial_result, final_result, self.goal_function.num_queries
)
def _get_examples_from_dataset(self, dataset, indices=None): def _get_examples_from_dataset(self, dataset, indices=None):
""" """
@@ -212,14 +221,10 @@ class Attack:
attacked_text = AttackedText( attacked_text = AttackedText(
text, attack_attrs={"label_names": label_names} text, attack_attrs={"label_names": label_names}
) )
self.goal_function.num_queries = 0 self.goal_function.init_attack_example(attacked_text, ground_truth_output)
goal_function_result, _ = self.goal_function.get_result( goal_function_result, _ = self.goal_function.get_result(
attacked_text, ground_truth_output attacked_text
) )
if goal_function_result.succeeded:
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = ground_truth_output
yield goal_function_result yield goal_function_result
except IndexError: except IndexError:
@@ -240,7 +245,7 @@ class Attack:
examples = self._get_examples_from_dataset(dataset, indices=indices) examples = self._get_examples_from_dataset(dataset, indices=indices)
for goal_function_result in examples: for goal_function_result in examples:
if goal_function_result.succeeded: if goal_function_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
yield SkippedAttackResult(goal_function_result) yield SkippedAttackResult(goal_function_result)
else: else:
result = self.attack_one(goal_function_result) result = self.attack_one(goal_function_result)

View File

@@ -5,6 +5,7 @@ import pickle
import time import time
from textattack.attack_results import ( from textattack.attack_results import (
MaximizedAttackResult,
FailedAttackResult, FailedAttackResult,
SkippedAttackResult, SkippedAttackResult,
SuccessfulAttackResult, SuccessfulAttackResult,
@@ -101,6 +102,11 @@ class Checkpoint:
f"(Number of failed attacks): {self.num_failed_attacks}", 2 f"(Number of failed attacks): {self.num_failed_attacks}", 2
) )
) )
breakdown_lines.append(
utils.add_indent(
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2
)
)
breakdown_lines.append( breakdown_lines.append(
utils.add_indent( utils.add_indent(
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2 f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
@@ -140,6 +146,12 @@ class Checkpoint:
isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results
) )
@property
def num_maximized_attacks(self):
return sum(
isinstance(r, MaximizedAttackResult) for r in self.log_manager.results
)
@property @property
def num_remaining_attacks(self): def num_remaining_attacks(self):
if self.args.attack_n: if self.args.attack_n: