1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

format and finish merge

This commit is contained in:
uvafan
2020-06-28 21:07:28 -04:00
parent b30fdcd1da
commit 54632c3c7b
11 changed files with 45 additions and 45 deletions

View File

@@ -6,8 +6,8 @@ import numpy as np
import textattack
from textattack.attack_results import (
MaximizedAttackResult,
FailedAttackResult,
MaximizedAttackResult,
SkippedAttackResult,
SuccessfulAttackResult,
)
@@ -182,19 +182,13 @@ class Attack:
"""
final_result = self.search_method(initial_result)
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult(
initial_result, final_result,
)
return SuccessfulAttackResult(initial_result, final_result,)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
return FailedAttackResult(
initial_result, final_result,
)
return FailedAttackResult(initial_result, final_result,)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(
initial_result, final_result,
)
return MaximizedAttackResult(initial_result, final_result,)
else:
raise ValueError(f'Unrecognized goal status {final_result.goal_status}')
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
def _get_examples_from_dataset(self, dataset, indices=None):
"""
@@ -226,10 +220,10 @@ class Attack:
attacked_text = AttackedText(
text, attack_attrs={"label_names": label_names}
)
self.goal_function.init_attack_example(attacked_text, ground_truth_output)
goal_function_result, _ = self.goal_function.get_result(
attacked_text
self.goal_function.init_attack_example(
attacked_text, ground_truth_output
)
goal_function_result, _ = self.goal_function.get_result(attacked_text)
yield goal_function_result
except IndexError: