mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
PR fix num_queries and default use_cache=True
This commit is contained in:
@@ -12,7 +12,7 @@ class GoalFunction:
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
"""
|
||||
def __init__(self, model, use_cache=False):
|
||||
def __init__(self, model, use_cache=True):
|
||||
self.model = model
|
||||
self.use_cache = use_cache
|
||||
self.num_queries = 0
|
||||
@@ -113,16 +113,16 @@ class GoalFunction:
|
||||
Gets prediction from cache if possible. If prediction is not in the
|
||||
cache, queries model and stores prediction in cache.
|
||||
"""
|
||||
try:
|
||||
self.num_queries += len(tokenized_text_list)
|
||||
except AttributeError:
|
||||
# If some outside class is just using the attack for its `call_model`
|
||||
# function, then `self.num_queries` will not have been initialized.
|
||||
# In this case, just continue.
|
||||
pass
|
||||
if not self.use_cache:
|
||||
return self._call_model_uncached(tokenized_text_list)
|
||||
else:
|
||||
try:
|
||||
self.num_queries += len(tokenized_text_list)
|
||||
except AttributeError:
|
||||
# If some outside class is just using the attack for its `call_model`
|
||||
# function, then `self.num_queries` will not have been initialized.
|
||||
# In this case, just continue.
|
||||
pass
|
||||
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
|
||||
scores = self._call_model_uncached(uncached_list)
|
||||
for text, score in zip(uncached_list, scores):
|
||||
|
||||
Reference in New Issue
Block a user