1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #350 from a1noack/fix_query_count

fix model query count for all search methods
This commit is contained in:
Jin Yong Yoo
2020-11-22 00:29:45 -05:00
committed by GitHub
4 changed files with 23 additions and 18 deletions

View File

@@ -56,16 +56,16 @@ hugh grant , who has a good line in charm , has never been more loveable[0
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 1 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 66.67% |
| Average perturbed word %: | 17.34% |
| Average num. words per input: | 15.0 |
| Avg num queries: | 551.67 |
+-------------------------------+--------+
+-------------------------------+---------+
| Attack Results | |
+-------------------------------+---------+
| Number of successful attacks: | 2 |
| Number of failed attacks: | 1 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 33.33% |
| Attack success rate: | 66.67% |
| Average perturbed word %: | 17.34% |
| Average num. words per input: | 15.0 |
| Avg num queries: | 1132.67 |
+-------------------------------+---------+

View File

@@ -101,9 +101,6 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
new_results, self._search_over = self.get_goal_results(transformed_texts)
if self._search_over:
break
diff_scores = (
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
)
@@ -119,6 +116,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
word_select_prob_weights[idx] = 0
iterations += 1
if self._search_over:
break
return pop_member
@abstractmethod

View File

@@ -37,8 +37,10 @@ class SearchMethod(ABC):
raise AttributeError(
"Search Method must have access to get_model method if it is a white-box method"
)
return self._perform_search(initial_result)
result = self._perform_search(initial_result)
# ensure that the number of queries for this GoalFunctionResult is up-to-date
result.num_queries = self.goal_function.num_queries
return result
@abstractmethod
def _perform_search(self, initial_result):

View File

@@ -108,6 +108,8 @@ class Attack:
# Give search method access to functions for getting transformations and evaluating them
self.search_method.get_transformations = self.get_transformations
# Give search method access to self.goal_function for model query count, etc.
self.search_method.goal_function = self.goal_function
# The search method only needs access to the first argument. The second is only used
# by the attack class when checking whether to skip the sample
self.search_method.get_goal_results = (