mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #350 from a1noack/fix_query_count
fix model query count for all search methods
This commit is contained in:
@@ -101,9 +101,6 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
|
||||
new_results, self._search_over = self.get_goal_results(transformed_texts)
|
||||
|
||||
if self._search_over:
|
||||
break
|
||||
|
||||
diff_scores = (
|
||||
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
|
||||
)
|
||||
@@ -119,6 +116,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
|
||||
word_select_prob_weights[idx] = 0
|
||||
iterations += 1
|
||||
|
||||
if self._search_over:
|
||||
break
|
||||
|
||||
return pop_member
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -37,8 +37,10 @@ class SearchMethod(ABC):
|
||||
raise AttributeError(
|
||||
"Search Method must have access to get_model method if it is a white-box method"
|
||||
)
|
||||
|
||||
return self._perform_search(initial_result)
|
||||
result = self._perform_search(initial_result)
|
||||
# ensure that the number of queries for this GoalFunctionResult is up-to-date
|
||||
result.num_queries = self.goal_function.num_queries
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _perform_search(self, initial_result):
|
||||
|
||||
@@ -108,6 +108,8 @@ class Attack:
|
||||
|
||||
# Give search method access to functions for getting transformations and evaluating them
|
||||
self.search_method.get_transformations = self.get_transformations
|
||||
# Give search method access to self.goal_function for model query count, etc.
|
||||
self.search_method.goal_function = self.goal_function
|
||||
# The search method only needs access to the first argument. The second is only used
|
||||
# by the attack class when checking whether to skip the sample
|
||||
self.search_method.get_goal_results = (
|
||||
|
||||
Reference in New Issue
Block a user