mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
cache constraint results
This commit is contained in:
@@ -12,10 +12,14 @@ class GoalFunction:
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
"""
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, use_cache=False):
|
||||
self.model = model
|
||||
self.use_cache = use_cache
|
||||
self.num_queries = 0
|
||||
self._call_model_cache = lru.LRU(2**18)
|
||||
if self.use_cache:
|
||||
self._call_model_cache = lru.LRU(2**18)
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, tokenized_text, correct_output):
|
||||
model_outputs = self._call_model([tokenized_text])
|
||||
@@ -53,7 +57,11 @@ class GoalFunction:
|
||||
if not len(tokenized_text_list):
|
||||
return torch.tensor([])
|
||||
ids = [t.ids for t in tokenized_text_list]
|
||||
ids = torch.tensor(ids).to(utils.get_device())
|
||||
if hasattr(self.model, 'model'):
|
||||
model_device = next(self.model.model.parameters()).device
|
||||
else:
|
||||
model_device = next(self.model.parameters()).device
|
||||
ids = torch.tensor(ids).to(model_device)
|
||||
#
|
||||
# shape of `ids` is (n, m, d)
|
||||
# - n: number of elements in `tokenized_text_list`
|
||||
@@ -72,6 +80,8 @@ class GoalFunction:
|
||||
batch = [batch_ids[:, x, :] for x in range(num_fields)]
|
||||
with torch.no_grad():
|
||||
preds = self.model(*batch)
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
scores.append(preds)
|
||||
scores = torch.cat(scores, dim=0)
|
||||
# Validation check on model score dimensions
|
||||
@@ -92,7 +102,9 @@ class GoalFunction:
|
||||
# set of numbers corresponding to probabilities, which should add
|
||||
# up to 1. Since they are `torch.float` values, allow a small
|
||||
# error in the summation.
|
||||
raise ValueError('Model scores do not add up to 1.')
|
||||
scores = torch.nn.functional.softmax(scores, dim=1)
|
||||
if not ((scores.sum(dim=1) - 1).abs() < 1e-6).all():
|
||||
raise ValueError('Model scores do not add up to 1.')
|
||||
return scores
|
||||
|
||||
def _call_model(self, tokenized_text_list):
|
||||
@@ -101,16 +113,19 @@ class GoalFunction:
|
||||
Gets prediction from cache if possible. If prediction is not in the
|
||||
cache, queries model and stores prediction in cache.
|
||||
"""
|
||||
try:
|
||||
self.num_queries += len(tokenized_text_list)
|
||||
except AttributeError:
|
||||
# If some outside class is just using the attack for its `call_model`
|
||||
# function, then `self.num_queries` will not have been initialized.
|
||||
# In this case, just continue.
|
||||
pass
|
||||
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
|
||||
scores = self._call_model_uncached(uncached_list)
|
||||
for text, score in zip(uncached_list, scores):
|
||||
self._call_model_cache[text] = score.cpu()
|
||||
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
|
||||
return torch.stack(final_scores)
|
||||
if not self.use_cache:
|
||||
return self._call_model_uncached(tokenized_text_list)
|
||||
else:
|
||||
try:
|
||||
self.num_queries += len(tokenized_text_list)
|
||||
except AttributeError:
|
||||
# If some outside class is just using the attack for its `call_model`
|
||||
# function, then `self.num_queries` will not have been initialized.
|
||||
# In this case, just continue.
|
||||
pass
|
||||
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
|
||||
scores = self._call_model_uncached(uncached_list)
|
||||
for text, score in zip(uncached_list, scores):
|
||||
self._call_model_cache[text] = score.cpu()
|
||||
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
|
||||
return torch.stack(final_scores)
|
||||
|
||||
Reference in New Issue
Block a user