1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge pull request #78 from QData/fix-cache-pattern

fix LRU cache bugs
This commit is contained in:
Jack Morris
2020-04-22 15:32:10 -04:00
committed by GitHub
3 changed files with 17 additions and 6 deletions

View File

@@ -11,7 +11,7 @@ sentence_transformers
spacy
torch
transformers>=2.5.1
tensorflow-gpu>=2
tensorflow
tensorflow_hub
terminaltables
tqdm

View File

@@ -21,10 +21,12 @@ class PartOfSpeech(Constraint):
def _get_pos(self, before_ctx, word, after_ctx):
context_words = before_ctx + [word] + after_ctx
context_key = ' '.join(context_words)
if context_key not in self._pos_tag_cache:
if context_key in self._pos_tag_cache:
pos_list = self._pos_tag_cache[context_key]
else:
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
self._pos_tag_cache[context_key] = pos_list
return self._pos_tag_cache[context_key]
return pos_list
def __call__(self, x, x_adv, original_text=None):
if not isinstance(x, TokenizedText):

View File

@@ -124,14 +124,23 @@ class GoalFunction:
if not self.use_cache:
return self._call_model_uncached(tokenized_text_list)
else:
uncached_list = []
for text in tokenized_text_list:
if text in self._call_model_cache:
# Re-write value in cache. This moves the key to the top of the
# LRU cache and prevents the unlikely event that the text
# is overwritten when we store the inputs from `uncached_list`.
self._call_model_cache[text] = self._call_model_cache[text]
else:
uncached_list.append(text)
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
scores = self._call_model_uncached(uncached_list)
for text, score in zip(uncached_list, scores):
self._call_model_cache[text] = score.cpu()
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
return torch.stack(final_scores)
final_scores = [self._call_model_cache[text] for text in tokenized_text_list]
return torch.stack(final_scores).to(utils.get_device())
def extra_repr_keys(self):
return []
__repr__ = __str__ = default_class_repr
__repr__ = __str__ = default_class_repr