mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
allow maximization goal functions
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .maximized_attack_result import MaximizedAttackResult
|
||||
from .failed_attack_result import FailedAttackResult
|
||||
from .skipped_attack_result import SkippedAttackResult
|
||||
from .successful_attack_result import SuccessfulAttackResult
|
||||
|
||||
@@ -13,7 +13,7 @@ class AttackResult:
|
||||
perturbed text. May or may not have been successful.
|
||||
"""
|
||||
|
||||
def __init__(self, original_result, perturbed_result, num_queries=0):
|
||||
def __init__(self, original_result, perturbed_result):
|
||||
if original_result is None:
|
||||
raise ValueError("Attack original result cannot be None")
|
||||
elif not isinstance(original_result, GoalFunctionResult):
|
||||
@@ -27,7 +27,7 @@ class AttackResult:
|
||||
|
||||
self.original_result = original_result
|
||||
self.perturbed_result = perturbed_result
|
||||
self.num_queries = num_queries
|
||||
self.num_queries = perturbed_result.num_queries
|
||||
|
||||
# We don't want the AttackedText attributes sticking around clogging up
|
||||
# space on our devices. Delete them here, if they're still present,
|
||||
|
||||
@@ -6,9 +6,9 @@ from .attack_result import AttackResult
|
||||
class FailedAttackResult(AttackResult):
|
||||
"""The result of a failed attack."""
|
||||
|
||||
def __init__(self, original_result, perturbed_result=None, num_queries=0):
|
||||
def __init__(self, original_result, perturbed_result=None):
|
||||
perturbed_result = perturbed_result or original_result
|
||||
super().__init__(original_result, perturbed_result, num_queries)
|
||||
super().__init__(original_result, perturbed_result)
|
||||
|
||||
def str_lines(self, color_method=None):
|
||||
lines = (
|
||||
|
||||
5
textattack/attack_results/maximized_attack_result.py
Normal file
5
textattack/attack_results/maximized_attack_result.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .attack_result import AttackResult
|
||||
|
||||
|
||||
class MaximizedAttackResult(AttackResult):
|
||||
""" The result of a successful attack. """
|
||||
@@ -129,7 +129,7 @@ def run(args):
|
||||
pbar.update()
|
||||
num_results += 1
|
||||
|
||||
if type(result) == textattack.attack_results.SuccessfulAttackResult:
|
||||
if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
|
||||
num_successes += 1
|
||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||
num_failures += 1
|
||||
|
||||
@@ -114,7 +114,7 @@ def run(args):
|
||||
|
||||
num_results += 1
|
||||
|
||||
if type(result) == textattack.attack_results.SuccessfulAttackResult:
|
||||
if type(result) == textattack.attack_results.SuccessfulAttackResult or type(result) == textattack.attack_results.MaximizedAttackResult:
|
||||
num_successes += 1
|
||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||
num_failures += 1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .goal_function_result import GoalFunctionResult
|
||||
from .goal_function_result import GoalFunctionResult, GoalFunctionResultStatus
|
||||
|
||||
from .classification_goal_function_result import ClassificationGoalFunctionResult
|
||||
from .text_to_text_goal_function_result import TextToTextGoalFunctionResult
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import torch
|
||||
|
||||
class GoalFunctionResultStatus:
|
||||
SUCCEEDED = 0
|
||||
SEARCHING = 1 # In process of searching for a success
|
||||
MAXIMIZING = 2
|
||||
|
||||
class GoalFunctionResult:
|
||||
"""
|
||||
@@ -8,16 +12,20 @@ class GoalFunctionResult:
|
||||
Args:
|
||||
attacked_text: The sequence that was evaluated.
|
||||
output: The display-friendly output.
|
||||
succeeded: Whether the goal has been achieved.
|
||||
goal_status: The ``GoalFunctionResultStatus`` representing the status of the achievement of the goal.
|
||||
score: A score representing how close the model is to achieving its goal.
|
||||
num_queries: How many model queries have been used
|
||||
ground_truth_output: The ground truth output
|
||||
"""
|
||||
|
||||
def __init__(self, attacked_text, raw_output, output, succeeded, score):
|
||||
def __init__(self, attacked_text, raw_output, output, goal_status, score, num_queries, ground_truth_output):
|
||||
self.attacked_text = attacked_text
|
||||
self.raw_output = raw_output
|
||||
self.output = output
|
||||
self.score = score
|
||||
self.succeeded = succeeded
|
||||
self.goal_status = goal_status
|
||||
self.num_queries = num_queries
|
||||
self.ground_truth_output = ground_truth_output
|
||||
|
||||
if isinstance(self.raw_output, torch.Tensor):
|
||||
self.raw_output = self.raw_output.cpu()
|
||||
@@ -25,9 +33,6 @@ class GoalFunctionResult:
|
||||
if isinstance(self.score, torch.Tensor):
|
||||
self.score = self.score.item()
|
||||
|
||||
if isinstance(self.succeeded, torch.Tensor):
|
||||
self.succeeded = self.succeeded.item()
|
||||
|
||||
def get_text_color_input(self):
|
||||
""" A string representing the color this result's changed
|
||||
portion should be if it represents the original input.
|
||||
|
||||
@@ -7,16 +7,16 @@ class TargetedClassification(ClassificationGoalFunction):
|
||||
score of the target label until it is the predicted label.
|
||||
"""
|
||||
|
||||
def __init__(self, model, target_class=0):
|
||||
super().__init__(model)
|
||||
def __init__(self, *args, target_class=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.target_class = target_class
|
||||
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
def _is_goal_complete(self, model_output):
|
||||
return (
|
||||
self.target_class == model_output.argmax()
|
||||
) or ground_truth_output == self.target_class
|
||||
) or self.ground_truth_output == self.target_class
|
||||
|
||||
def _get_score(self, model_output, _):
|
||||
def _get_score(self, model_output):
|
||||
if self.target_class < 0 or self.target_class >= len(model_output):
|
||||
raise ValueError(
|
||||
f"target class set to {self.target_class} with {len(model_output)} classes."
|
||||
|
||||
@@ -16,23 +16,23 @@ class UntargetedClassification(ClassificationGoalFunction):
|
||||
self.target_max_score = target_max_score
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
def _is_goal_complete(self, model_output):
|
||||
if self.target_max_score:
|
||||
return model_output[ground_truth_output] < self.target_max_score
|
||||
elif (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
||||
return abs(ground_truth_output - model_output.item()) >= (
|
||||
return model_output[self.ground_truth_output] < self.target_max_score
|
||||
elif (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
|
||||
return abs(self.ground_truth_output - model_output.item()) >= (
|
||||
self.target_max_score or 0.5
|
||||
)
|
||||
else:
|
||||
return model_output.argmax() != ground_truth_output
|
||||
return model_output.argmax() != self.ground_truth_output
|
||||
|
||||
def _get_score(self, model_output, ground_truth_output):
|
||||
def _get_score(self, model_output):
|
||||
# If the model outputs a single number and the ground truth output is
|
||||
# a float, we assume that this is a regression task.
|
||||
if (model_output.numel() is 1) and isinstance(ground_truth_output, float):
|
||||
return abs(model_output.item() - ground_truth_output)
|
||||
if (model_output.numel() is 1) and isinstance(self.ground_truth_output, float):
|
||||
return abs(model_output.item() - self.ground_truth_output)
|
||||
else:
|
||||
return 1 - model_output[ground_truth_output]
|
||||
return 1 - model_output[self.ground_truth_output]
|
||||
|
||||
def _get_displayed_output(self, raw_output):
|
||||
return int(raw_output.argmax())
|
||||
|
||||
@@ -4,26 +4,29 @@ import lru
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.goal_function_results.goal_function_result import GoalFunctionResultStatus
|
||||
from textattack.shared import utils, validators
|
||||
from textattack.shared.utils import batch_model_predict, default_class_repr
|
||||
|
||||
|
||||
class GoalFunction:
|
||||
"""
|
||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
maximizable: Whether the goal function is maximizable, as opposed to a boolean result
|
||||
of success or failure.
|
||||
query_budget: The maximum number of model queries allowed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
|
||||
self, model, maximizable=False, tokenizer=None, use_cache=True, query_budget=float("inf")
|
||||
):
|
||||
validators.validate_model_goal_function_compatibility(
|
||||
self.__class__, model.__class__
|
||||
)
|
||||
self.model = model
|
||||
self.maximizable = maximizable
|
||||
self.tokenizer = tokenizer
|
||||
if not self.tokenizer:
|
||||
if hasattr(self.model, "tokenizer"):
|
||||
@@ -33,20 +36,16 @@ class GoalFunction:
|
||||
if not hasattr(self.tokenizer, "encode"):
|
||||
raise TypeError("Tokenizer must contain `encode()` method")
|
||||
self.use_cache = use_cache
|
||||
self.num_queries = 0
|
||||
self.query_budget = query_budget
|
||||
if self.use_cache:
|
||||
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``\,
|
||||
due to misprediction by the model.
|
||||
"""
|
||||
model_outputs = self._call_model([attacked_text])
|
||||
return self._is_goal_complete(model_outputs[0], ground_truth_output)
|
||||
def init_attack_example(self, attacked_text, ground_truth_output):
|
||||
self.initial_attacked_text = attacked_text
|
||||
self.ground_truth_output = ground_truth_output
|
||||
self.num_queries = 0
|
||||
|
||||
def get_output(self, attacked_text):
|
||||
"""
|
||||
@@ -54,16 +53,16 @@ class GoalFunction:
|
||||
"""
|
||||
return self._get_displayed_output(self._call_model([attacked_text])[0])
|
||||
|
||||
def get_result(self, attacked_text, ground_truth_output):
|
||||
def get_result(self, attacked_text):
|
||||
"""
|
||||
A helper method that queries `self.get_results` with a single
|
||||
``AttackedText`` object.
|
||||
"""
|
||||
results, search_over = self.get_results([attacked_text], ground_truth_output)
|
||||
results, search_over = self.get_results([attacked_text])
|
||||
result = results[0] if len(results) else None
|
||||
return result, search_over
|
||||
|
||||
def get_results(self, attacked_text_list, ground_truth_output):
|
||||
def get_results(self, attacked_text_list):
|
||||
"""
|
||||
For each attacked_text object in attacked_text_list, returns a result
|
||||
consisting of whether or not the goal has been achieved, the output for
|
||||
@@ -78,23 +77,32 @@ class GoalFunction:
|
||||
model_outputs = self._call_model(attacked_text_list)
|
||||
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
|
||||
displayed_output = self._get_displayed_output(raw_output)
|
||||
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
|
||||
goal_function_score = self._get_score(raw_output, ground_truth_output)
|
||||
goal_status = self._get_goal_status(raw_output)
|
||||
goal_function_score = self._get_score(raw_output)
|
||||
results.append(
|
||||
self._goal_function_result_type()(
|
||||
attacked_text,
|
||||
raw_output,
|
||||
displayed_output,
|
||||
succeeded,
|
||||
goal_status,
|
||||
goal_function_score,
|
||||
self.num_queries,
|
||||
self.ground_truth_output,
|
||||
)
|
||||
)
|
||||
return results, self.num_queries == self.query_budget
|
||||
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
def _get_goal_status(self, model_output):
|
||||
if self.maximizable:
|
||||
return GoalFunctionResultStatus.MAXIMIZING
|
||||
if self._is_goal_complete(model_output):
|
||||
return GoalFunctionResultStatus.SUCCEEDED
|
||||
return GoalFunctionResultStatus.SEARCHING
|
||||
|
||||
def _is_goal_complete(self, model_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_score(self, model_output, ground_truth_output):
|
||||
def _get_score(self, model_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_displayed_output(self, raw_output):
|
||||
|
||||
@@ -14,15 +14,15 @@ class NonOverlappingOutput(TextToTextGoalFunction):
|
||||
Defined in seq2sick (https://arxiv.org/pdf/1803.01128.pdf), equation (3).
|
||||
"""
|
||||
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
return self._get_score(model_output, ground_truth_output) == 1.0
|
||||
def _is_goal_complete(self, model_output):
|
||||
return self._get_score(model_output, self.ground_truth_output) == 1.0
|
||||
|
||||
def _get_score(self, model_output, ground_truth_output):
|
||||
num_words_diff = word_difference_score(model_output, ground_truth_output)
|
||||
def _get_score(self, model_output):
|
||||
num_words_diff = word_difference_score(model_output, self.ground_truth_output)
|
||||
if num_words_diff == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return num_words_diff / len(get_words_cached(ground_truth_output))
|
||||
return num_words_diff / len(get_words_cached(self.ground_truth_output))
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=2 ** 12)
|
||||
|
||||
@@ -9,9 +9,6 @@ class TextToTextGoalFunction(GoalFunction):
|
||||
original_output: the original output of the model
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
|
||||
def _goal_function_result_type(self):
|
||||
""" Returns the class of this goal function's results. """
|
||||
return TextToTextGoalFunctionResult
|
||||
|
||||
@@ -20,9 +20,8 @@ class CSVLogger(Logger):
|
||||
self._flushed = True
|
||||
|
||||
def log_attack_result(self, result):
|
||||
if isinstance(result, FailedAttackResult):
|
||||
return
|
||||
original_text, perturbed_text = result.diff_color(self.color_method)
|
||||
result_type = result.__class__.__name__[:-12]
|
||||
row = {
|
||||
"original_text": original_text,
|
||||
"perturbed_text": perturbed_text,
|
||||
@@ -30,7 +29,9 @@ class CSVLogger(Logger):
|
||||
"perturbed_score": result.perturbed_result.score,
|
||||
"original_output": result.original_result.output,
|
||||
"perturbed_output": result.perturbed_result.output,
|
||||
"ground_truth_output": result.original_result.ground_truth_output,
|
||||
"num_queries": result.num_queries,
|
||||
"result_type": result_type
|
||||
}
|
||||
self.df = self.df.append(row, ignore_index=True)
|
||||
self._flushed = False
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.search_methods import SearchMethod
|
||||
|
||||
|
||||
@@ -21,7 +22,7 @@ class BeamSearch(SearchMethod):
|
||||
def _perform_search(self, initial_result):
|
||||
beam = [initial_result.attacked_text]
|
||||
best_result = initial_result
|
||||
while not best_result.succeeded:
|
||||
while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
potential_next_beam = []
|
||||
for text in beam:
|
||||
transformations = self.get_transformations(
|
||||
@@ -33,7 +34,7 @@ class BeamSearch(SearchMethod):
|
||||
# If we did not find any possible perturbations, give up.
|
||||
return best_result
|
||||
results, search_over = self.get_goal_results(
|
||||
potential_next_beam, initial_result.output
|
||||
potential_next_beam
|
||||
)
|
||||
scores = np.array([r.score for r in results])
|
||||
best_result = results[scores.argmax()]
|
||||
|
||||
@@ -10,6 +10,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
@@ -52,12 +53,12 @@ class GeneticAlgorithm(SearchMethod):
|
||||
if not len(transformations):
|
||||
return False
|
||||
orig_result, self.search_over = self.get_goal_results(
|
||||
[pop_member.attacked_text], self.correct_output
|
||||
[pop_member.attacked_text]
|
||||
)
|
||||
if self.search_over:
|
||||
return False
|
||||
new_x_results, self.search_over = self.get_goal_results(
|
||||
transformations, self.correct_output
|
||||
transformations
|
||||
)
|
||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||
new_x_scores = new_x_scores - orig_result[0].score
|
||||
@@ -157,7 +158,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
cur_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results, self.search_over = self.get_goal_results(
|
||||
[pm.attacked_text for pm in pop], self.correct_output
|
||||
[pm.attacked_text for pm in pop]
|
||||
)
|
||||
if self.search_over:
|
||||
if not len(pop_results):
|
||||
@@ -171,7 +172,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
logits = ((-pop_scores) / self.temp).exp()
|
||||
select_probs = (logits / logits.sum()).cpu().numpy()
|
||||
|
||||
if pop[0].result.succeeded:
|
||||
if pop[0].result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
return pop[0].result
|
||||
|
||||
if pop[0].result.score > cur_score:
|
||||
|
||||
@@ -9,6 +9,7 @@ See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
@@ -33,7 +34,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
ranks in order of descending score.
|
||||
"""
|
||||
leave_one_results, search_over = self.get_goal_results(
|
||||
texts, initial_result.output
|
||||
texts
|
||||
)
|
||||
leave_one_scores = np.array([result.score for result in leave_one_results])
|
||||
return leave_one_scores, search_over
|
||||
@@ -78,7 +79,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
if len(transformed_text_candidates) == 0:
|
||||
continue
|
||||
results, search_over = self.get_goal_results(
|
||||
transformed_text_candidates, initial_result.output
|
||||
transformed_text_candidates
|
||||
)
|
||||
results = sorted(results, key=lambda x: -x.score)
|
||||
# Skip swaps which don't improve the score
|
||||
@@ -87,12 +88,12 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
else:
|
||||
continue
|
||||
# If we succeeded, return the index with best similarity.
|
||||
if cur_result.succeeded:
|
||||
if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
best_result = cur_result
|
||||
# @TODO: Use vectorwise operations
|
||||
max_similarity = -float("inf")
|
||||
for result in results:
|
||||
if not result.succeeded:
|
||||
if result.goal_status != GoalFunctionResultStatus.SUCCEEDED:
|
||||
break
|
||||
candidate = result.attacked_text
|
||||
try:
|
||||
|
||||
@@ -6,10 +6,12 @@ import numpy as np
|
||||
|
||||
import textattack
|
||||
from textattack.attack_results import (
|
||||
MaximizedAttackResult,
|
||||
FailedAttackResult,
|
||||
SkippedAttackResult,
|
||||
SuccessfulAttackResult,
|
||||
)
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.shared import AttackedText, utils
|
||||
|
||||
|
||||
@@ -170,17 +172,24 @@ class Attack:
|
||||
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
|
||||
|
||||
Returns:
|
||||
Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``.
|
||||
Either a ``SuccessfulAttackResult``, ``FailedAttackResult``,
|
||||
or ``MaximizedAttackResult``.
|
||||
"""
|
||||
final_result = self.search_method(initial_result)
|
||||
if final_result.succeeded:
|
||||
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
return SuccessfulAttackResult(
|
||||
initial_result, final_result, self.goal_function.num_queries
|
||||
initial_result, final_result,
|
||||
)
|
||||
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
|
||||
return FailedAttackResult(
|
||||
initial_result, final_result,
|
||||
)
|
||||
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
|
||||
return MaximizedAttackResult(
|
||||
initial_result, final_result,
|
||||
)
|
||||
else:
|
||||
return FailedAttackResult(
|
||||
initial_result, final_result, self.goal_function.num_queries
|
||||
)
|
||||
raise ValueError(f'Unrecognized goal status {final_result.goal_status}')
|
||||
|
||||
def _get_examples_from_dataset(self, dataset, indices=None):
|
||||
"""
|
||||
@@ -212,14 +221,10 @@ class Attack:
|
||||
attacked_text = AttackedText(
|
||||
text, attack_attrs={"label_names": label_names}
|
||||
)
|
||||
self.goal_function.num_queries = 0
|
||||
self.goal_function.init_attack_example(attacked_text, ground_truth_output)
|
||||
goal_function_result, _ = self.goal_function.get_result(
|
||||
attacked_text, ground_truth_output
|
||||
attacked_text
|
||||
)
|
||||
if goal_function_result.succeeded:
|
||||
# Store the true output on the goal function so that the
|
||||
# SkippedAttackResult has the correct output, not the incorrect.
|
||||
goal_function_result.output = ground_truth_output
|
||||
yield goal_function_result
|
||||
|
||||
except IndexError:
|
||||
@@ -240,7 +245,7 @@ class Attack:
|
||||
examples = self._get_examples_from_dataset(dataset, indices=indices)
|
||||
|
||||
for goal_function_result in examples:
|
||||
if goal_function_result.succeeded:
|
||||
if goal_function_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
yield SkippedAttackResult(goal_function_result)
|
||||
else:
|
||||
result = self.attack_one(goal_function_result)
|
||||
|
||||
@@ -5,6 +5,7 @@ import pickle
|
||||
import time
|
||||
|
||||
from textattack.attack_results import (
|
||||
MaximizedAttackResult,
|
||||
FailedAttackResult,
|
||||
SkippedAttackResult,
|
||||
SuccessfulAttackResult,
|
||||
@@ -101,6 +102,11 @@ class Checkpoint:
|
||||
f"(Number of failed attacks): {self.num_failed_attacks}", 2
|
||||
)
|
||||
)
|
||||
breakdown_lines.append(
|
||||
utils.add_indent(
|
||||
f"(Number of maximized attacks): {self.num_maximized_attacks}", 2
|
||||
)
|
||||
)
|
||||
breakdown_lines.append(
|
||||
utils.add_indent(
|
||||
f"(Number of skipped attacks): {self.num_skipped_attacks}", 2
|
||||
@@ -140,6 +146,12 @@ class Checkpoint:
|
||||
isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results
|
||||
)
|
||||
|
||||
@property
|
||||
def num_maximized_attacks(self):
|
||||
return sum(
|
||||
isinstance(r, MaximizedAttackResult) for r in self.log_manager.results
|
||||
)
|
||||
|
||||
@property
|
||||
def num_remaining_attacks(self):
|
||||
if self.args.attack_n:
|
||||
|
||||
Reference in New Issue
Block a user