mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
v1
This commit is contained in:
@@ -36,16 +36,61 @@ class GeneticAlgorithm(SearchMethod):
|
|||||||
temp=0.3,
|
temp=0.3,
|
||||||
give_up_if_no_improvement=False,
|
give_up_if_no_improvement=False,
|
||||||
max_crossover_retries=20,
|
max_crossover_retries=20,
|
||||||
|
improved_genetic_algorithm = False,
|
||||||
|
max_replace_time_per_index = 5,
|
||||||
):
|
):
|
||||||
self.max_iters = max_iters
|
self.max_iters = max_iters
|
||||||
self.pop_size = pop_size
|
self.pop_size = pop_size
|
||||||
self.temp = temp
|
self.temp = temp
|
||||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||||
self.max_crossover_retries = max_crossover_retries
|
self.max_crossover_retries = max_crossover_retries
|
||||||
|
self.improved_genetic_algorithm = improved_genetic_algorithm
|
||||||
|
# The maximum times words at the same index can be replaced in improved genetic algorithm
|
||||||
|
self.max_replace_time_per_index = max_replace_time_per_index
|
||||||
|
|
||||||
# internal flag to indicate if search should end immediately
|
# internal flag to indicate if search should end immediately
|
||||||
self._search_over = False
|
self._search_over = False
|
||||||
|
|
||||||
|
def _replace_at_index(self, pop_member, idx, original_result):
|
||||||
|
"""
|
||||||
|
Select the best replacement for word at position (idx)
|
||||||
|
in (pop_member) to maximize score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pop_member: The population member being perturbed.
|
||||||
|
idx: The index at which to replace a word.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether a replacement which increased the score was found.
|
||||||
|
"""
|
||||||
|
transformations = self.get_transformations(
|
||||||
|
pop_member.attacked_text,
|
||||||
|
original_text=original_result.attacked_text,
|
||||||
|
indices_to_modify=[idx],
|
||||||
|
)
|
||||||
|
if not len(transformations):
|
||||||
|
return False
|
||||||
|
new_results, self._search_over = self.get_goal_results(
|
||||||
|
transformations
|
||||||
|
)
|
||||||
|
if self._search_over:
|
||||||
|
return False
|
||||||
|
diff_scores = (
|
||||||
|
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
|
||||||
|
)
|
||||||
|
if len(diff_scores) and diff_scores.max() > 0:
|
||||||
|
idx_with_max_score = diff_scores.argmax()
|
||||||
|
pop_member.attacked_text = transformations[idx_with_max_score]
|
||||||
|
# For genetic algorithms, the word has been replaced cannot be modified again.
|
||||||
|
# For improved genetic algorithm, it allows to substitute words at the same index multiple times but not more than `max_replace_time_per_index`.
|
||||||
|
if self.improved_genetic_algorithm:
|
||||||
|
pop_member.num_candidates_per_word[idx_with_max_score] -= -1
|
||||||
|
else:
|
||||||
|
pop_member.num_candidates_per_word[idx_with_max_score] = 0
|
||||||
|
pop_member.results = new_results[idx_with_max_score]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _perturb(self, pop_member, original_result):
|
def _perturb(self, pop_member, original_result):
|
||||||
"""
|
"""
|
||||||
Replaces a word in pop_member that has not been modified in place.
|
Replaces a word in pop_member that has not been modified in place.
|
||||||
|
|||||||
@@ -41,10 +41,13 @@ class ImprovedGeneticAlgorithm(SearchMethod):
|
|||||||
self.max_replaced_times = max_replaced_times
|
self.max_replaced_times = max_replaced_times
|
||||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||||
self.max_crossover_retries = max_crossover_retries
|
self.max_crossover_retries = max_crossover_retries
|
||||||
|
# flag to indicate if it is the improved genetic algorithm proposed in Natural Language Adversarial Attacks and Defenses in Word Level by Wang et al.
|
||||||
|
|
||||||
# internal flag to indicate if search should end immediately
|
# internal flag to indicate if search should end immediately
|
||||||
self._search_over = False
|
self._search_over = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_at_index(self, pop_member, idx):
|
def _replace_at_index(self, pop_member, idx):
|
||||||
"""
|
"""
|
||||||
Select the best replacement for word at position (idx)
|
Select the best replacement for word at position (idx)
|
||||||
|
|||||||
Reference in New Issue
Block a user