diff --git a/textattack/search_methods/beam_search.py b/textattack/search_methods/beam_search.py index e0bcb335..715ee345 100644 --- a/textattack/search_methods/beam_search.py +++ b/textattack/search_methods/beam_search.py @@ -50,7 +50,7 @@ class BeamSearch(SearchMethod): return best_result @property - def is_blackbox(self): + def is_black_box(self): return True def extra_repr_keys(self): diff --git a/textattack/search_methods/genetic_algorithm.py b/textattack/search_methods/genetic_algorithm.py index d2cd49e5..a84cc170 100644 --- a/textattack/search_methods/genetic_algorithm.py +++ b/textattack/search_methods/genetic_algorithm.py @@ -286,7 +286,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC): return transformation_consists_of_word_swaps(transformation) @property - def is_blackbox(self): + def is_black_box(self): return True def extra_repr_keys(self): diff --git a/textattack/search_methods/greedy_word_swap_wir.py b/textattack/search_methods/greedy_word_swap_wir.py index bbb80c24..f1399cfe 100644 --- a/textattack/search_methods/greedy_word_swap_wir.py +++ b/textattack/search_methods/greedy_word_swap_wir.py @@ -184,7 +184,7 @@ class GreedyWordSwapWIR(SearchMethod): return transformation_consists_of_word_swaps_and_deletions(transformation) @property - def is_blackbox(self): + def is_black_box(self): if self.wir_method == "gradient": return False else: diff --git a/textattack/search_methods/particle_swarm_optimization.py b/textattack/search_methods/particle_swarm_optimization.py index 8712542e..ad24e428 100644 --- a/textattack/search_methods/particle_swarm_optimization.py +++ b/textattack/search_methods/particle_swarm_optimization.py @@ -45,7 +45,7 @@ class ParticleSwarmOptimization(PopulationBasedSearch): self.pop_size = pop_size self.post_turn_check = post_turn_check self.max_turn_retries = 20 - self.is_blackbox = True + self.is_black_box = True self._search_over = False self.omega_1 = 0.8 @@ -331,7 +331,7 @@ class ParticleSwarmOptimization(PopulationBasedSearch): return transformation_consists_of_word_swaps(transformation) @property - def is_blackbox(self): + def is_black_box(self): return True def extra_repr_keys(self): diff --git a/textattack/search_methods/search_method.py b/textattack/search_methods/search_method.py index ee26b430..cb6b0784 100644 --- a/textattack/search_methods/search_method.py +++ b/textattack/search_methods/search_method.py @@ -33,7 +33,7 @@ class SearchMethod(ABC): "Search Method must have access to filter_transformations method" ) - if not self.is_blackbox and not hasattr(self, "get_model"): + if not self.is_black_box and not hasattr(self, "get_model"): raise AttributeError( "Search Method must have access to get_model method if it is a white-box method" ) @@ -55,7 +55,7 @@ class SearchMethod(ABC): return True @property - def is_blackbox(self): + def is_black_box(self): """Returns `True` if search method does not require access to victim model's internal states.""" raise NotImplementedError() diff --git a/textattack/shared/attack.py b/textattack/shared/attack.py index 97f6f7ca..dc95785f 100644 --- a/textattack/shared/attack.py +++ b/textattack/shared/attack.py @@ -67,7 +67,9 @@ class Attack: self.transformation = transformation if not self.transformation: raise NameError("Cannot instantiate attack without transformation") - self.is_black_box = getattr(transformation, "is_black_box", True) + self.is_black_box = ( + getattr(transformation, "is_black_box", True) or search_method.is_black_box + ) if not self.search_method.check_transformation_compatibility( self.transformation @@ -114,7 +116,7 @@ class Attack: ) ) self.search_method.filter_transformations = self.filter_transformations - if not search_method.is_blackbox: + if not search_method.is_black_box: self.search_method.get_model = lambda: self.goal_function.model def clear_cache(self, recursive=True):