1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

merge from master

This commit is contained in:
Jin Yong Yoo
2020-07-06 10:50:17 -04:00
110 changed files with 1622 additions and 648 deletions

View File

@@ -7,9 +7,11 @@ import numpy as np
import textattack
from textattack.attack_results import (
FailedAttackResult,
MaximizedAttackResult,
SkippedAttackResult,
SuccessfulAttackResult,
)
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.shared import AttackedText, utils
@@ -56,7 +58,7 @@ class Attack:
self.transformation
):
raise ValueError(
"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
)
self.constraints = []
@@ -74,7 +76,12 @@ class Attack:
# Give search method access to functions for getting transformations and evaluating them
self.search_method.get_transformations = self.get_transformations
self.search_method.get_goal_results = self.goal_function.get_results
# The search method only needs access to the first argument. The second is only used
# by the attack class when checking whether to skip the sample
self.search_method.get_goal_results = lambda attacked_text_list: self.goal_function.get_results(
attacked_text_list
)
self.search_method.filter_transformations = self.filter_transformations
def get_transformations(self, current_text, original_text=None, **kwargs):
"""
@@ -102,7 +109,7 @@ class Attack:
**kwargs,
)
)
return self._filter_transformations(
return self.filter_transformations(
transformed_texts, current_text, original_text
)
@@ -138,7 +145,7 @@ class Attack:
self.constraints_cache[(current_text, filtered_text)] = True
return filtered_texts
def _filter_transformations(
def filter_transformations(
self, transformed_texts, current_text, original_text=None
):
"""
@@ -180,17 +187,18 @@ class Attack:
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
Returns:
Either a ``SuccessfulAttackResult`` or ``FailedAttackResult``.
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
or ``MaximizedAttackResult``.
"""
final_result = self.search_method(initial_result)
if final_result.succeeded:
return SuccessfulAttackResult(
initial_result, final_result, self.goal_function.num_queries
)
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult(initial_result, final_result,)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
return FailedAttackResult(initial_result, final_result,)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(initial_result, final_result,)
else:
return FailedAttackResult(
initial_result, final_result, self.goal_function.num_queries
)
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
def _get_examples_from_dataset(self, dataset, indices=None):
"""
@@ -222,14 +230,9 @@ class Attack:
attacked_text = AttackedText(
text, attack_attrs={"label_names": label_names}
)
self.goal_function.num_queries = 0
goal_function_result, _ = self.goal_function.get_result(
goal_function_result, _ = self.goal_function.init_attack_example(
attacked_text, ground_truth_output
)
if goal_function_result.succeeded:
# Store the true output on the goal function so that the
# SkippedAttackResult has the correct output, not the incorrect.
goal_function_result.output = ground_truth_output
yield goal_function_result
except IndexError:
@@ -250,7 +253,7 @@ class Attack:
examples = self._get_examples_from_dataset(dataset, indices=indices)
for goal_function_result in examples:
if goal_function_result.succeeded:
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
yield SkippedAttackResult(goal_function_result)
else:
result = self.attack_one(goal_function_result)