From cdd9061cb42a88576ace531e5b3d0315d8a57b5d Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Wed, 22 Jul 2020 04:42:44 -0400 Subject: [PATCH] add cache clearing --- .../google_language_model/alzantot_goog_lm.py | 3 +++ textattack/constraints/grammaticality/part_of_speech.py | 3 +++ .../semantics/sentence_encoders/thought_vector.py | 3 +++ textattack/goal_functions/goal_function.py | 4 ++++ textattack/goal_functions/text/minimize_bleu.py | 5 +++++ textattack/goal_functions/text/non_overlapping_output.py | 6 ++++++ textattack/shared/attack.py | 9 +++++++++ 7 files changed, 33 insertions(+) diff --git a/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py b/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py index 071acfdc..5a12fbe8 100644 --- a/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py +++ b/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py @@ -48,6 +48,9 @@ class GoogLMHelper: self.lm_cache = lru.LRU(2 ** 18) + def clear_cache(self): + self.lm_cache.clear() + def get_words_probs_uncached(self, prefix_words, list_words): targets = np.zeros([self.BATCH_SIZE, self.NUM_TIMESTEPS], np.int32) weights = np.ones([self.BATCH_SIZE, self.NUM_TIMESTEPS], np.float32) diff --git a/textattack/constraints/grammaticality/part_of_speech.py b/textattack/constraints/grammaticality/part_of_speech.py index 2f43a345..ce471df5 100644 --- a/textattack/constraints/grammaticality/part_of_speech.py +++ b/textattack/constraints/grammaticality/part_of_speech.py @@ -42,6 +42,9 @@ class PartOfSpeech(Constraint): else: self._flair_pos_tagger = SequenceTagger.load("pos-fast") + def clear_cache(self): + self._pos_tag_cache.clear() + def _can_replace_pos(self, pos_a, pos_b): return (pos_a == pos_b) or ( self.allow_verb_noun_swap and set([pos_a, pos_b]) <= set(["NOUN", "VERB"]) diff --git a/textattack/constraints/semantics/sentence_encoders/thought_vector.py b/textattack/constraints/semantics/sentence_encoders/thought_vector.py index 23b5aa32..8a5a5d83 100644 --- a/textattack/constraints/semantics/sentence_encoders/thought_vector.py +++ b/textattack/constraints/semantics/sentence_encoders/thought_vector.py @@ -21,6 +21,9 @@ class ThoughtVector(SentenceEncoder): self.embedding_type = embedding_type super().__init__(**kwargs) + def clear_cache(self): + self._get_thought_vector.cache_clear() + @functools.lru_cache(maxsize=2 ** 10) def _get_thought_vector(self, text): """Sums the embeddings of all the words in ``text`` into a "thought diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index 19aba392..5c77387c 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -55,6 +55,10 @@ class GoalFunction(ABC): else: self._call_model_cache = None + def clear_cache(self): + if self.use_cache: + self._call_model_cache.clear() + def init_attack_example(self, attacked_text, ground_truth_output): """Called before attacking ``attacked_text`` to 'reset' the goal function and set properties for this example.""" diff --git a/textattack/goal_functions/text/minimize_bleu.py b/textattack/goal_functions/text/minimize_bleu.py index ac0c13f8..3ee2df1b 100644 --- a/textattack/goal_functions/text/minimize_bleu.py +++ b/textattack/goal_functions/text/minimize_bleu.py @@ -26,6 +26,11 @@ class MinimizeBleu(TextToTextGoalFunction): self.target_bleu = target_bleu super().__init__(*args, **kwargs) + def clear_cache(self): + if self.use_cache: + self._call_model_cache.clear() + get_bleu.cache_clear() + def _is_goal_complete(self, model_output, _): bleu_score = 1.0 - self._get_score(model_output, _) return bleu_score <= (self.target_bleu + MinimizeBleu.EPS) diff --git a/textattack/goal_functions/text/non_overlapping_output.py b/textattack/goal_functions/text/non_overlapping_output.py index d872edd5..fdc491a0 100644 --- a/textattack/goal_functions/text/non_overlapping_output.py +++ b/textattack/goal_functions/text/non_overlapping_output.py @@ -14,6 +14,12 @@ class NonOverlappingOutput(TextToTextGoalFunction): (3). """ + def clear_cache(self): + if self.use_cache: + self._call_model_cache.clear() + get_words_cached.cache_clear() + word_difference_score.cache_clear() + def _is_goal_complete(self, model_output, _): return self._get_score(model_output, self.ground_truth_output) == 1.0 diff --git a/textattack/shared/attack.py b/textattack/shared/attack.py index d41849f1..326ee547 100644 --- a/textattack/shared/attack.py +++ b/textattack/shared/attack.py @@ -86,6 +86,14 @@ class Attack: ) self.search_method.filter_transformations = self.filter_transformations + def clear_cache(self, recursive=True): + self.constraints_cache.clear() + if recursive: + self.goal_function.clear_cache() + for constraint in self.constraints: + if hasattr(constraint, "clear_cache"): + constraint.clear_cache() + def get_transformations(self, current_text, original_text=None, **kwargs): """Applies ``self.transformation`` to ``text``, then filters the list of possible transformations through the applicable constraints. @@ -191,6 +199,7 @@ class Attack: or ``MaximizedAttackResult``. """ final_result = self.search_method(initial_result) + self.clear_cache() if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: return SuccessfulAttackResult(initial_result, final_result,) elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING: