1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

merge attack_str with master

This commit is contained in:
Jack Morris
2020-03-27 12:47:25 -04:00
5 changed files with 61 additions and 18 deletions

View File

@@ -13,10 +13,14 @@ class GoalFunction:
Args:
model: The PyTorch or TensorFlow model used for evaluation.
"""
def __init__(self, model):
def __init__(self, model, use_cache=True):
self.model = model
self.use_cache = use_cache
self.num_queries = 0
self._call_model_cache = lru.LRU(2**18)
if self.use_cache:
self._call_model_cache = lru.LRU(2**18)
else:
self._call_model_cache = None
def should_skip(self, tokenized_text, correct_output):
model_outputs = self._call_model([tokenized_text])
@@ -54,7 +58,11 @@ class GoalFunction:
if not len(tokenized_text_list):
return torch.tensor([])
ids = [t.ids for t in tokenized_text_list]
ids = torch.tensor(ids).to(utils.get_device())
if hasattr(self.model, 'model'):
model_device = next(self.model.model.parameters()).device
else:
model_device = next(self.model.parameters()).device
ids = torch.tensor(ids).to(model_device)
#
# shape of `ids` is (n, m, d)
# - n: number of elements in `tokenized_text_list`
@@ -73,6 +81,8 @@ class GoalFunction:
batch = [batch_ids[:, x, :] for x in range(num_fields)]
with torch.no_grad():
preds = self.model(*batch)
if isinstance(preds, tuple):
preds = preds[0]
scores.append(preds)
scores = torch.cat(scores, dim=0)
# Validation check on model score dimensions
@@ -93,7 +103,9 @@ class GoalFunction:
# set of numbers corresponding to probabilities, which should add
# up to 1. Since they are `torch.float` values, allow a small
# error in the summation.
raise ValueError('Model scores do not add up to 1.')
scores = torch.nn.functional.softmax(scores, dim=1)
if not ((scores.sum(dim=1) - 1).abs() < 1e-6).all():
raise ValueError('Model scores do not add up to 1.')
return scores
def _call_model(self, tokenized_text_list):
@@ -109,14 +121,17 @@ class GoalFunction:
# function, then `self.num_queries` will not have been initialized.
# In this case, just continue.
pass
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
scores = self._call_model_uncached(uncached_list)
for text, score in zip(uncached_list, scores):
self._call_model_cache[text] = score.cpu()
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
return torch.stack(final_scores)
if not self.use_cache:
return self._call_model_uncached(tokenized_text_list)
else:
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
scores = self._call_model_uncached(uncached_list)
for text, score in zip(uncached_list, scores):
self._call_model_cache[text] = score.cpu()
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
return torch.stack(final_scores)
def extra_repr_keys(self):
return []
__repr__ = __str__ = default_class_repr
__repr__ = __str__ = default_class_repr