1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add query budget

This commit is contained in:
uvafan
2020-06-05 15:40:20 -04:00
parent 20f3769963
commit ee95b82895
8 changed files with 33 additions and 24 deletions

View File

@@ -139,9 +139,9 @@ class Attack:
"""
final_result = self.search_method(initial_result)
if final_result.succeeded:
return SuccessfulAttackResult(initial_result, final_result)
return SuccessfulAttackResult(initial_result, final_result, self.goal_function.num_queries)
else:
return FailedAttackResult(initial_result, final_result)
return FailedAttackResult(initial_result, final_result, self.goal_function.num_queries)
def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,
attack_n=False, attack_skippable_examples=False):
@@ -173,6 +173,7 @@ class Attack:
for text, ground_truth_output in dataset:
tokenized_text = TokenizedText(text, self.tokenizer)
self.goal_function.num_queries = 0
goal_function_result = self.goal_function.get_result(tokenized_text, ground_truth_output)
# We can skip examples for which the goal is already succeeded,
# unless `attack_skippable_examples` is True.
@@ -210,11 +211,7 @@ class Attack:
if was_skipped:
yield SkippedAttackResult(goal_function_result)
continue
# Start query count at 1 since we made a single query to determine
# that the prediction was correct.
self.goal_function.num_queries = 1
result = self.attack_one(goal_function_result)
result.num_queries = self.goal_function.num_queries
yield result
def __repr__(self):