mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
v1
This commit is contained in:
@@ -36,16 +36,61 @@ class GeneticAlgorithm(SearchMethod):
|
||||
temp=0.3,
|
||||
give_up_if_no_improvement=False,
|
||||
max_crossover_retries=20,
|
||||
improved_genetic_algorithm = False,
|
||||
max_replace_time_per_index = 5,
|
||||
):
|
||||
self.max_iters = max_iters
|
||||
self.pop_size = pop_size
|
||||
self.temp = temp
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
self.max_crossover_retries = max_crossover_retries
|
||||
self.improved_genetic_algorithm = improved_genetic_algorithm
|
||||
# The maximum times words at the same index can be replaced in improved genetic algorithm
|
||||
self.max_replace_time_per_index = max_replace_time_per_index
|
||||
|
||||
# internal flag to indicate if search should end immediately
|
||||
self._search_over = False
|
||||
|
||||
def _replace_at_index(self, pop_member, idx, original_result):
|
||||
"""
|
||||
Select the best replacement for word at position (idx)
|
||||
in (pop_member) to maximize score.
|
||||
|
||||
Args:
|
||||
pop_member: The population member being perturbed.
|
||||
idx: The index at which to replace a word.
|
||||
|
||||
Returns:
|
||||
Whether a replacement which increased the score was found.
|
||||
"""
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=original_result.attacked_text,
|
||||
indices_to_modify=[idx],
|
||||
)
|
||||
if not len(transformations):
|
||||
return False
|
||||
new_results, self._search_over = self.get_goal_results(
|
||||
transformations
|
||||
)
|
||||
if self._search_over:
|
||||
return False
|
||||
diff_scores = (
|
||||
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
|
||||
)
|
||||
if len(diff_scores) and diff_scores.max() > 0:
|
||||
idx_with_max_score = diff_scores.argmax()
|
||||
pop_member.attacked_text = transformations[idx_with_max_score]
|
||||
# For genetic algorithms, the word has been replaced cannot be modified again.
|
||||
# For improved genetic algorithm, it allows to substitute words at the same index multiple times but not more than `max_replace_time_per_index`.
|
||||
if self.improved_genetic_algorithm:
|
||||
pop_member.num_candidates_per_word[idx_with_max_score] -= -1
|
||||
else:
|
||||
pop_member.num_candidates_per_word[idx_with_max_score] = 0
|
||||
pop_member.results = new_results[idx_with_max_score]
|
||||
return True
|
||||
return False
|
||||
|
||||
def _perturb(self, pop_member, original_result):
|
||||
"""
|
||||
Replaces a word in pop_member that has not been modified in place.
|
||||
|
||||
@@ -41,10 +41,13 @@ class ImprovedGeneticAlgorithm(SearchMethod):
|
||||
self.max_replaced_times = max_replaced_times
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
self.max_crossover_retries = max_crossover_retries
|
||||
# flag to indicate if it is the improved genetic algorithm proposed in Natural Language Adversarial Attacks and Defenses in Word Level by Wang et al.
|
||||
|
||||
# internal flag to indicate if search should end immediately
|
||||
self._search_over = False
|
||||
|
||||
|
||||
|
||||
def _replace_at_index(self, pop_member, idx):
|
||||
"""
|
||||
Select the best replacement for word at position (idx)
|
||||
|
||||
Reference in New Issue
Block a user