From 29dd7ee6ffae1650f9b86df26eb9baffe0d5567c Mon Sep 17 00:00:00 2001 From: sherlockyyc <994986061@qq.com> Date: Tue, 7 Jul 2020 11:04:03 +0800 Subject: [PATCH] v1 --- .../search_methods/genetic_algorithm.py | 45 +++++++++++++++++++ .../improved_genetic_algorithm.py | 3 ++ 2 files changed, 48 insertions(+) diff --git a/textattack/search_methods/genetic_algorithm.py b/textattack/search_methods/genetic_algorithm.py index c3fda273..ef1baf9e 100644 --- a/textattack/search_methods/genetic_algorithm.py +++ b/textattack/search_methods/genetic_algorithm.py @@ -36,16 +36,61 @@ class GeneticAlgorithm(SearchMethod): temp=0.3, give_up_if_no_improvement=False, max_crossover_retries=20, + improved_genetic_algorithm = False, + max_replace_time_per_index = 5, ): self.max_iters = max_iters self.pop_size = pop_size self.temp = temp self.give_up_if_no_improvement = give_up_if_no_improvement self.max_crossover_retries = max_crossover_retries + self.improved_genetic_algorithm = improved_genetic_algorithm + # The maximum times words at the same index can be replaced in improved genetic algorithm + self.max_replace_time_per_index = max_replace_time_per_index # internal flag to indicate if search should end immediately self._search_over = False + def _replace_at_index(self, pop_member, idx, original_result): + """ + Select the best replacement for word at position (idx) + in (pop_member) to maximize score. + + Args: + pop_member: The population member being perturbed. + idx: The index at which to replace a word. + + Returns: + Whether a replacement which increased the score was found. + """ + transformations = self.get_transformations( + pop_member.attacked_text, + original_text=original_result.attacked_text, + indices_to_modify=[idx], + ) + if not len(transformations): + return False + new_results, self._search_over = self.get_goal_results( + transformations + ) + if self._search_over: + return False + diff_scores = ( + torch.Tensor([r.score for r in new_results]) - pop_member.result.score + ) + if len(diff_scores) and diff_scores.max() > 0: + idx_with_max_score = diff_scores.argmax() + pop_member.attacked_text = transformations[idx_with_max_score] + # For genetic algorithms, the word has been replaced cannot be modified again. + # For improved genetic algorithm, it allows to substitute words at the same index multiple times but not more than `max_replace_time_per_index`. + if self.improved_genetic_algorithm: + pop_member.num_candidates_per_word[idx_with_max_score] -= -1 + else: + pop_member.num_candidates_per_word[idx_with_max_score] = 0 + pop_member.results = new_results[idx_with_max_score] + return True + return False + def _perturb(self, pop_member, original_result): """ Replaces a word in pop_member that has not been modified in place. diff --git a/textattack/search_methods/improved_genetic_algorithm.py b/textattack/search_methods/improved_genetic_algorithm.py index dfc911aa..af767bed 100644 --- a/textattack/search_methods/improved_genetic_algorithm.py +++ b/textattack/search_methods/improved_genetic_algorithm.py @@ -41,10 +41,13 @@ class ImprovedGeneticAlgorithm(SearchMethod): self.max_replaced_times = max_replaced_times self.give_up_if_no_improvement = give_up_if_no_improvement self.max_crossover_retries = max_crossover_retries + # flag to indicate if it is the improved genetic algorithm proposed in Natural Language Adversarial Attacks and Defenses in Word Level by Wang et al. # internal flag to indicate if search should end immediately self._search_over = False + + def _replace_at_index(self, pop_member, idx): """ Select the best replacement for word at position (idx)