mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update genetic algorithm
This commit is contained in:
@@ -26,6 +26,8 @@ class GeneticAlgorithm(SearchMethod):
|
||||
give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found.
|
||||
max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints.
|
||||
Setting it to 0 means we immediately take one of the parents at random as the child.
|
||||
compare_againt_original (bool): If True, the reference text for constraints is the original text.
|
||||
Else, the reference text is the most recent text from which the new text is generated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,23 +37,18 @@ class GeneticAlgorithm(SearchMethod):
|
||||
temp=0.3,
|
||||
give_up_if_no_improvement=False,
|
||||
max_crossover_retries=20,
|
||||
compare_against_original=True,
|
||||
):
|
||||
self.max_iters = max_iters
|
||||
self.pop_size = pop_size
|
||||
self.temp = temp
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
self.max_crossover_retries = max_crossover_retries
|
||||
self.compare_to_original = True
|
||||
|
||||
# internal flag to indicate if search should end immediately
|
||||
self._search_over = False
|
||||
|
||||
def __call__(self, initial_result):
|
||||
if not hasattr(self, "filter_transformations"):
|
||||
raise AttributeError(
|
||||
"Search Method must have access to filter_transformations method"
|
||||
)
|
||||
return super(GeneticAlgorithm, self).__call__(initial_result)
|
||||
|
||||
def _perturb(self, pop_member, original_result):
|
||||
"""
|
||||
Replaces a word in pop_member that has not been modified in place.
|
||||
@@ -61,23 +58,29 @@ class GeneticAlgorithm(SearchMethod):
|
||||
|
||||
Returns: None
|
||||
"""
|
||||
num_words = pop_member.num_neighbors_list.shape[0]
|
||||
num_neighbors_list = np.copy(pop_member.num_neighbors_list)
|
||||
non_zero_indices = np.sum(np.sign(pop_member.num_neighbors_list))
|
||||
num_words = pop_member.num_candidates_per_word.shape[0]
|
||||
num_candidates_per_word = np.copy(pop_member.num_candidates_per_word)
|
||||
non_zero_indices = np.sum(np.sign(pop_member.num_candidates_per_word))
|
||||
if non_zero_indices == 0:
|
||||
return
|
||||
iterations = 0
|
||||
while iterations < non_zero_indices:
|
||||
w_select_probs = num_neighbors_list / np.sum(num_neighbors_list)
|
||||
w_select_probs = num_candidates_per_word / np.sum(num_candidates_per_word)
|
||||
rand_idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=original_result.attacked_text,
|
||||
indices_to_modify=[rand_idx],
|
||||
)
|
||||
if self.compare_to_original:
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=original_result.attacked_text,
|
||||
indices_to_modify=[rand_idx],
|
||||
)
|
||||
else:
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text, indices_to_modify=[rand_idx],
|
||||
)
|
||||
|
||||
if not len(transformations):
|
||||
iterations += 1
|
||||
continue
|
||||
|
||||
new_results, self._search_over = self.get_goal_results(
|
||||
@@ -88,16 +91,16 @@ class GeneticAlgorithm(SearchMethod):
|
||||
break
|
||||
|
||||
diff_scores = (
|
||||
torch.Tensor([r.score for r in new_results]) - original_result.score
|
||||
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
|
||||
)
|
||||
if len(diff_scores) and diff_scores.max() > 0:
|
||||
idx = diff_scores.argmax()
|
||||
pop_member.attacked_text = transformations[idx]
|
||||
pop_member.num_candidates_per_word[rand_idx] = 0
|
||||
pop_member.results = new_results[idx]
|
||||
pop_member.num_neighbors_list[rand_idx] = 0
|
||||
break
|
||||
|
||||
num_neighbors_list[rand_idx] = 0
|
||||
num_candidates_per_word[rand_idx] = 0
|
||||
iterations += 1
|
||||
|
||||
def _crossover(self, pop_member1, pop_member2, original_result):
|
||||
@@ -119,12 +122,12 @@ class GeneticAlgorithm(SearchMethod):
|
||||
while num_tries < self.max_crossover_retries:
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
num_neighbors_list = np.copy(pop_member1.num_neighbors_list)
|
||||
num_candidates_per_word = np.copy(pop_member1.num_candidates_per_word)
|
||||
for i in range(len(x1_text.words)):
|
||||
if np.random.uniform() < 0.5:
|
||||
indices_to_replace.append(i)
|
||||
words_to_replace.append(x2_words[i])
|
||||
num_neighbors_list[i] = pop_member2.num_neighbors_list[i]
|
||||
num_candidates_per_word[i] = pop_member2.num_candidates_per_word[i]
|
||||
new_text = x1_text.replace_words_at_indices(
|
||||
indices_to_replace, words_to_replace
|
||||
)
|
||||
@@ -143,24 +146,30 @@ class GeneticAlgorithm(SearchMethod):
|
||||
num_tries += 1
|
||||
|
||||
if not passed_constraints:
|
||||
# If we cannot find a child that passes the constraints,
|
||||
# we just randomly pick one of the parents to be the child for the next iteration.
|
||||
new_text = (
|
||||
pop_member1.attacked_text
|
||||
if np.random.uniform() < 0.5
|
||||
else pop_member2.attacked_text
|
||||
)
|
||||
|
||||
return PopulationMember(new_text, num_neighbors_list)
|
||||
new_results, self._search_over = self.get_goal_results(
|
||||
[new_text], original_result.output
|
||||
)
|
||||
|
||||
return PopulationMember(new_text, num_candidates_per_word, new_results[0])
|
||||
|
||||
def _initialize_population(self, initial_result):
|
||||
"""
|
||||
Initialize a population of texts each with one word replaced
|
||||
Args:
|
||||
initial_result (GaolFunctionResult): The result to instantiate the population with
|
||||
initial_result (GoalFunctionResult): The result to instantiate the population with
|
||||
Returns:
|
||||
The population.
|
||||
"""
|
||||
words = initial_result.attacked_text.words
|
||||
num_neighbors_list = np.zeros(len(words))
|
||||
num_candidates_per_word = np.zeros(len(words))
|
||||
transformations = self.get_transformations(
|
||||
initial_result.attacked_text, original_text=initial_result.attacked_text
|
||||
)
|
||||
@@ -168,13 +177,20 @@ class GeneticAlgorithm(SearchMethod):
|
||||
diff_idx = initial_result.attacked_text.first_word_diff_index(
|
||||
transformed_text
|
||||
)
|
||||
num_neighbors_list[diff_idx] += 1
|
||||
num_candidates_per_word[diff_idx] += 1
|
||||
|
||||
total_candidates = np.sum(num_candidates_per_word)
|
||||
# Just b/c there are no candidates now doesn't mean we never want to select the word for perturbation
|
||||
# Therefore, we give small non-zero probability for words with no candidates
|
||||
for i in range(len(num_candidates_per_word)):
|
||||
if num_candidates_per_word[i] == 0:
|
||||
num_candidates_per_word[i] = int(total_candidates * 0.01)
|
||||
|
||||
population = []
|
||||
for _ in range(self.pop_size):
|
||||
pop_member = PopulationMember(
|
||||
initial_result.attacked_text,
|
||||
np.copy(num_neighbors_list),
|
||||
np.copy(num_candidates_per_word),
|
||||
initial_result,
|
||||
)
|
||||
# Perturb `pop_member` in-place
|
||||
@@ -214,16 +230,14 @@ class GeneticAlgorithm(SearchMethod):
|
||||
population[parent2_idx[idx]],
|
||||
initial_result,
|
||||
)
|
||||
if self._search_over:
|
||||
break
|
||||
|
||||
self._perturb(child, initial_result)
|
||||
if child.result is None:
|
||||
# If child.result is not computed for any reason, we compute it here.
|
||||
result, self._search_over = self.get_goal_results(
|
||||
[child.attacked_text], initial_result.output
|
||||
)
|
||||
child.result = result[0]
|
||||
children.append(child)
|
||||
|
||||
# We need two `search_over` checks b/c value might change both in
|
||||
# `crossover` method and `perturb` method.
|
||||
if self._search_over:
|
||||
break
|
||||
if self._search_over:
|
||||
@@ -249,10 +263,10 @@ class PopulationMember:
|
||||
|
||||
Args:
|
||||
attacked_text: The ``AttackedText`` of the population member.
|
||||
num_neighbors_list (numpy.array): A list of the number of candidate neighbors list for each word.
|
||||
num_candidates_per_word (numpy.array): A list of the number of candidate neighbors list for each word.
|
||||
"""
|
||||
|
||||
def __init__(self, attacked_text, num_neighbors_list, result=None):
|
||||
def __init__(self, attacked_text, num_candidates_per_word, result):
|
||||
self.attacked_text = attacked_text
|
||||
self.num_neighbors_list = num_neighbors_list
|
||||
self.num_candidates_per_word = num_candidates_per_word
|
||||
self.result = result
|
||||
|
||||
Reference in New Issue
Block a user