mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
58 lines
2.3 KiB
Python
58 lines
2.3 KiB
Python
import torch
|
|
|
|
from textattack.goal_function_results import ClassificationGoalFunctionResult
|
|
from textattack.goal_functions import GoalFunction
|
|
|
|
|
|
class ClassificationGoalFunction(GoalFunction):
|
|
""" A goal function defined on a model that outputs a probability for some
|
|
number of classes.
|
|
|
|
model: The PyTorch or TensorFlow model used for evaluation.
|
|
"""
|
|
|
|
def _process_model_outputs(self, inputs, scores):
|
|
""" Processes and validates a list of model outputs.
|
|
|
|
This is a task-dependent operation. For example, classification
|
|
outputs need to have a softmax applied.
|
|
"""
|
|
# Validation check on model score dimensions
|
|
if scores.ndim == 1:
|
|
# Unsqueeze prediction, if it's been squeezed by the model.
|
|
if len(inputs == 1):
|
|
scores = scores.unsqueeze(dim=0)
|
|
else:
|
|
raise ValueError(
|
|
f"Model return score of shape {scores.shape} for {len(inputs)} inputs."
|
|
)
|
|
elif scores.ndim != 2:
|
|
# If model somehow returns too may dimensions, throw an error.
|
|
raise ValueError(
|
|
f"Model return score of shape {scores.shape} for {len(inputs)} inputs."
|
|
)
|
|
elif scores.shape[0] != len(inputs):
|
|
# If model returns an incorrect number of scores, throw an error.
|
|
raise ValueError(
|
|
f"Model return score of shape {scores.shape} for {len(inputs)} inputs."
|
|
)
|
|
elif not ((scores.sum(dim=1) - 1).abs() < 1e-6).all():
|
|
# Values in each row should sum up to 1. The model should return a
|
|
# set of numbers corresponding to probabilities, which should add
|
|
# up to 1. Since they are `torch.float` values, allow a small
|
|
# error in the summation.
|
|
scores = torch.nn.functional.softmax(scores, dim=1)
|
|
if not ((scores.sum(dim=1) - 1).abs() < 1e-6).all():
|
|
raise ValueError("Model scores do not add up to 1.")
|
|
return scores.cpu()
|
|
|
|
def _goal_function_result_type(self):
|
|
""" Returns the class of this goal function's results. """
|
|
return ClassificationGoalFunctionResult
|
|
|
|
def extra_repr_keys(self):
|
|
return []
|
|
|
|
def _get_displayed_output(self, raw_output):
|
|
return int(raw_output.argmax())
|