import torch from textattack.goal_functions import GoalFunction class ClassificationGoalFunction(GoalFunction): """ A goal function defined on a model that outputs a probability for some number of classes. model: The PyTorch or TensorFlow model used for evaluation. """ def _process_model_outputs(self, inputs, outputs): """ Processes and validates a list of model outputs. This is a task-dependent operation. For example, classification outputs need to have a softmax applied. """ scores = torch.cat(outputs, dim=0) # Validation check on model score dimensions if scores.ndim == 1: # Unsqueeze prediction, if it's been squeezed by the model. if len(inputs == 1): scores = scores.unsqueeze(dim=0) else: raise ValueError(f'Model return score of shape {scores.shape} for {len(inputs)} inputs.') elif scores.ndim != 2: # If model somehow returns too may dimensions, throw an error. raise ValueError(f'Model return score of shape {scores.shape} for {len(inputs)} inputs.') elif scores.shape[0] != len(inputs): # If model returns an incorrect number of scores, throw an error. raise ValueError(f'Model return score of shape {scores.shape} for {len(inputs)} inputs.') elif not ((scores.sum(dim=1) - 1).abs() < 1e-6).all(): # Values in each row should sum up to 1. The model should return a # set of numbers corresponding to probabilities, which should add # up to 1. Since they are `torch.float` values, allow a small # error in the summation. scores = torch.nn.functional.softmax(scores, dim=1) if not ((scores.sum(dim=1) - 1).abs() < 1e-6).all(): raise ValueError('Model scores do not add up to 1.') return scores.cpu()