mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge pull request #221 from QData/greedy-wir-cleanup
Clean up greedy-wir
This commit is contained in:
@@ -25,5 +25,5 @@ def PWWSRen2019(model):
|
||||
constraints = [RepeatModification(), StopwordModification()]
|
||||
goal_function = UntargetedClassification(model)
|
||||
# search over words based on a combination of their saliency score, and how efficient the WordSwap transform is
|
||||
search_method = GreedyWordSwapWIR("pwws")
|
||||
search_method = GreedyWordSwapWIR("weighted-saliency")
|
||||
return Attack(goal_function, constraints, transformation, search_method)
|
||||
|
||||
@@ -28,35 +28,24 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
def __init__(self, wir_method="unk"):
|
||||
self.wir_method = wir_method
|
||||
|
||||
def _get_index_order(self, initial_result, texts):
|
||||
"""Queries model for list of attacked text objects ``text`` and ranks
|
||||
in order of descending score."""
|
||||
leave_one_results, search_over = self.get_goal_results(texts)
|
||||
leave_one_scores = np.array([result.score for result in leave_one_results])
|
||||
return leave_one_scores, search_over
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
attacked_text = initial_result.attacked_text
|
||||
cur_result = initial_result
|
||||
|
||||
# Sort words by order of importance
|
||||
len_text = len(attacked_text.words)
|
||||
def _get_index_order(self, initial_text):
|
||||
"""Returns word indices of ``initial_text`` in descending order of
|
||||
importance."""
|
||||
len_text = len(initial_text.words)
|
||||
|
||||
if self.wir_method == "unk":
|
||||
leave_one_texts = [
|
||||
attacked_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
|
||||
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
|
||||
]
|
||||
leave_one_scores, search_over = self._get_index_order(
|
||||
initial_result, leave_one_texts
|
||||
)
|
||||
elif self.wir_method == "pwws":
|
||||
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
||||
index_scores = np.array([result.score for result in leave_one_results])
|
||||
elif self.wir_method == "weighted-saliency":
|
||||
# first, compute word saliency
|
||||
leave_one_texts = [
|
||||
attacked_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
|
||||
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
|
||||
]
|
||||
saliency_scores, search_over = self._get_index_order(
|
||||
initial_result, leave_one_texts
|
||||
)
|
||||
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
||||
saliency_scores = np.array([result.score for result in leave_one_results])
|
||||
|
||||
softmax_saliency_scores = softmax(
|
||||
torch.Tensor(saliency_scores), dim=0
|
||||
@@ -66,9 +55,7 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
delta_ps = []
|
||||
for idx in range(len_text):
|
||||
transformed_text_candidates = self.get_transformations(
|
||||
cur_result.attacked_text,
|
||||
original_text=initial_result.attacked_text,
|
||||
indices_to_modify=[idx],
|
||||
initial_text, original_text=initial_text, indices_to_modify=[idx],
|
||||
)
|
||||
if not transformed_text_candidates:
|
||||
# no valid synonym substitutions for this word
|
||||
@@ -79,26 +66,33 @@ class GreedyWordSwapWIR(SearchMethod):
|
||||
max_score_change = np.max(score_change)
|
||||
delta_ps.append(max_score_change)
|
||||
|
||||
leave_one_scores = softmax_saliency_scores * np.array(delta_ps)
|
||||
|
||||
index_scores = softmax_saliency_scores * np.array(delta_ps)
|
||||
elif self.wir_method == "delete":
|
||||
leave_one_texts = [
|
||||
attacked_text.delete_word_at_index(i) for i in range(len_text)
|
||||
initial_text.delete_word_at_index(i) for i in range(len_text)
|
||||
]
|
||||
leave_one_scores, search_over = self._get_index_order(
|
||||
initial_result, leave_one_texts
|
||||
)
|
||||
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
|
||||
index_scores = np.array([result.score for result in leave_one_results])
|
||||
elif self.wir_method == "random":
|
||||
index_order = np.arange(len_text)
|
||||
np.random.shuffle(index_order)
|
||||
search_over = False
|
||||
else:
|
||||
raise ValueError(f"Unsupport WIR method {self.wir_method}")
|
||||
raise ValueError(f"Unsupported WIR method {self.wir_method}")
|
||||
|
||||
if self.wir_method != "random":
|
||||
index_order = (-leave_one_scores).argsort()
|
||||
index_order = (-index_scores).argsort()
|
||||
|
||||
return index_order, search_over
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
attacked_text = initial_result.attacked_text
|
||||
|
||||
# Sort words by order of importance
|
||||
index_order, search_over = self._get_index_order(attacked_text)
|
||||
|
||||
i = 0
|
||||
cur_result = initial_result
|
||||
results = None
|
||||
while i < len(index_order) and not search_over:
|
||||
transformed_text_candidates = self.get_transformations(
|
||||
|
||||
Reference in New Issue
Block a user