1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
This commit is contained in:
sherlockyyc
2020-07-07 11:04:03 +08:00
parent 528f2e2f55
commit 29dd7ee6ff
2 changed files with 48 additions and 0 deletions

View File

@@ -36,16 +36,61 @@ class GeneticAlgorithm(SearchMethod):
temp=0.3, temp=0.3,
give_up_if_no_improvement=False, give_up_if_no_improvement=False,
max_crossover_retries=20, max_crossover_retries=20,
improved_genetic_algorithm = False,
max_replace_time_per_index = 5,
): ):
self.max_iters = max_iters self.max_iters = max_iters
self.pop_size = pop_size self.pop_size = pop_size
self.temp = temp self.temp = temp
self.give_up_if_no_improvement = give_up_if_no_improvement self.give_up_if_no_improvement = give_up_if_no_improvement
self.max_crossover_retries = max_crossover_retries self.max_crossover_retries = max_crossover_retries
self.improved_genetic_algorithm = improved_genetic_algorithm
# The maximum times words at the same index can be replaced in improved genetic algorithm
self.max_replace_time_per_index = max_replace_time_per_index
# internal flag to indicate if search should end immediately # internal flag to indicate if search should end immediately
self._search_over = False self._search_over = False
def _replace_at_index(self, pop_member, idx, original_result):
"""
Select the best replacement for word at position (idx)
in (pop_member) to maximize score.
Args:
pop_member: The population member being perturbed.
idx: The index at which to replace a word.
Returns:
Whether a replacement which increased the score was found.
"""
transformations = self.get_transformations(
pop_member.attacked_text,
original_text=original_result.attacked_text,
indices_to_modify=[idx],
)
if not len(transformations):
return False
new_results, self._search_over = self.get_goal_results(
transformations
)
if self._search_over:
return False
diff_scores = (
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
)
if len(diff_scores) and diff_scores.max() > 0:
idx_with_max_score = diff_scores.argmax()
pop_member.attacked_text = transformations[idx_with_max_score]
# For genetic algorithms, the word has been replaced cannot be modified again.
# For improved genetic algorithm, it allows to substitute words at the same index multiple times but not more than `max_replace_time_per_index`.
if self.improved_genetic_algorithm:
pop_member.num_candidates_per_word[idx_with_max_score] -= -1
else:
pop_member.num_candidates_per_word[idx_with_max_score] = 0
pop_member.results = new_results[idx_with_max_score]
return True
return False
def _perturb(self, pop_member, original_result): def _perturb(self, pop_member, original_result):
""" """
Replaces a word in pop_member that has not been modified in place. Replaces a word in pop_member that has not been modified in place.

View File

@@ -41,10 +41,13 @@ class ImprovedGeneticAlgorithm(SearchMethod):
self.max_replaced_times = max_replaced_times self.max_replaced_times = max_replaced_times
self.give_up_if_no_improvement = give_up_if_no_improvement self.give_up_if_no_improvement = give_up_if_no_improvement
self.max_crossover_retries = max_crossover_retries self.max_crossover_retries = max_crossover_retries
# flag to indicate if it is the improved genetic algorithm proposed in Natural Language Adversarial Attacks and Defenses in Word Level by Wang et al.
# internal flag to indicate if search should end immediately # internal flag to indicate if search should end immediately
self._search_over = False self._search_over = False
def _replace_at_index(self, pop_member, idx): def _replace_at_index(self, pop_member, idx):
""" """
Select the best replacement for word at position (idx) Select the best replacement for word at position (idx)