1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add consistent is_black_box attributes

This commit is contained in:
Jin Yong Yoo
2020-10-06 16:42:21 -04:00
parent 79f3f4f8f0
commit 4893d47e6c
6 changed files with 11 additions and 9 deletions

View File

@@ -50,7 +50,7 @@ class BeamSearch(SearchMethod):
return best_result return best_result
@property @property
def is_blackbox(self): def is_black_box(self):
return True return True
def extra_repr_keys(self): def extra_repr_keys(self):

View File

@@ -286,7 +286,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
return transformation_consists_of_word_swaps(transformation) return transformation_consists_of_word_swaps(transformation)
@property @property
def is_blackbox(self): def is_black_box(self):
return True return True
def extra_repr_keys(self): def extra_repr_keys(self):

View File

@@ -184,7 +184,7 @@ class GreedyWordSwapWIR(SearchMethod):
return transformation_consists_of_word_swaps_and_deletions(transformation) return transformation_consists_of_word_swaps_and_deletions(transformation)
@property @property
def is_blackbox(self): def is_black_box(self):
if self.wir_method == "gradient": if self.wir_method == "gradient":
return False return False
else: else:

View File

@@ -45,7 +45,7 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
self.pop_size = pop_size self.pop_size = pop_size
self.post_turn_check = post_turn_check self.post_turn_check = post_turn_check
self.max_turn_retries = 20 self.max_turn_retries = 20
self.is_blackbox = True self.is_black_box = True
self._search_over = False self._search_over = False
self.omega_1 = 0.8 self.omega_1 = 0.8
@@ -331,7 +331,7 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
return transformation_consists_of_word_swaps(transformation) return transformation_consists_of_word_swaps(transformation)
@property @property
def is_blackbox(self): def is_black_box(self):
return True return True
def extra_repr_keys(self): def extra_repr_keys(self):

View File

@@ -33,7 +33,7 @@ class SearchMethod(ABC):
"Search Method must have access to filter_transformations method" "Search Method must have access to filter_transformations method"
) )
if not self.is_blackbox and not hasattr(self, "get_model"): if not self.is_black_box and not hasattr(self, "get_model"):
raise AttributeError( raise AttributeError(
"Search Method must have access to get_model method if it is a white-box method" "Search Method must have access to get_model method if it is a white-box method"
) )
@@ -55,7 +55,7 @@ class SearchMethod(ABC):
return True return True
@property @property
def is_blackbox(self): def is_black_box(self):
"""Returns `True` if search method does not require access to victim """Returns `True` if search method does not require access to victim
model's internal states.""" model's internal states."""
raise NotImplementedError() raise NotImplementedError()

View File

@@ -67,7 +67,9 @@ class Attack:
self.transformation = transformation self.transformation = transformation
if not self.transformation: if not self.transformation:
raise NameError("Cannot instantiate attack without transformation") raise NameError("Cannot instantiate attack without transformation")
self.is_black_box = getattr(transformation, "is_black_box", True) self.is_black_box = (
getattr(transformation, "is_black_box", True) or search_method.is_black_box
)
if not self.search_method.check_transformation_compatibility( if not self.search_method.check_transformation_compatibility(
self.transformation self.transformation
@@ -114,7 +116,7 @@ class Attack:
) )
) )
self.search_method.filter_transformations = self.filter_transformations self.search_method.filter_transformations = self.filter_transformations
if not search_method.is_blackbox: if not search_method.is_black_box:
self.search_method.get_model = lambda: self.goal_function.model self.search_method.get_model = lambda: self.goal_function.model
def clear_cache(self, recursive=True): def clear_cache(self, recursive=True):