mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add gradient-based white-box search
This commit is contained in:
@@ -49,5 +49,9 @@ class BeamSearch(SearchMethod):
|
||||
beam = [potential_next_beam[i] for i in best_indices]
|
||||
return best_result
|
||||
|
||||
@property
|
||||
def is_blackbox(self):
|
||||
return True
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["beam_width"]
|
||||
|
||||
@@ -285,6 +285,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
substitutions."""
|
||||
return transformation_consists_of_word_swaps(transformation)
|
||||
|
||||
@property
|
||||
def is_blackbox(self):
|
||||
return True
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return [
|
||||
"pop_size",
|
||||
|
||||
@@ -28,6 +28,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
|
||||
Args:
|
||||
wir_method: method for ranking most important words
|
||||
model_wrapper: model wrapper used for gradient-based ranking
|
||||
"""
|
||||
|
||||
def __init__(self, wir_method="unk"):
|
||||
@@ -44,6 +45,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
]
|
||||
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
||||
index_scores = np.array([result.score for result in leave_one_results])
|
||||
|
||||
elif self.wir_method == "weighted-saliency":
|
||||
# first, compute word saliency
|
||||
leave_one_texts = [
|
||||
@@ -74,12 +76,47 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
delta_ps.append(max_score_change)
|
||||
|
||||
index_scores = softmax_saliency_scores * np.array(delta_ps)
|
||||
|
||||
elif self.wir_method == "delete":
|
||||
leave_one_texts = [
|
||||
initial_text.delete_word_at_index(i) for i in range(len_text)
|
||||
]
|
||||
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
||||
index_scores = np.array([result.score for result in leave_one_results])
|
||||
|
||||
elif self.wir_method == "gradient":
|
||||
victim_model = self.get_model()
|
||||
index_scores = np.zeros(initial_text.num_words)
|
||||
grad_output = victim_model.get_grad(initial_text.tokenizer_input)
|
||||
gradient = grad_output["gradient"]
|
||||
j = 0
|
||||
last_matched = 0
|
||||
for i, word in enumerate(initial_text.words):
|
||||
word = initial_text.words[i].lower()
|
||||
matched_tokens = []
|
||||
a = []
|
||||
while j < len(grad_output["tokens"]) and len(word) > 0:
|
||||
token = grad_output["tokens"][j].lower()
|
||||
# remove "##" if it's a subword
|
||||
token = token.replace("##", "")
|
||||
idx = word.find(token)
|
||||
if idx == 0:
|
||||
word = word[idx + len(token) :]
|
||||
matched_tokens.append(j)
|
||||
a.append(token)
|
||||
last_matched = j
|
||||
j += 1
|
||||
|
||||
if not matched_tokens:
|
||||
# Reset j to most recent match
|
||||
j = last_matched
|
||||
index_scores[i] = 0.0
|
||||
else:
|
||||
agg_grad = np.mean(gradient[matched_tokens], axis=0)
|
||||
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
|
||||
|
||||
search_over = False
|
||||
|
||||
elif self.wir_method == "random":
|
||||
index_order = np.arange(len_text)
|
||||
np.random.shuffle(index_order)
|
||||
@@ -146,5 +183,12 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
limited to word swap and deletion transformations."""
|
||||
return transformation_consists_of_word_swaps_and_deletions(transformation)
|
||||
|
||||
@property
|
||||
def is_blackbox(self):
|
||||
if self.wir_method == "gradient":
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["wir_method"]
|
||||
|
||||
@@ -45,6 +45,7 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
|
||||
self.pop_size = pop_size
|
||||
self.post_turn_check = post_turn_check
|
||||
self.max_turn_retries = 20
|
||||
self.is_blackbox = True
|
||||
|
||||
self._search_over = False
|
||||
self.omega_1 = 0.8
|
||||
@@ -329,6 +330,10 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
|
||||
substitutions."""
|
||||
return transformation_consists_of_word_swaps(transformation)
|
||||
|
||||
@property
|
||||
def is_blackbox(self):
|
||||
return True
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]
|
||||
|
||||
|
||||
@@ -32,6 +32,12 @@ class SearchMethod(ABC):
|
||||
raise AttributeError(
|
||||
"Search Method must have access to filter_transformations method"
|
||||
)
|
||||
|
||||
if not self.is_blackbox and not hasattr(self, "get_model"):
|
||||
raise AttributeError(
|
||||
"Search Method must have access to get_model method if it is a white-box method"
|
||||
)
|
||||
|
||||
return self._perform_search(initial_result)
|
||||
|
||||
@abstractmethod
|
||||
@@ -48,6 +54,12 @@ class SearchMethod(ABC):
|
||||
``transformation``."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_blackbox(self):
|
||||
"""Returns `True` if search method does not require access to victim
|
||||
model's internal states."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return []
|
||||
|
||||
|
||||
@@ -114,6 +114,8 @@ class Attack:
|
||||
)
|
||||
)
|
||||
self.search_method.filter_transformations = self.filter_transformations
|
||||
if not search_method.is_blackbox:
|
||||
self.search_method.get_model = lambda: self.goal_function.model
|
||||
|
||||
def clear_cache(self, recursive=True):
|
||||
self.constraints_cache.clear()
|
||||
|
||||
@@ -100,14 +100,6 @@ def load_textattack_model_from_path(model_name, model_path):
|
||||
model = textattack.models.helpers.WordCNNForClassification(
|
||||
model_path=model_path, num_labels=num_labels
|
||||
)
|
||||
elif model_name.startswith("bert"):
|
||||
model_path, num_labels = model_path
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
|
||||
)
|
||||
model = textattack.models.helpers.BERTForClassification(
|
||||
model_path=model_path, num_labels=num_labels
|
||||
)
|
||||
elif model_name.startswith("t5"):
|
||||
model = textattack.models.helpers.T5ForTextToText(model_path)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user