mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
"""
|
|
|
|
classification_goal_function_result
|
|
====================================
|
|
|
|
"""
|
|
|
|
import torch
|
|
|
|
import textattack
|
|
from textattack.shared import utils
|
|
|
|
from .goal_function_result import GoalFunctionResult
|
|
|
|
|
|
class ClassificationGoalFunctionResult(GoalFunctionResult):
|
|
"""Represents the result of a classification goal function."""
|
|
|
|
@property
|
|
def _processed_output(self):
|
|
"""Takes a model output (like `1`) and returns the class labeled output
|
|
(like `positive`), if possible.
|
|
|
|
Also returns the associated color.
|
|
"""
|
|
output_label = self.raw_output.argmax()
|
|
if self.attacked_text.attack_attrs.get("label_names"):
|
|
output = self.attacked_text.attack_attrs["label_names"][self.output]
|
|
output = textattack.shared.utils.process_label_name(output)
|
|
color = textattack.shared.utils.color_from_output(output, output_label)
|
|
return output, color
|
|
else:
|
|
color = textattack.shared.utils.color_from_label(output_label)
|
|
return output_label, color
|
|
|
|
def get_text_color_input(self):
|
|
"""A string representing the color this result's changed portion should
|
|
be if it represents the original input."""
|
|
_, color = self._processed_output
|
|
return color
|
|
|
|
def get_text_color_perturbed(self):
|
|
"""A string representing the color this result's changed portion should
|
|
be if it represents the perturbed input."""
|
|
_, color = self._processed_output
|
|
return color
|
|
|
|
def get_colored_output(self, color_method=None):
|
|
"""Returns a string representation of this result's output, colored
|
|
according to `color_method`."""
|
|
output_label = self.raw_output.argmax()
|
|
confidence_score = self.raw_output[output_label]
|
|
if isinstance(confidence_score, torch.Tensor):
|
|
confidence_score = confidence_score.item()
|
|
output, color = self._processed_output
|
|
# concatenate with label and convert confidence score to percent, like '33%'
|
|
output_str = f"{output} ({confidence_score:.0%})"
|
|
return utils.color_text(output_str, color=color, method=color_method)
|