1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #221 from QData/greedy-wir-cleanup

Clean up greedy-wir
This commit is contained in:
Jack Morris
2020-07-22 13:59:30 -04:00
committed by GitHub
3 changed files with 28 additions and 35 deletions

View File

@@ -25,5 +25,5 @@ def PWWSRen2019(model):
constraints = [RepeatModification(), StopwordModification()]
goal_function = UntargetedClassification(model)
# search over words based on a combination of their saliency score, and how efficient the WordSwap transform is
search_method = GreedyWordSwapWIR("pwws")
search_method = GreedyWordSwapWIR("weighted-saliency")
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -28,35 +28,24 @@ class GreedyWordSwapWIR(SearchMethod):
def __init__(self, wir_method="unk"):
self.wir_method = wir_method
def _get_index_order(self, initial_result, texts):
"""Queries model for list of attacked text objects ``text`` and ranks
in order of descending score."""
leave_one_results, search_over = self.get_goal_results(texts)
leave_one_scores = np.array([result.score for result in leave_one_results])
return leave_one_scores, search_over
def _perform_search(self, initial_result):
attacked_text = initial_result.attacked_text
cur_result = initial_result
# Sort words by order of importance
len_text = len(attacked_text.words)
def _get_index_order(self, initial_text):
"""Returns word indices of ``initial_text`` in descending order of
importance."""
len_text = len(initial_text.words)
if self.wir_method == "unk":
leave_one_texts = [
attacked_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
]
leave_one_scores, search_over = self._get_index_order(
initial_result, leave_one_texts
)
elif self.wir_method == "pwws":
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "weighted-saliency":
# first, compute word saliency
leave_one_texts = [
attacked_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
]
saliency_scores, search_over = self._get_index_order(
initial_result, leave_one_texts
)
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
saliency_scores = np.array([result.score for result in leave_one_results])
softmax_saliency_scores = softmax(
torch.Tensor(saliency_scores), dim=0
@@ -66,9 +55,7 @@ class GreedyWordSwapWIR(SearchMethod):
delta_ps = []
for idx in range(len_text):
transformed_text_candidates = self.get_transformations(
cur_result.attacked_text,
original_text=initial_result.attacked_text,
indices_to_modify=[idx],
initial_text, original_text=initial_text, indices_to_modify=[idx],
)
if not transformed_text_candidates:
# no valid synonym substitutions for this word
@@ -79,26 +66,33 @@ class GreedyWordSwapWIR(SearchMethod):
max_score_change = np.max(score_change)
delta_ps.append(max_score_change)
leave_one_scores = softmax_saliency_scores * np.array(delta_ps)
index_scores = softmax_saliency_scores * np.array(delta_ps)
elif self.wir_method == "delete":
leave_one_texts = [
attacked_text.delete_word_at_index(i) for i in range(len_text)
initial_text.delete_word_at_index(i) for i in range(len_text)
]
leave_one_scores, search_over = self._get_index_order(
initial_result, leave_one_texts
)
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "random":
index_order = np.arange(len_text)
np.random.shuffle(index_order)
search_over = False
else:
raise ValueError(f"Unsupport WIR method {self.wir_method}")
raise ValueError(f"Unsupported WIR method {self.wir_method}")
if self.wir_method != "random":
index_order = (-leave_one_scores).argsort()
index_order = (-index_scores).argsort()
return index_order, search_over
def _perform_search(self, initial_result):
attacked_text = initial_result.attacked_text
# Sort words by order of importance
index_order, search_over = self._get_index_order(attacked_text)
i = 0
cur_result = initial_result
results = None
while i < len(index_order) and not search_over:
transformed_text_candidates = self.get_transformations(