1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/goal_functions/classification/targeted_classification.py

28 lines
1.0 KiB
Python

from .classification_goal_function import ClassificationGoalFunction
class TargetedClassification(ClassificationGoalFunction):
"""
An targeted attack on classification models which attempts to maximize the
score of the target label until it is the predicted label.
"""
def __init__(self, model, target_class=0):
super().__init__(model)
self.target_class = target_class
def _is_goal_complete(self, model_output, correct_output):
return (self.target_class == model_output.argmax()) or correct_output == self.target_class
def _get_score(self, model_output, _):
if self.target_class < 0 or self.target_class >= len(model_output):
raise ValueError(f'target class set to {self.target_class} with {len(model_output)} classes.')
else:
return model_output[self.target_class]
def _get_displayed_output(self, raw_output):
return int(raw_output.argmax())
def extra_repr_keys(self):
return ['target_class']