1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/search_methods/greedy_word_swap_wir.py
2020-06-23 23:33:48 -04:00

122 lines
4.7 KiB
Python

"""
When WIR method is set to ``unk``, this is a
reimplementation of the search method from thepaper:
Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and
Entailment by Jin et. al, 2019.
See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler.
"""
import numpy as np
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps
class GreedyWordSwapWIR(SearchMethod):
"""
An attack that greedily chooses from a list of possible perturbations in
order of index, after ranking indices by importance.
Args:
wir_method: method for ranking most important words
ascending: if True, ranks words from least-to-most important. (Default
ranking shows the most important word first.)
"""
def __init__(self, wir_method="unk", ascending=False):
self.wir_method = wir_method
self.ascending = ascending
def _get_index_order(self, initial_result, texts):
""" Queries model for list of attacked text objects ``text`` and
ranks in order of descending score.
"""
leave_one_results, search_over = self.get_goal_results(
texts
)
leave_one_scores = np.array([result.score for result in leave_one_results])
return leave_one_scores, search_over
def _perform_search(self, initial_result):
attacked_text = initial_result.attacked_text
cur_result = initial_result
# Sort words by order of importance
len_text = len(attacked_text.words)
if self.wir_method == "unk":
leave_one_texts = [
attacked_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
]
leave_one_scores, search_over = self._get_index_order(
initial_result, leave_one_texts
)
elif self.wir_method == "delete":
leave_one_texts = [
attacked_text.delete_word_at_index(i) for i in range(len_text)
]
leave_one_scores = self._get_index_order(initial_result, leave_one_texts)
elif self.wir_method == "random":
leave_one_scores = torch.random(len_text)
search_over = False
if self.ascending:
index_order = (leave_one_scores).argsort()
else:
index_order = (-leave_one_scores).argsort()
i = 0
results = None
while i < len(index_order) and not search_over:
transformed_text_candidates = self.get_transformations(
cur_result.attacked_text,
original_text=initial_result.attacked_text,
indices_to_modify=[index_order[i]],
)
i += 1
if len(transformed_text_candidates) == 0:
continue
results, search_over = self.get_goal_results(
transformed_text_candidates
)
results = sorted(results, key=lambda x: -x.score)
# Skip swaps which don't improve the score
if results[0].score > cur_result.score:
cur_result = results[0]
else:
continue
# If we succeeded, return the index with best similarity.
if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
best_result = cur_result
# @TODO: Use vectorwise operations
max_similarity = -float("inf")
for result in results:
if result.goal_status != GoalFunctionResultStatus.SUCCEEDED:
break
candidate = result.attacked_text
try:
similarity_score = candidate.attack_attrs["similarity_score"]
except KeyError:
# If the attack was run without any similarity metrics,
# candidates won't have a similarity score. In this
# case, break and return the candidate that changed
# the original score the most.
break
if similarity_score > max_similarity:
max_similarity = similarity_score
best_result = result
return best_result
return cur_result
def check_transformation_compatibility(self, transformation):
"""
Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swaps transformations.
"""
return transformation_consists_of_word_swaps(transformation)
def extra_repr_keys(self):
return ["wir_method"]