From 05b32b87f0f7baf59a6c05ce665a90ee2886350f Mon Sep 17 00:00:00 2001 From: uvafan Date: Fri, 3 Jul 2020 14:28:03 -0400 Subject: [PATCH] update get_goal_results signature --- README.md | 2 +- textattack/goal_functions/goal_function.py | 7 +++++++ textattack/shared/attack.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7995f723..6a38fa36 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,7 @@ The `attack_one` method in an `Attack` takes as input an `AttackedText`, and out ### Goal Functions -A `GoalFunction` takes as input an `AttackedText` object and the ground truth output, and determines whether the attack has succeeded, returning a `GoalFunctionResult`. +A `GoalFunction` takes as input an `AttackedText` object, scores it, and determines whether the attack has succeeded, returning a `GoalFunctionResult`. ### Constraints diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index 572f316b..07e1e4b1 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -114,6 +114,13 @@ class GoalFunction: ) return results, self.num_queries == self.query_budget + def get_results_from_search_method(self, attacked_text_list): + """ + The search method doesn't have access to the ``check_skip`` argument of + ``get_results``. + """ + return self.get_results(attacked_text_list) + def _get_goal_status(self, model_output, attacked_text, check_skip=False): should_skip = check_skip and self._should_skip(model_output, attacked_text) if should_skip: diff --git a/textattack/shared/attack.py b/textattack/shared/attack.py index ca05cafc..6f390615 100644 --- a/textattack/shared/attack.py +++ b/textattack/shared/attack.py @@ -77,7 +77,7 @@ class Attack: # Give search method access to functions for getting transformations and evaluating them self.search_method.get_transformations = self.get_transformations - self.search_method.get_goal_results = self.goal_function.get_results + self.search_method.get_goal_results = self.goal_function.get_results_from_search_method self.search_method.filter_transformations = self.filter_transformations def get_transformations(self, current_text, original_text=None, **kwargs):