mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge [Tests pass]
This commit is contained in:
@@ -138,9 +138,9 @@ class Attack:
|
||||
"""
|
||||
final_result = self.search_method(initial_result)
|
||||
if final_result.succeeded:
|
||||
return SuccessfulAttackResult(initial_result, final_result)
|
||||
return SuccessfulAttackResult(initial_result, final_result, self.goal_function.num_queries)
|
||||
else:
|
||||
return FailedAttackResult(initial_result, final_result)
|
||||
return FailedAttackResult(initial_result, final_result, self.goal_function.num_queries)
|
||||
|
||||
def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False,
|
||||
attack_n=False, attack_skippable_examples=False):
|
||||
@@ -172,7 +172,8 @@ class Attack:
|
||||
|
||||
for text, ground_truth_output in dataset:
|
||||
tokenized_text = TokenizedText(text, self.goal_function.tokenizer)
|
||||
goal_function_result = self.goal_function.get_result(tokenized_text, ground_truth_output)
|
||||
self.goal_function.num_queries = 0
|
||||
goal_function_result, _ = self.goal_function.get_result(tokenized_text, ground_truth_output)
|
||||
# We can skip examples for which the goal is already succeeded,
|
||||
# unless `attack_skippable_examples` is True.
|
||||
if (not attack_skippable_examples) and (goal_function_result.succeeded):
|
||||
@@ -209,11 +210,7 @@ class Attack:
|
||||
if was_skipped:
|
||||
yield SkippedAttackResult(goal_function_result)
|
||||
continue
|
||||
# Start query count at 1 since we made a single query to determine
|
||||
# that the prediction was correct.
|
||||
self.goal_function.num_queries = 1
|
||||
result = self.attack_one(goal_function_result)
|
||||
result.num_queries = self.goal_function.num_queries
|
||||
yield result
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
Reference in New Issue
Block a user