mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
allow maximization goal functions
This commit is contained in:
@@ -4,26 +4,29 @@ import lru
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.goal_function_results.goal_function_result import GoalFunctionResultStatus
|
||||
from textattack.shared import utils, validators
|
||||
from textattack.shared.utils import batch_model_predict, default_class_repr
|
||||
|
||||
|
||||
class GoalFunction:
|
||||
"""
|
||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
maximizable: Whether the goal function is maximizable, as opposed to a boolean result
|
||||
of success or failure.
|
||||
query_budget: The maximum number of model queries allowed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model, tokenizer=None, use_cache=True, query_budget=float("inf")
|
||||
self, model, maximizable=False, tokenizer=None, use_cache=True, query_budget=float("inf")
|
||||
):
|
||||
validators.validate_model_goal_function_compatibility(
|
||||
self.__class__, model.__class__
|
||||
)
|
||||
self.model = model
|
||||
self.maximizable = maximizable
|
||||
self.tokenizer = tokenizer
|
||||
if not self.tokenizer:
|
||||
if hasattr(self.model, "tokenizer"):
|
||||
@@ -33,20 +36,16 @@ class GoalFunction:
|
||||
if not hasattr(self.tokenizer, "encode"):
|
||||
raise TypeError("Tokenizer must contain `encode()` method")
|
||||
self.use_cache = use_cache
|
||||
self.num_queries = 0
|
||||
self.query_budget = query_budget
|
||||
if self.use_cache:
|
||||
self._call_model_cache = lru.LRU(utils.config("MODEL_CACHE_SIZE"))
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``\,
|
||||
due to misprediction by the model.
|
||||
"""
|
||||
model_outputs = self._call_model([attacked_text])
|
||||
return self._is_goal_complete(model_outputs[0], ground_truth_output)
|
||||
def init_attack_example(self, attacked_text, ground_truth_output):
|
||||
self.initial_attacked_text = attacked_text
|
||||
self.ground_truth_output = ground_truth_output
|
||||
self.num_queries = 0
|
||||
|
||||
def get_output(self, attacked_text):
|
||||
"""
|
||||
@@ -54,16 +53,16 @@ class GoalFunction:
|
||||
"""
|
||||
return self._get_displayed_output(self._call_model([attacked_text])[0])
|
||||
|
||||
def get_result(self, attacked_text, ground_truth_output):
|
||||
def get_result(self, attacked_text):
|
||||
"""
|
||||
A helper method that queries `self.get_results` with a single
|
||||
``AttackedText`` object.
|
||||
"""
|
||||
results, search_over = self.get_results([attacked_text], ground_truth_output)
|
||||
results, search_over = self.get_results([attacked_text])
|
||||
result = results[0] if len(results) else None
|
||||
return result, search_over
|
||||
|
||||
def get_results(self, attacked_text_list, ground_truth_output):
|
||||
def get_results(self, attacked_text_list):
|
||||
"""
|
||||
For each attacked_text object in attacked_text_list, returns a result
|
||||
consisting of whether or not the goal has been achieved, the output for
|
||||
@@ -78,23 +77,32 @@ class GoalFunction:
|
||||
model_outputs = self._call_model(attacked_text_list)
|
||||
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
|
||||
displayed_output = self._get_displayed_output(raw_output)
|
||||
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
|
||||
goal_function_score = self._get_score(raw_output, ground_truth_output)
|
||||
goal_status = self._get_goal_status(raw_output)
|
||||
goal_function_score = self._get_score(raw_output)
|
||||
results.append(
|
||||
self._goal_function_result_type()(
|
||||
attacked_text,
|
||||
raw_output,
|
||||
displayed_output,
|
||||
succeeded,
|
||||
goal_status,
|
||||
goal_function_score,
|
||||
self.num_queries,
|
||||
self.ground_truth_output,
|
||||
)
|
||||
)
|
||||
return results, self.num_queries == self.query_budget
|
||||
|
||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||
def _get_goal_status(self, model_output):
|
||||
if self.maximizable:
|
||||
return GoalFunctionResultStatus.MAXIMIZING
|
||||
if self._is_goal_complete(model_output):
|
||||
return GoalFunctionResultStatus.SUCCEEDED
|
||||
return GoalFunctionResultStatus.SEARCHING
|
||||
|
||||
def _is_goal_complete(self, model_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_score(self, model_output, ground_truth_output):
|
||||
def _get_score(self, model_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_displayed_output(self, raw_output):
|
||||
|
||||
Reference in New Issue
Block a user