1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

fix crossover constraint checking

This commit is contained in:
Jin Yong Yoo
2020-06-28 12:03:36 -04:00
parent 86dff2174a
commit 36b52608b2
4 changed files with 104 additions and 194 deletions

View File

@@ -19,17 +19,38 @@ class GeneticAlgorithm(SearchMethod):
Attacks a model with word substiutitions using a genetic algorithm.
Args:
pop_size (:obj:`int`, optional): The population size. Defauls to 20.
max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 50.
pop_size (int): The population size. Defaults to 20.
max_iters (int): The maximum number of iterations to use. Defaults to 50.
temp (float): Temperature for softmax function used to normalize probability dist when sampling parents.
Higher temperature increases the sensitivity to lower probability candidates.
give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found.
max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints.
Setting it to 0 means we immediately take one of the parents at random as the child.
"""
def __init__(
self, pop_size=20, max_iters=50, temp=0.3, give_up_if_no_improvement=False
self,
pop_size=20,
max_iters=50,
temp=0.3,
give_up_if_no_improvement=False,
max_crossover_retries=20,
):
self.max_iters = max_iters
self.pop_size = pop_size
self.temp = temp
self.give_up_if_no_improvement = give_up_if_no_improvement
self.max_crossover_retries = max_crossover_retries
# internal flag to indicate if search should end immediately
self._search_over = False
def __call__(self, initial_result):
if not hasattr(self, "filter_transformations"):
raise AttributeError(
"Search Method must have access to filter_transformations method"
)
return super(GeneticAlgorithm, self).__call__(initial_result)
def _perturb(self, pop_member, original_result):
"""
@@ -59,49 +80,78 @@ class GeneticAlgorithm(SearchMethod):
if not len(transformations):
continue
new_results, search_over = self.get_goal_results(
ransformations, self.correct_output
new_results, self._search_over = self.get_goal_results(
transformations, original_result.output
)
if search_over:
if self._search_over:
break
diff_scores = (
torch.Tensor([r.score for r in new_results]) - original_result.score
)
if len(diff_scores) and diff_scores.max() > 0:
pop_member.attacked_text = transformations[diff_scores.argmax()]
idx = diff_scores.argmax()
pop_member.attacked_text = transformations[idx]
pop_member.results = new_results[idx]
pop_member.num_neighbors_list[rand_idx] = 0
break
num_neighbors_list[rand_idx] = 0
iterations += 1
def _crossover(self, pop_member1, pop_member2):
def _crossover(self, pop_member1, pop_member2, original_result):
"""
Generates a crossover between pop_member1 and pop_member2.
If the child fails to satisfy the constraits, we re-try crossover for a fix number of times,
before taking one of the parents at random as the resulting child.
Args:
pop_member1 (PopulationMember): The first population member.
pop_member2 (PopulationMember): The second population member.
Returns:
A population member containing the crossover.
"""
indices_to_replace = []
words_to_replace = []
x1_text = pop_member1.attacked_text
x2_words = pop_member2.attacked_text.words
num_neighbors_list = np.copy(pop_member1.num_neighbors_list)
for i in range(len(x1_text.words)):
if np.random.uniform() < 0.5:
indices_to_replace.append(i)
words_to_replace.append(x2_words[i])
num_neighbors_list[i] = pop_member2.num_neighbors_list[i]
new_text = x1_text.replace_words_at_indices(
indices_to_replace, words_to_replace
)
num_tries = 0
passed_constraints = False
while num_tries < self.max_crossover_retries:
indices_to_replace = []
words_to_replace = []
num_neighbors_list = np.copy(pop_member1.num_neighbors_list)
for i in range(len(x1_text.words)):
if np.random.uniform() < 0.5:
indices_to_replace.append(i)
words_to_replace.append(x2_words[i])
num_neighbors_list[i] = pop_member2.num_neighbors_list[i]
new_text = x1_text.replace_words_at_indices(
indices_to_replace, words_to_replace
)
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[
"last_transformation"
]
filtered = self.filter_transformations(
[new_text], x1_text, original_text=original_result.attacked_text
)
if filtered:
new_text = filtered[0]
passed_constraints = True
break
num_tries += 1
if not passed_constraints:
new_text = (
pop_member1.attacked_text
if np.random.uniform() < 0.5
else pop_member2.attacked_text
)
return PopulationMember(new_text, num_neighbors_list)
def _initalize_population(self, initial_result):
def _initialize_population(self, initial_result):
"""
Initialize a population of texts each with one word replaced
Args:
@@ -115,7 +165,9 @@ class GeneticAlgorithm(SearchMethod):
initial_result.attacked_text, original_text=initial_result.attacked_text
)
for transformed_text in transformations:
diff_idx = attacked_text.first_word_diff_index(transformed_text)
diff_idx = initial_result.attacked_text.first_word_diff_index(
transformed_text
)
num_neighbors_list[diff_idx] += 1
population = []
@@ -131,34 +183,23 @@ class GeneticAlgorithm(SearchMethod):
return population
def _perform_search(self, initial_result):
self._search_over = False
population = self._initialize_population(initial_result)
current_score = initial_result.score
for i in range(self.max_iters):
pop_results, search_over = self.get_goal_results(
[pm.attacked_text for pm in pop], self.correct_output
)
if search_over:
if len(pop_results) == 0:
return population[0].result
return max(pop_results, key=lambda x: x.score)
for idx, result in enumerate(pop_results):
population[idx].result = result
population = sorted(population, key=lambda x: -x.result.score)
pop_scores = torch.Tensor([r.score for r in pop_results])
logits = ((-pop_scores) / self.temp).exp()
select_probs = (logits / logits.sum()).cpu().numpy()
if population[0].result.succeeded:
return population[0].result
population = sorted(population, key=lambda x: x.result.score, reverse=True)
if self._search_over or population[0].result.succeeded:
break
if population[0].result.score > current_score:
current_score = population[0].result.score
elif self.give_up_if_no_improvement:
break
best_member = population[0]
pop_scores = torch.Tensor([pm.result.score for pm in population])
logits = ((-pop_scores) / self.temp).exp()
select_probs = (logits / logits.sum()).cpu().numpy()
parent1_idx = np.random.choice(
self.pop_size, size=self.pop_size - 1, p=select_probs
)
@@ -169,12 +210,26 @@ class GeneticAlgorithm(SearchMethod):
children = []
for idx in range(self.pop_size - 1):
child = self._crossover(
population[parent1_idx[idx]], population[parent2_idx[idx]]
population[parent1_idx[idx]],
population[parent2_idx[idx]],
initial_result,
)
self._perturb(child, initial_result)
if child.result is None:
# If child.result is not computed for any reason, we compute it here.
result, self._search_over = self.get_goal_results(
[child.attacked_text], initial_result.output
)
child.result = result[0]
children.append(child)
population = [best_member] + children
if self._search_over:
break
if self._search_over:
break
population = [population[0]] + children
return population[0].result
@@ -194,7 +249,7 @@ class PopulationMember:
Args:
attacked_text: The ``AttackedText`` of the population member.
num_neighbors_list: A list of the number of candidate neighbors list for each word.
num_neighbors_list (numpy.array): A list of the number of candidate neighbors list for each word.
"""
def __init__(self, attacked_text, num_neighbors_list, result=None):