mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
LRU caching for model and POS tagging
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
filelock
|
||||
language_check
|
||||
lru-dict
|
||||
nltk
|
||||
numpy<1.17
|
||||
pandas
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import lru
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
@@ -39,6 +40,7 @@ class Attack:
|
||||
self.transformation = transformation
|
||||
self.constraints = constraints
|
||||
self.is_black_box = is_black_box
|
||||
self._call_model_cache = lru.LRU(2**12)
|
||||
|
||||
def get_transformations(self, text, original_text=None,
|
||||
apply_constraints=True, **kwargs):
|
||||
@@ -86,7 +88,7 @@ class Attack:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _call_model(self, tokenized_text_list, batch_size=8):
|
||||
def _call_model_uncached(self, tokenized_text_list, batch_size=8):
|
||||
"""
|
||||
Returns model predictions for a list of TokenizedText objects.
|
||||
|
||||
@@ -142,6 +144,14 @@ class Attack:
|
||||
# error in the summation.
|
||||
raise ValueError('Model scores do not add up to 1.')
|
||||
return scores
|
||||
|
||||
def _call_model(self, tokenized_text_list):
|
||||
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
|
||||
scores = self._call_model_uncached(uncached_list)
|
||||
for text, score in zip(uncached_list, scores):
|
||||
self._call_model_cache[text] = score.cpu()
|
||||
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
|
||||
return torch.stack(final_scores)
|
||||
|
||||
def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False):
|
||||
examples = []
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import lru
|
||||
import nltk
|
||||
|
||||
from textattack.constraints import Constraint
|
||||
@@ -10,13 +11,18 @@ class PartOfSpeech(Constraint):
|
||||
def __init__(self, tagset='universal', allow_verb_noun_swap=True):
|
||||
self.tagset = tagset
|
||||
self.allow_verb_noun_swap = allow_verb_noun_swap
|
||||
self._pos_tag_cache = lru.LRU(2**12)
|
||||
|
||||
def _can_replace_pos(self, pos_a, pos_b):
|
||||
return (pos_a == pos_b) or (self.allow_verb_noun_swap and set([pos_a,pos_b]) <= set(['NOUN','VERB']))
|
||||
|
||||
def _get_pos(self, before_ctx, word, after_ctx):
|
||||
_, pos_list = zip(*nltk.pos_tag(before_ctx + [word] + after_ctx, tagset=self.tagset))
|
||||
return pos_list[len(before_ctx)]
|
||||
context_words = before_ctx + [word] + after_ctx
|
||||
context_key = ' '.join(context_words)
|
||||
if context_key not in self._pos_tag_cache:
|
||||
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
|
||||
self._pos_tag_cache[context_key] = pos_list
|
||||
return self._pos_tag_cache[context_key]
|
||||
|
||||
def __call__(self, x, x_adv, original_text=None):
|
||||
if not isinstance(x, TokenizedText):
|
||||
|
||||
Reference in New Issue
Block a user