mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update attackedtext references, need to update tokenization
This commit is contained in:
@@ -10,7 +10,7 @@ from textattack.shared.utils import batch_model_predict, default_class_repr
|
||||
|
||||
class GoalFunction:
|
||||
"""
|
||||
Evaluates how well a perturbed tokenized_text object is achieving a specified goal.
|
||||
Evaluates how well a perturbed attacked_text object is achieving a specified goal.
|
||||
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
@@ -40,32 +40,32 @@ class GoalFunction:
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, tokenized_text, ground_truth_output):
|
||||
def should_skip(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
Returns whether or not the goal has already been completed for ``tokenized_text``\,
|
||||
Returns whether or not the goal has already been completed for ``attacked_text``\,
|
||||
due to misprediction by the model.
|
||||
"""
|
||||
model_outputs = self._call_model([tokenized_text])
|
||||
model_outputs = self._call_model([attacked_text])
|
||||
return self._is_goal_complete(model_outputs[0], ground_truth_output)
|
||||
|
||||
def get_output(self, tokenized_text):
|
||||
def get_output(self, attacked_text):
|
||||
"""
|
||||
Returns output for display based on the result of calling the model.
|
||||
"""
|
||||
return self._get_displayed_output(self._call_model([tokenized_text])[0])
|
||||
return self._get_displayed_output(self._call_model([attacked_text])[0])
|
||||
|
||||
def get_result(self, tokenized_text, ground_truth_output):
|
||||
def get_result(self, attacked_text, ground_truth_output):
|
||||
"""
|
||||
A helper method that queries `self.get_results` with a single
|
||||
``AttackedText`` object.
|
||||
"""
|
||||
results, search_over = self.get_results([tokenized_text], ground_truth_output)
|
||||
results, search_over = self.get_results([attacked_text], ground_truth_output)
|
||||
result = results[0] if len(results) else None
|
||||
return result, search_over
|
||||
|
||||
def get_results(self, tokenized_text_list, ground_truth_output):
|
||||
def get_results(self, attacked_text_list, ground_truth_output):
|
||||
"""
|
||||
For each tokenized_text object in tokenized_text_list, returns a result
|
||||
For each attacked_text object in attacked_text_list, returns a result
|
||||
consisting of whether or not the goal has been achieved, the output for
|
||||
display purposes, and a score. Additionally returns whether the search
|
||||
is over due to the query budget.
|
||||
@@ -73,16 +73,16 @@ class GoalFunction:
|
||||
results = []
|
||||
if self.query_budget < float("inf"):
|
||||
queries_left = self.query_budget - self.num_queries
|
||||
tokenized_text_list = tokenized_text_list[:queries_left]
|
||||
self.num_queries += len(tokenized_text_list)
|
||||
model_outputs = self._call_model(tokenized_text_list)
|
||||
for tokenized_text, raw_output in zip(tokenized_text_list, model_outputs):
|
||||
attacked_text_list = attacked_text_list[:queries_left]
|
||||
self.num_queries += len(attacked_text_list)
|
||||
model_outputs = self._call_model(attacked_text_list)
|
||||
for attacked_text, raw_output in zip(attacked_text_list, model_outputs):
|
||||
displayed_output = self._get_displayed_output(raw_output)
|
||||
succeeded = self._is_goal_complete(raw_output, ground_truth_output)
|
||||
goal_function_score = self._get_score(raw_output, ground_truth_output)
|
||||
results.append(
|
||||
self._goal_function_result_type()(
|
||||
tokenized_text, displayed_output, succeeded, goal_function_score
|
||||
attacked_text, displayed_output, succeeded, goal_function_score
|
||||
)
|
||||
)
|
||||
return results, self.num_queries == self.query_budget
|
||||
@@ -111,31 +111,31 @@ class GoalFunction:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _call_model_uncached(self, tokenized_text_list):
|
||||
def _call_model_uncached(self, attacked_text_list):
|
||||
"""
|
||||
Queries model and returns outputs for a list of AttackedText
|
||||
objects.
|
||||
"""
|
||||
if not len(tokenized_text_list):
|
||||
if not len(attacked_text_list):
|
||||
return []
|
||||
ids = utils.batch_tokenize(self.tokenizer, tokenized_text_list)
|
||||
ids = utils.batch_tokenize(self.tokenizer, attacked_text_list)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = batch_model_predict(self.model, ids)
|
||||
|
||||
return self._process_model_outputs(tokenized_text_list, outputs)
|
||||
return self._process_model_outputs(attacked_text_list, outputs)
|
||||
|
||||
def _call_model(self, tokenized_text_list):
|
||||
def _call_model(self, attacked_text_list):
|
||||
""" Gets predictions for a list of `AttackedText` objects.
|
||||
|
||||
Gets prediction from cache if possible. If prediction is not in the
|
||||
cache, queries model and stores prediction in cache.
|
||||
"""
|
||||
if not self.use_cache:
|
||||
return self._call_model_uncached(tokenized_text_list)
|
||||
return self._call_model_uncached(attacked_text_list)
|
||||
else:
|
||||
uncached_list = []
|
||||
for text in tokenized_text_list:
|
||||
for text in attacked_text_list:
|
||||
if text in self._call_model_cache:
|
||||
# Re-write value in cache. This moves the key to the top of the
|
||||
# LRU cache and prevents the unlikely event that the text
|
||||
@@ -145,13 +145,13 @@ class GoalFunction:
|
||||
uncached_list.append(text)
|
||||
uncached_list = [
|
||||
text
|
||||
for text in tokenized_text_list
|
||||
for text in attacked_text_list
|
||||
if text not in self._call_model_cache
|
||||
]
|
||||
outputs = self._call_model_uncached(uncached_list)
|
||||
for text, output in zip(uncached_list, outputs):
|
||||
self._call_model_cache[text] = output
|
||||
all_outputs = [self._call_model_cache[text] for text in tokenized_text_list]
|
||||
all_outputs = [self._call_model_cache[text] for text in attacked_text_list]
|
||||
return all_outputs
|
||||
|
||||
def extra_repr_keys(self):
|
||||
|
||||
Reference in New Issue
Block a user