1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add gradient-based white-box search

This commit is contained in:
Jin Yong Yoo
2020-10-05 18:46:47 -04:00
parent 6273b19c19
commit bdbeae80d2
7 changed files with 71 additions and 8 deletions

View File

@@ -49,5 +49,9 @@ class BeamSearch(SearchMethod):
beam = [potential_next_beam[i] for i in best_indices]
return best_result
@property
def is_blackbox(self):
return True
def extra_repr_keys(self):
return ["beam_width"]

View File

@@ -285,6 +285,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
substitutions."""
return transformation_consists_of_word_swaps(transformation)
@property
def is_blackbox(self):
return True
def extra_repr_keys(self):
return [
"pop_size",

View File

@@ -28,6 +28,7 @@ class GreedyWordSwapWIR(SearchMethod):
Args:
wir_method: method for ranking most important words
model_wrapper: model wrapper used for gradient-based ranking
"""
def __init__(self, wir_method="unk"):
@@ -44,6 +45,7 @@ class GreedyWordSwapWIR(SearchMethod):
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "weighted-saliency":
# first, compute word saliency
leave_one_texts = [
@@ -74,12 +76,47 @@ class GreedyWordSwapWIR(SearchMethod):
delta_ps.append(max_score_change)
index_scores = softmax_saliency_scores * np.array(delta_ps)
elif self.wir_method == "delete":
leave_one_texts = [
initial_text.delete_word_at_index(i) for i in range(len_text)
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "gradient":
victim_model = self.get_model()
index_scores = np.zeros(initial_text.num_words)
grad_output = victim_model.get_grad(initial_text.tokenizer_input)
gradient = grad_output["gradient"]
j = 0
last_matched = 0
for i, word in enumerate(initial_text.words):
word = initial_text.words[i].lower()
matched_tokens = []
a = []
while j < len(grad_output["tokens"]) and len(word) > 0:
token = grad_output["tokens"][j].lower()
# remove "##" if it's a subword
token = token.replace("##", "")
idx = word.find(token)
if idx == 0:
word = word[idx + len(token) :]
matched_tokens.append(j)
a.append(token)
last_matched = j
j += 1
if not matched_tokens:
# Reset j to most recent match
j = last_matched
index_scores[i] = 0.0
else:
agg_grad = np.mean(gradient[matched_tokens], axis=0)
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
search_over = False
elif self.wir_method == "random":
index_order = np.arange(len_text)
np.random.shuffle(index_order)
@@ -146,5 +183,12 @@ class GreedyWordSwapWIR(SearchMethod):
limited to word swap and deletion transformations."""
return transformation_consists_of_word_swaps_and_deletions(transformation)
@property
def is_blackbox(self):
if self.wir_method == "gradient":
return False
else:
return True
def extra_repr_keys(self):
return ["wir_method"]

View File

@@ -45,6 +45,7 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
self.pop_size = pop_size
self.post_turn_check = post_turn_check
self.max_turn_retries = 20
self.is_blackbox = True
self._search_over = False
self.omega_1 = 0.8
@@ -329,6 +330,10 @@ class ParticleSwarmOptimization(PopulationBasedSearch):
substitutions."""
return transformation_consists_of_word_swaps(transformation)
@property
def is_blackbox(self):
return True
def extra_repr_keys(self):
return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]

View File

@@ -32,6 +32,12 @@ class SearchMethod(ABC):
raise AttributeError(
"Search Method must have access to filter_transformations method"
)
if not self.is_blackbox and not hasattr(self, "get_model"):
raise AttributeError(
"Search Method must have access to get_model method if it is a white-box method"
)
return self._perform_search(initial_result)
@abstractmethod
@@ -48,6 +54,12 @@ class SearchMethod(ABC):
``transformation``."""
return True
@property
def is_blackbox(self):
"""Returns `True` if search method does not require access to victim
model's internal states."""
raise NotImplementedError()
def extra_repr_keys(self):
return []

View File

@@ -114,6 +114,8 @@ class Attack:
)
)
self.search_method.filter_transformations = self.filter_transformations
if not search_method.is_blackbox:
self.search_method.get_model = lambda: self.goal_function.model
def clear_cache(self, recursive=True):
self.constraints_cache.clear()

View File

@@ -100,14 +100,6 @@ def load_textattack_model_from_path(model_name, model_path):
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("bert"):
model_path, num_labels = model_path
textattack.shared.logger.info(
f"Loading pre-trained TextAttack BERT model: {colored_model_name}"
)
model = textattack.models.helpers.BERTForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("t5"):
model = textattack.models.helpers.T5ForTextToText(model_path)
else: