1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/attack_results/attack_result.py
2020-07-03 16:01:19 -04:00

133 lines
5.1 KiB
Python

from textattack.goal_function_results import GoalFunctionResult
from textattack.shared import utils
class AttackResult:
"""
Result of an Attack run on a single (output, text_input) pair.
Args:
original_result (GoalFunctionResult): Result of the goal function
applied to the original text
perturbed_result (GoalFunctionResult): Result of the goal function applied to the
perturbed text. May or may not have been successful.
"""
def __init__(self, original_result, perturbed_result):
if original_result is None:
raise ValueError("Attack original result cannot be None")
elif not isinstance(original_result, GoalFunctionResult):
raise TypeError(f"Invalid original goal function result: {original_result}")
if perturbed_result is None:
raise ValueError("Attack perturbed result cannot be None")
elif not isinstance(perturbed_result, GoalFunctionResult):
raise TypeError(
f"Invalid perturbed goal function result: {perturbed_result}"
)
self.original_result = original_result
self.perturbed_result = perturbed_result
self.num_queries = perturbed_result.num_queries
# We don't want the AttackedText attributes sticking around clogging up
# space on our devices. Delete them here, if they're still present,
# because we won't need them anymore anyway.
self.original_result.attacked_text.free_memory()
self.perturbed_result.attacked_text.free_memory()
def original_text(self, color_method=None):
""" Returns the text portion of `self.original_result`. Helper method.
"""
return self.original_result.attacked_text.printable_text(
key_color=("bold", "underline"), key_color_method=color_method
)
def perturbed_text(self, color_method=None):
""" Returns the text portion of `self.perturbed_result`. Helper method.
"""
return self.perturbed_result.attacked_text.printable_text(
key_color=("bold", "underline"), key_color_method=color_method
)
def str_lines(self, color_method=None):
""" A list of the lines to be printed for this result's string
representation.
"""
lines = [self.goal_function_result_str(color_method=color_method)]
lines.extend(self.diff_color(color_method))
return lines
def __str__(self, color_method=None):
return "\n\n".join(self.str_lines(color_method=color_method))
def goal_function_result_str(self, color_method=None):
"""
Returns a string illustrating the results of the goal function.
"""
orig_colored = self.original_result.get_colored_output(color_method)
pert_colored = self.perturbed_result.get_colored_output(color_method)
return orig_colored + " --> " + pert_colored
def diff_color(self, color_method=None):
""" Highlights the difference between two texts using color.
"""
t1 = self.original_result.attacked_text
t2 = self.perturbed_result.attacked_text
if color_method is None:
return t1.printable_text(), t2.printable_text()
color_1 = self.original_result.get_text_color_input()
color_2 = self.perturbed_result.get_text_color_perturbed()
words_1 = []
words_1_idxs = []
words_2 = []
words_2_idxs = []
i1 = 0
i2 = 0
while i1 < t1.num_words or i2 < t2.num_words:
# show deletions
while (
i1 < len(t2.attack_attrs["original_index_map"])
and t2.attack_attrs["original_index_map"][i1] == -1
):
words_1.append(utils.color_text(t1.words[i1], color_1, color_method))
words_1_idxs.append(i1)
i1 += 1
# show insertions
while (
i1 < len(t2.attack_attrs["original_index_map"])
and i2 < t2.attack_attrs["original_index_map"][i1]
):
words_2.append(utils.color_text(t1.words[i2], color_2, color_method))
words_2_idxs.append(i2)
i2 += 1
# show swaps
if i1 < t1.num_words and i2 < t2.num_words:
word_1 = t1.words[i1]
word_2 = t2.words[i2]
if word_1 != word_2:
words_1.append(utils.color_text(word_1, color_1, color_method))
words_2.append(utils.color_text(word_2, color_2, color_method))
words_1_idxs.append(i1)
words_2_idxs.append(i2)
i1 += 1
i2 += 1
t1 = self.original_result.attacked_text.replace_words_at_indices(
words_1_idxs, words_1
)
t2 = self.perturbed_result.attacked_text.replace_words_at_indices(
words_2_idxs, words_2
)
key_color = ("bold", "underline")
return (
t1.printable_text(key_color=key_color, key_color_method=color_method),
t2.printable_text(key_color=key_color, key_color_method=color_method),
)