mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #350 from a1noack/fix_query_count
fix model query count for all search methods
This commit is contained in:
@@ -56,16 +56,16 @@ hugh grant , who has a good line in charm , has never been more [91mloveable[0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
+-------------------------------+--------+
|
+-------------------------------+---------+
|
||||||
| Attack Results | |
|
| Attack Results | |
|
||||||
+-------------------------------+--------+
|
+-------------------------------+---------+
|
||||||
| Number of successful attacks: | 2 |
|
| Number of successful attacks: | 2 |
|
||||||
| Number of failed attacks: | 1 |
|
| Number of failed attacks: | 1 |
|
||||||
| Number of skipped attacks: | 0 |
|
| Number of skipped attacks: | 0 |
|
||||||
| Original accuracy: | 100.0% |
|
| Original accuracy: | 100.0% |
|
||||||
| Accuracy under attack: | 33.33% |
|
| Accuracy under attack: | 33.33% |
|
||||||
| Attack success rate: | 66.67% |
|
| Attack success rate: | 66.67% |
|
||||||
| Average perturbed word %: | 17.34% |
|
| Average perturbed word %: | 17.34% |
|
||||||
| Average num. words per input: | 15.0 |
|
| Average num. words per input: | 15.0 |
|
||||||
| Avg num queries: | 551.67 |
|
| Avg num queries: | 1132.67 |
|
||||||
+-------------------------------+--------+
|
+-------------------------------+---------+
|
||||||
|
|||||||
@@ -101,9 +101,6 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
|||||||
|
|
||||||
new_results, self._search_over = self.get_goal_results(transformed_texts)
|
new_results, self._search_over = self.get_goal_results(transformed_texts)
|
||||||
|
|
||||||
if self._search_over:
|
|
||||||
break
|
|
||||||
|
|
||||||
diff_scores = (
|
diff_scores = (
|
||||||
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
|
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
|
||||||
)
|
)
|
||||||
@@ -119,6 +116,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
|||||||
|
|
||||||
word_select_prob_weights[idx] = 0
|
word_select_prob_weights[idx] = 0
|
||||||
iterations += 1
|
iterations += 1
|
||||||
|
|
||||||
|
if self._search_over:
|
||||||
|
break
|
||||||
|
|
||||||
return pop_member
|
return pop_member
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -37,8 +37,10 @@ class SearchMethod(ABC):
|
|||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"Search Method must have access to get_model method if it is a white-box method"
|
"Search Method must have access to get_model method if it is a white-box method"
|
||||||
)
|
)
|
||||||
|
result = self._perform_search(initial_result)
|
||||||
return self._perform_search(initial_result)
|
# ensure that the number of queries for this GoalFunctionResult is up-to-date
|
||||||
|
result.num_queries = self.goal_function.num_queries
|
||||||
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _perform_search(self, initial_result):
|
def _perform_search(self, initial_result):
|
||||||
|
|||||||
@@ -108,6 +108,8 @@ class Attack:
|
|||||||
|
|
||||||
# Give search method access to functions for getting transformations and evaluating them
|
# Give search method access to functions for getting transformations and evaluating them
|
||||||
self.search_method.get_transformations = self.get_transformations
|
self.search_method.get_transformations = self.get_transformations
|
||||||
|
# Give search method access to self.goal_function for model query count, etc.
|
||||||
|
self.search_method.goal_function = self.goal_function
|
||||||
# The search method only needs access to the first argument. The second is only used
|
# The search method only needs access to the first argument. The second is only used
|
||||||
# by the attack class when checking whether to skip the sample
|
# by the attack class when checking whether to skip the sample
|
||||||
self.search_method.get_goal_results = (
|
self.search_method.get_goal_results = (
|
||||||
|
|||||||
Reference in New Issue
Block a user