mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add input reduction
This commit is contained in:
@@ -58,7 +58,7 @@ class Attack:
|
||||
self.transformation
|
||||
):
|
||||
raise ValueError(
|
||||
"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
|
||||
f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
|
||||
)
|
||||
|
||||
self.constraints = []
|
||||
@@ -220,10 +220,9 @@ class Attack:
|
||||
attacked_text = AttackedText(
|
||||
text, attack_attrs={"label_names": label_names}
|
||||
)
|
||||
self.goal_function.init_attack_example(
|
||||
goal_function_result, _ = self.goal_function.init_attack_example(
|
||||
attacked_text, ground_truth_output
|
||||
)
|
||||
goal_function_result, _ = self.goal_function.get_result(attacked_text)
|
||||
yield goal_function_result
|
||||
|
||||
except IndexError:
|
||||
@@ -244,7 +243,7 @@ class Attack:
|
||||
examples = self._get_examples_from_dataset(dataset, indices=indices)
|
||||
|
||||
for goal_function_result in examples:
|
||||
if goal_function_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
|
||||
yield SkippedAttackResult(goal_function_result)
|
||||
else:
|
||||
result = self.attack_one(goal_function_result)
|
||||
|
||||
Reference in New Issue
Block a user