1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

PR fix num_queries and default use_cache=True

This commit is contained in:
Jack Morris
2020-03-26 12:27:46 -04:00
parent c661180d30
commit 8c7d4711ae

View File

@@ -12,7 +12,7 @@ class GoalFunction:
Args:
model: The PyTorch or TensorFlow model used for evaluation.
"""
def __init__(self, model, use_cache=False):
def __init__(self, model, use_cache=True):
self.model = model
self.use_cache = use_cache
self.num_queries = 0
@@ -113,16 +113,16 @@ class GoalFunction:
Gets prediction from cache if possible. If prediction is not in the
cache, queries model and stores prediction in cache.
"""
try:
self.num_queries += len(tokenized_text_list)
except AttributeError:
# If some outside class is just using the attack for its `call_model`
# function, then `self.num_queries` will not have been initialized.
# In this case, just continue.
pass
if not self.use_cache:
return self._call_model_uncached(tokenized_text_list)
else:
try:
self.num_queries += len(tokenized_text_list)
except AttributeError:
# If some outside class is just using the attack for its `call_model`
# function, then `self.num_queries` will not have been initialized.
# In this case, just continue.
pass
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
scores = self._call_model_uncached(uncached_list)
for text, score in zip(uncached_list, scores):