mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
fix LRU cache bugs
This commit is contained in:
@@ -10,7 +10,7 @@ sentence_transformers
|
||||
spacy
|
||||
torch
|
||||
transformers>=2.5.1
|
||||
tensorflow-gpu>=2
|
||||
tensorflow
|
||||
tensorflow_hub
|
||||
terminaltables
|
||||
tqdm
|
||||
|
||||
@@ -2,8 +2,10 @@ from textattack.attack_results import AttackResult
|
||||
from textattack.shared import utils
|
||||
|
||||
class FailedAttackResult(AttackResult):
|
||||
def __init__(self, original_text, original_output):
|
||||
super().__init__(original_text, original_text, original_output, original_output)
|
||||
def __init__(self, original_text, original_output, perturbed_text=None, perturbed_output=None):
|
||||
perturbed_text = perturbed_text or original_text
|
||||
perturbed_output = perturbed_output or original_output
|
||||
super().__init__(original_text, perturbed_text, original_output, perturbed_output)
|
||||
|
||||
def __data__(self, color_method=None):
|
||||
data = (self.result_str(color_method), self.original_text.text)
|
||||
|
||||
@@ -21,10 +21,12 @@ class PartOfSpeech(Constraint):
|
||||
def _get_pos(self, before_ctx, word, after_ctx):
|
||||
context_words = before_ctx + [word] + after_ctx
|
||||
context_key = ' '.join(context_words)
|
||||
if context_key not in self._pos_tag_cache:
|
||||
if context_key in self._pos_tag_cache:
|
||||
pos_list = self._pos_tag_cache[context_key]
|
||||
else:
|
||||
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
|
||||
self._pos_tag_cache[context_key] = pos_list
|
||||
return self._pos_tag_cache[context_key]
|
||||
return pos_list
|
||||
|
||||
def __call__(self, x, x_adv, original_text=None):
|
||||
if not isinstance(x, TokenizedText):
|
||||
|
||||
@@ -124,14 +124,23 @@ class GoalFunction:
|
||||
if not self.use_cache:
|
||||
return self._call_model_uncached(tokenized_text_list)
|
||||
else:
|
||||
uncached_list = []
|
||||
for text in tokenized_text_list:
|
||||
if text in self._call_model_cache:
|
||||
# Re-write value in cache. This moves the key to the top of the
|
||||
# LRU cache and prevents the unlikely event that the text
|
||||
# is overwritten when we store the inputs from `uncached_list`.
|
||||
self._call_model_cache[text] = self._call_model_cache[text]
|
||||
else:
|
||||
uncached_list.append(text)
|
||||
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
|
||||
scores = self._call_model_uncached(uncached_list)
|
||||
for text, score in zip(uncached_list, scores):
|
||||
self._call_model_cache[text] = score.cpu()
|
||||
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
|
||||
return torch.stack(final_scores)
|
||||
final_scores = [self._call_model_cache[text] for text in tokenized_text_list]
|
||||
return torch.stack(final_scores).to(utils.get_device())
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return []
|
||||
|
||||
__repr__ = __str__ = default_class_repr
|
||||
__repr__ = __str__ = default_class_repr
|
||||
|
||||
Reference in New Issue
Block a user