1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

LRU caching for model and POS tagging

This commit is contained in:
Jack Morris
2020-02-08 16:28:02 -05:00
parent 846f12f543
commit fdf3fa9a1e
3 changed files with 20 additions and 3 deletions

View File

@@ -1,5 +1,6 @@
filelock
language_check
lru-dict
nltk
numpy<1.17
pandas

View File

@@ -1,3 +1,4 @@
import lru
import math
import numpy as np
import os
@@ -39,6 +40,7 @@ class Attack:
self.transformation = transformation
self.constraints = constraints
self.is_black_box = is_black_box
self._call_model_cache = lru.LRU(2**12)
def get_transformations(self, text, original_text=None,
apply_constraints=True, **kwargs):
@@ -86,7 +88,7 @@ class Attack:
"""
raise NotImplementedError()
def _call_model(self, tokenized_text_list, batch_size=8):
def _call_model_uncached(self, tokenized_text_list, batch_size=8):
"""
Returns model predictions for a list of TokenizedText objects.
@@ -142,6 +144,14 @@ class Attack:
# error in the summation.
raise ValueError('Model scores do not add up to 1.')
return scores
def _call_model(self, tokenized_text_list):
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
scores = self._call_model_uncached(uncached_list)
for text, score in zip(uncached_list, scores):
self._call_model_cache[text] = score.cpu()
final_scores = [self._call_model_cache[text].to(utils.get_device()) for text in tokenized_text_list]
return torch.stack(final_scores)
def _get_examples_from_dataset(self, dataset, num_examples=None, shuffle=False):
examples = []

View File

@@ -1,3 +1,4 @@
import lru
import nltk
from textattack.constraints import Constraint
@@ -10,13 +11,18 @@ class PartOfSpeech(Constraint):
def __init__(self, tagset='universal', allow_verb_noun_swap=True):
self.tagset = tagset
self.allow_verb_noun_swap = allow_verb_noun_swap
self._pos_tag_cache = lru.LRU(2**12)
def _can_replace_pos(self, pos_a, pos_b):
return (pos_a == pos_b) or (self.allow_verb_noun_swap and set([pos_a,pos_b]) <= set(['NOUN','VERB']))
def _get_pos(self, before_ctx, word, after_ctx):
_, pos_list = zip(*nltk.pos_tag(before_ctx + [word] + after_ctx, tagset=self.tagset))
return pos_list[len(before_ctx)]
context_words = before_ctx + [word] + after_ctx
context_key = ' '.join(context_words)
if context_key not in self._pos_tag_cache:
_, pos_list = zip(*nltk.pos_tag(context_words, tagset=self.tagset))
self._pos_tag_cache[context_key] = pos_list
return self._pos_tag_cache[context_key]
def __call__(self, x, x_adv, original_text=None):
if not isinstance(x, TokenizedText):