1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix LRU cache bugs

This commit is contained in:
Jack Morris
2020-04-22 15:31:48 -04:00
parent 4887b51fef
commit 6d20d44ac0
4 changed files with 21 additions and 8 deletions

View File

@@ -10,7 +10,7 @@ sentence_transformers
spacy
torch
transformers>=2.5.1
tensorflow-gpu>=2
tensorflow
tensorflow_hub
terminaltables
tqdm

View File

@@ -2,8 +2,10 @@ from textattack.attack_results import AttackResult
from textattack.shared import utils
class FailedAttackResult(AttackResult):
def __init__(self, original_text, original_output):
super().__init__(original_text, original_text, original_output, original_output)
def __init__(self, original_text, original_output, perturbed_text=None, perturbed_output=None):
perturbed_text = perturbed_text or original_text
perturbed_output = perturbed_output or original_output
super().__init__(original_text, perturbed_text, original_output, perturbed_output)
def __data__(self, color_method=None):
data = (self.result_str(color_method), self.original_text.text)

View File

@@ -21,10 +21,12 @@ class PartOfSpeech(Constraint):
def _get_pos(self, before_ctx, word, after_ctx):
context_words = before_ctx + [word] + after_ctx
context_key = ' '.join(context_words)
if context_key not in self._pos_tag_cache:
if context_key in self._pos_tag_cache:
pos_list = self._pos_tag_cache[context_key]
else:
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
self._pos_tag_cache[context_key] = pos_list
return self._pos_tag_cache[context_key]
return pos_list
def __call__(self, x, x_adv, original_text=None):
if not isinstance(x, TokenizedText):

View File

@@ -124,14 +124,23 @@ class GoalFunction:
if not self.use_cache:
return self._call_model_uncached(tokenized_text_list)
else:
uncached_list = []
for text in tokenized_text_list:
if text in self._call_model_cache:
# Re-write value in cache. This moves the key to the top of the
# LRU cache and prevents the unlikely event that the text
# is overwritten when we store the inputs from `uncached_list`.
self._call_model_cache[text] = self._call_model_cache[text]
else:
uncached_list.append(text)
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
scores = self._call_model_uncached(uncached_list)
for text, score in zip(uncached_list, scores):
self._call_model_cache[text] = score.cpu()
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
return torch.stack(final_scores)
final_scores = [self._call_model_cache[text] for text in tokenized_text_list]
return torch.stack(final_scores).to(utils.get_device())
def extra_repr_keys(self):
return []
__repr__ = __str__ = default_class_repr
__repr__ = __str__ = default_class_repr