mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add gradient-based white-box search
This commit is contained in:
@@ -49,5 +49,9 @@ class BeamSearch(SearchMethod):
|
|||||||
beam = [potential_next_beam[i] for i in best_indices]
|
beam = [potential_next_beam[i] for i in best_indices]
|
||||||
return best_result
|
return best_result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_blackbox(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def extra_repr_keys(self):
|
def extra_repr_keys(self):
|
||||||
return ["beam_width"]
|
return ["beam_width"]
|
||||||
|
|||||||
@@ -285,6 +285,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
|||||||
substitutions."""
|
substitutions."""
|
||||||
return transformation_consists_of_word_swaps(transformation)
|
return transformation_consists_of_word_swaps(transformation)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_blackbox(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def extra_repr_keys(self):
|
def extra_repr_keys(self):
|
||||||
return [
|
return [
|
||||||
"pop_size",
|
"pop_size",
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
wir_method: method for ranking most important words
|
wir_method: method for ranking most important words
|
||||||
|
model_wrapper: model wrapper used for gradient-based ranking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, wir_method="unk"):
|
def __init__(self, wir_method="unk"):
|
||||||
@@ -44,6 +45,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
|||||||
]
|
]
|
||||||
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
||||||
index_scores = np.array([result.score for result in leave_one_results])
|
index_scores = np.array([result.score for result in leave_one_results])
|
||||||
|
|
||||||
elif self.wir_method == "weighted-saliency":
|
elif self.wir_method == "weighted-saliency":
|
||||||
# first, compute word saliency
|
# first, compute word saliency
|
||||||
leave_one_texts = [
|
leave_one_texts = [
|
||||||
@@ -74,12 +76,47 @@ class GreedyWordSwapWIR(SearchMethod):
|
|||||||
delta_ps.append(max_score_change)
|
delta_ps.append(max_score_change)
|
||||||
|
|
||||||
index_scores = softmax_saliency_scores * np.array(delta_ps)
|
index_scores = softmax_saliency_scores * np.array(delta_ps)
|
||||||
|
|
||||||
elif self.wir_method == "delete":
|
elif self.wir_method == "delete":
|
||||||
leave_one_texts = [
|
leave_one_texts = [
|
||||||
initial_text.delete_word_at_index(i) for i in range(len_text)
|
initial_text.delete_word_at_index(i) for i in range(len_text)
|
||||||
]
|
]
|
||||||
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
||||||
index_scores = np.array([result.score for result in leave_one_results])
|
index_scores = np.array([result.score for result in leave_one_results])
|
||||||
|
|
||||||
|
elif self.wir_method == "gradient":
|
||||||
|
victim_model = self.get_model()
|
||||||
|
index_scores = np.zeros(initial_text.num_words)
|
||||||
|
grad_output = victim_model.get_grad(initial_text.tokenizer_input)
|
||||||
|
gradient = grad_output["gradient"]
|
||||||
|
j = 0
|
||||||
|
last_matched = 0
|
||||||
|
for i, word in enumerate(initial_text.words):
|
||||||
|
word = initial_text.words[i].lower()
|
||||||
|
matched_tokens = []
|
||||||
|
a = []
|
||||||
|
while j < len(grad_output["tokens"]) and len(word) > 0:
|
||||||
|
token = grad_output["tokens"][j].lower()
|
||||||
|
# remove "##" if it's a subword
|
||||||
|
token = token.replace("##", "")
|
||||||
|
idx = word.find(token)
|
||||||
|
if idx == 0:
|
||||||
|
word = word[idx + len(token) :]
|
||||||
|
matched_tokens.append(j)
|
||||||
|
a.append(token)
|
||||||
|
last_matched = j
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
if not matched_tokens:
|
||||||
|
# Reset j to most recent match
|
||||||
|
j = last_matched
|
||||||
|
index_scores[i] = 0.0
|
||||||
|
else:
|
||||||
|
agg_grad = np.mean(gradient[matched_tokens], axis=0)
|
||||||
|
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
|
||||||
|
|
||||||
|
search_over = False
|
||||||
|
|
||||||
elif self.wir_method == "random":
|
elif self.wir_method == "random":
|
||||||
index_order = np.arange(len_text)
|
index_order = np.arange(len_text)
|
||||||
np.random.shuffle(index_order)
|
np.random.shuffle(index_order)
|
||||||
@@ -146,5 +183,12 @@ class GreedyWordSwapWIR(SearchMethod):
|
|||||||
limited to word swap and deletion transformations."""
|
limited to word swap and deletion transformations."""
|
||||||
return transformation_consists_of_word_swaps_and_deletions(transformation)
|
return transformation_consists_of_word_swaps_and_deletions(transformation)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_blackbox(self):
|
||||||
|
if self.wir_method == "gradient":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
def extra_repr_keys(self):
|
def extra_repr_keys(self):
|
||||||
return ["wir_method"]
|
return ["wir_method"]
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
|
|||||||
self.pop_size = pop_size
|
self.pop_size = pop_size
|
||||||
self.post_turn_check = post_turn_check
|
self.post_turn_check = post_turn_check
|
||||||
self.max_turn_retries = 20
|
self.max_turn_retries = 20
|
||||||
|
self.is_blackbox = True
|
||||||
|
|
||||||
self._search_over = False
|
self._search_over = False
|
||||||
self.omega_1 = 0.8
|
self.omega_1 = 0.8
|
||||||
@@ -329,6 +330,10 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
|
|||||||
substitutions."""
|
substitutions."""
|
||||||
return transformation_consists_of_word_swaps(transformation)
|
return transformation_consists_of_word_swaps(transformation)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_blackbox(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def extra_repr_keys(self):
|
def extra_repr_keys(self):
|
||||||
return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]
|
return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,12 @@ class SearchMethod(ABC):
|
|||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"Search Method must have access to filter_transformations method"
|
"Search Method must have access to filter_transformations method"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.is_blackbox and not hasattr(self, "get_model"):
|
||||||
|
raise AttributeError(
|
||||||
|
"Search Method must have access to get_model method if it is a white-box method"
|
||||||
|
)
|
||||||
|
|
||||||
return self._perform_search(initial_result)
|
return self._perform_search(initial_result)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -48,6 +54,12 @@ class SearchMethod(ABC):
|
|||||||
``transformation``."""
|
``transformation``."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_blackbox(self):
|
||||||
|
"""Returns `True` if search method does not require access to victim
|
||||||
|
model's internal states."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def extra_repr_keys(self):
|
def extra_repr_keys(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,8 @@ class Attack:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.search_method.filter_transformations = self.filter_transformations
|
self.search_method.filter_transformations = self.filter_transformations
|
||||||
|
if not search_method.is_blackbox:
|
||||||
|
self.search_method.get_model = lambda: self.goal_function.model
|
||||||
|
|
||||||
def clear_cache(self, recursive=True):
|
def clear_cache(self, recursive=True):
|
||||||
self.constraints_cache.clear()
|
self.constraints_cache.clear()
|
||||||
|
|||||||
@@ -100,14 +100,6 @@ def load_textattack_model_from_path(model_name, model_path):
|
|||||||
model = textattack.models.helpers.WordCNNForClassification(
|
model = textattack.models.helpers.WordCNNForClassification(
|
||||||
model_path=model_path, num_labels=num_labels
|
model_path=model_path, num_labels=num_labels
|
||||||
)
|
)
|
||||||
elif model_name.startswith("bert"):
|
|
||||||
model_path, num_labels = model_path
|
|
||||||
textattack.shared.logger.info(
|
|
||||||
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
|
|
||||||
)
|
|
||||||
model = textattack.models.helpers.BERTForClassification(
|
|
||||||
model_path=model_path, num_labels=num_labels
|
|
||||||
)
|
|
||||||
elif model_name.startswith("t5"):
|
elif model_name.startswith("t5"):
|
||||||
model = textattack.models.helpers.T5ForTextToText(model_path)
|
model = textattack.models.helpers.T5ForTextToText(model_path)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user