1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix model

This commit is contained in:
uvafan
2020-03-05 22:29:07 -05:00
parent bc5e4d57a7
commit 0b6da2a475
4 changed files with 10 additions and 6 deletions

View File

@@ -14,12 +14,16 @@ class GoalFunction:
"""
def __init__(self, model):
self.model = model
self.num_queries = 0
self._call_model_cache = lru.LRU(2**18)
def should_skip(self, tokenized_text, correct_output):
model_outputs = self._call_model([tokenized_text])
return self._is_goal_complete(model_outputs[0], correct_output)
def get_output(self, tokenized_text):
return self._get_displayed_output(self._call_model([tokenized_text])[0])
def get_results(self, tokenized_text_list, correct_output):
"""
For each tokenized_text object in tokenized_text_list, returns a result consisting of whether or not the goal has been achieved, the output for display purposes, and a score.