1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

allow maximization goal functions

This commit is contained in:
uvafan
2020-06-23 23:33:48 -04:00
parent fe109267a1
commit 0fcfb51b7f
19 changed files with 115 additions and 78 deletions

View File

@@ -6,10 +6,12 @@ import numpy as np
import textattack
from textattack.attack_results import (
MaximizedAttackResult,
FailedAttackResult,
SkippedAttackResult,
SuccessfulAttackResult,
)
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.shared import AttackedText, utils
@@ -170,17 +172,24 @@ class Attack:
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
Returns:
Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``.
Either a ``SuccessfulAttackResult``, ``FailedAttackResult``,
or ``MaximizedAttackResult``.
"""
final_result = self.search_method(initial_result)
if final_result.succeeded:
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult(
initial_result, final_result, self.goal_function.num_queries
initial_result, final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
return FailedAttackResult(
initial_result, final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(
initial_result, final_result,
)
else:
return FailedAttackResult(
initial_result, final_result, self.goal_function.num_queries
)
raise ValueError(f'Unrecognized goal status {final_result.goal_status}')
def _get_examples_from_dataset(self, dataset, indices=None):
"""
@@ -212,14 +221,10 @@ class Attack:
attacked_text = AttackedText(
text, attack_attrs={"label_names": label_names}
)
self.goal_function.num_queries = 0
self.goal_function.init_attack_example(attacked_text, ground_truth_output)
goal_function_result, _ = self.goal_function.get_result(
attacked_text, ground_truth_output
attacked_text
)
if goal_function_result.succeeded:
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = ground_truth_output
yield goal_function_result
except IndexError:
@@ -240,7 +245,7 @@ class Attack:
examples = self._get_examples_from_dataset(dataset, indices=indices)
for goal_function_result in examples:
if goal_function_result.succeeded:
if goal_function_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
yield SkippedAttackResult(goal_function_result)
else:
result = self.attack_one(goal_function_result)