mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge in master and fix syntax errors
This commit is contained in:
@@ -10,6 +10,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
@@ -19,167 +20,209 @@ class GeneticAlgorithm(SearchMethod):
|
||||
Attacks a model with word substiutitions using a genetic algorithm.
|
||||
|
||||
Args:
|
||||
pop_size (:obj:`int`, optional): The population size. Defauls to 20.
|
||||
max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 50.
|
||||
pop_size (int): The population size. Defaults to 20.
|
||||
max_iters (int): The maximum number of iterations to use. Defaults to 50.
|
||||
temp (float): Temperature for softmax function used to normalize probability dist when sampling parents.
|
||||
Higher temperature increases the sensitivity to lower probability candidates.
|
||||
give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found.
|
||||
max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints.
|
||||
Setting it to 0 means we immediately take one of the parents at random as the child.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, pop_size=20, max_iters=50, temp=0.3, give_up_if_no_improvement=False
|
||||
self,
|
||||
pop_size=20,
|
||||
max_iters=50,
|
||||
temp=0.3,
|
||||
give_up_if_no_improvement=False,
|
||||
max_crossover_retries=20,
|
||||
):
|
||||
self.max_iters = max_iters
|
||||
self.pop_size = pop_size
|
||||
self.temp = temp
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
self.search_over = False
|
||||
self.max_crossover_retries = max_crossover_retries
|
||||
|
||||
def _replace_at_index(self, pop_member, idx):
|
||||
# internal flag to indicate if search should end immediately
|
||||
self._search_over = False
|
||||
|
||||
def _perturb(self, pop_member, original_result):
|
||||
"""
|
||||
Select the best replacement for word at position (idx)
|
||||
in (pop_member) to maximize score.
|
||||
|
||||
Replaces a word in pop_member that has not been modified in place.
|
||||
Args:
|
||||
pop_member: The population member being perturbed.
|
||||
idx: The index at which to replace a word.
|
||||
|
||||
Returns:
|
||||
Whether a replacement which increased the score was found.
|
||||
pop_member (PopulationMember): The population member being perturbed.
|
||||
original_result (GoalFunctionResult): Result of original sample being attacked
|
||||
|
||||
Returns: None
|
||||
"""
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=self.original_attacked_text,
|
||||
indices_to_modify=[idx],
|
||||
)
|
||||
if not len(transformations):
|
||||
return False
|
||||
orig_result, self.search_over = self.get_goal_results(
|
||||
[pop_member.attacked_text], self.correct_output
|
||||
)
|
||||
if self.search_over:
|
||||
return False
|
||||
new_x_results, self.search_over = self.get_goal_results(
|
||||
transformations, self.correct_output
|
||||
)
|
||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||
new_x_scores = new_x_scores - orig_result[0].score
|
||||
if len(new_x_scores) and new_x_scores.max() > 0:
|
||||
pop_member.attacked_text = transformations[new_x_scores.argmax()]
|
||||
return True
|
||||
return False
|
||||
|
||||
def _perturb(self, pop_member):
|
||||
"""
|
||||
Replaces a word in pop_member that has not been modified.
|
||||
Args:
|
||||
pop_member: The population member being perturbed.
|
||||
"""
|
||||
x_len = pop_member.neighbors_len.shape[0]
|
||||
neighbors_len = deepcopy(pop_member.neighbors_len)
|
||||
non_zero_indices = np.sum(np.sign(pop_member.neighbors_len))
|
||||
num_words = pop_member.num_candidates_per_word.shape[0]
|
||||
num_candidates_per_word = np.copy(pop_member.num_candidates_per_word)
|
||||
non_zero_indices = np.count_nonzero(num_candidates_per_word)
|
||||
if non_zero_indices == 0:
|
||||
return
|
||||
iterations = 0
|
||||
while iterations < non_zero_indices and not self.search_over:
|
||||
w_select_probs = neighbors_len / np.sum(neighbors_len)
|
||||
rand_idx = np.random.choice(x_len, 1, p=w_select_probs)[0]
|
||||
if self._replace_at_index(pop_member, rand_idx):
|
||||
pop_member.neighbors_len[rand_idx] = 0
|
||||
while iterations < non_zero_indices:
|
||||
w_select_probs = num_candidates_per_word / np.sum(num_candidates_per_word)
|
||||
rand_idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=original_result.attacked_text,
|
||||
indices_to_modify=[rand_idx],
|
||||
)
|
||||
|
||||
if not len(transformations):
|
||||
iterations += 1
|
||||
continue
|
||||
|
||||
new_results, self._search_over = self.get_goal_results(transformations)
|
||||
|
||||
if self._search_over:
|
||||
break
|
||||
neighbors_len[rand_idx] = 0
|
||||
|
||||
diff_scores = (
|
||||
torch.Tensor([r.score for r in new_results]) - pop_member.result.score
|
||||
)
|
||||
if len(diff_scores) and diff_scores.max() > 0:
|
||||
idx = diff_scores.argmax()
|
||||
pop_member.attacked_text = transformations[idx]
|
||||
pop_member.num_candidates_per_word[rand_idx] = 0
|
||||
pop_member.results = new_results[idx]
|
||||
break
|
||||
|
||||
num_candidates_per_word[rand_idx] = 0
|
||||
iterations += 1
|
||||
|
||||
def _generate_population(self, neighbors_len, initial_result):
|
||||
"""
|
||||
Generates a population of texts each with one word replaced
|
||||
Args:
|
||||
neighbors_len: A list of the number of candidate neighbors for each word.
|
||||
initial_result: The result to instantiate the population with
|
||||
Returns:
|
||||
The population.
|
||||
"""
|
||||
pop = []
|
||||
for _ in range(self.pop_size):
|
||||
pop_member = PopulationMember(
|
||||
self.original_attacked_text, deepcopy(neighbors_len), initial_result
|
||||
)
|
||||
self._perturb(pop_member)
|
||||
pop.append(pop_member)
|
||||
return pop
|
||||
|
||||
def _crossover(self, pop_member1, pop_member2):
|
||||
def _crossover(self, pop_member1, pop_member2, original_result):
|
||||
"""
|
||||
Generates a crossover between pop_member1 and pop_member2.
|
||||
If the child fails to satisfy the constraits, we re-try crossover for a fix number of times,
|
||||
before taking one of the parents at random as the resulting child.
|
||||
Args:
|
||||
pop_member1: The first population member.
|
||||
pop_member2: The second population member.
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
Returns:
|
||||
A population member containing the crossover.
|
||||
"""
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
x1_text = pop_member1.attacked_text
|
||||
x2_words = pop_member2.attacked_text.words
|
||||
new_neighbors_len = deepcopy(pop_member1.neighbors_len)
|
||||
for i in range(len(x1_text.words)):
|
||||
if np.random.uniform() < 0.5:
|
||||
indices_to_replace.append(i)
|
||||
words_to_replace.append(x2_words[i])
|
||||
new_neighbors_len[i] = pop_member2.neighbors_len[i]
|
||||
new_text = x1_text.replace_words_at_indices(
|
||||
indices_to_replace, words_to_replace
|
||||
)
|
||||
return PopulationMember(new_text, deepcopy(new_neighbors_len))
|
||||
x2_text = pop_member2.attacked_text
|
||||
x2_words = x2_text.words
|
||||
|
||||
def _get_neighbors_len(self, attacked_text):
|
||||
num_tries = 0
|
||||
passed_constraints = False
|
||||
while num_tries < self.max_crossover_retries + 1:
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
num_candidates_per_word = np.copy(pop_member1.num_candidates_per_word)
|
||||
for i in range(len(x1_text.words)):
|
||||
if np.random.uniform() < 0.5:
|
||||
indices_to_replace.append(i)
|
||||
words_to_replace.append(x2_words[i])
|
||||
num_candidates_per_word[i] = pop_member2.num_candidates_per_word[i]
|
||||
new_text = x1_text.replace_words_at_indices(
|
||||
indices_to_replace, words_to_replace
|
||||
)
|
||||
if "last_transformation" in x1_text.attack_attrs:
|
||||
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[
|
||||
"last_transformation"
|
||||
]
|
||||
filtered = self.filter_transformations(
|
||||
[new_text], x1_text, original_text=original_result.attacked_text
|
||||
)
|
||||
elif "last_transformation" in x2_text.attack_attrs:
|
||||
new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[
|
||||
"last_transformation"
|
||||
]
|
||||
filtered = self.filter_transformations(
|
||||
[new_text], x1_text, original_text=original_result.attacked_text
|
||||
)
|
||||
else:
|
||||
# In this case, neither x_1 nor x_2 has been transformed,
|
||||
# meaning that new_text == original_text
|
||||
filtered = [new_text]
|
||||
|
||||
if filtered:
|
||||
new_text = filtered[0]
|
||||
passed_constraints = True
|
||||
break
|
||||
|
||||
num_tries += 1
|
||||
|
||||
if not passed_constraints:
|
||||
# If we cannot find a child that passes the constraints,
|
||||
# we just randomly pick one of the parents to be the child for the next iteration.
|
||||
new_text = (
|
||||
pop_member1.attacked_text
|
||||
if np.random.uniform() < 0.5
|
||||
else pop_member2.attacked_text
|
||||
)
|
||||
|
||||
new_results, self._search_over = self.get_goal_results([new_text])
|
||||
|
||||
return PopulationMember(new_text, num_candidates_per_word, new_results[0])
|
||||
|
||||
def _initialize_population(self, initial_result):
|
||||
"""
|
||||
Generates this neighbors_len list
|
||||
Initialize a population of texts each with one word replaced
|
||||
Args:
|
||||
attacked_text: The original text
|
||||
initial_result (GoalFunctionResult): The result to instantiate the population with
|
||||
Returns:
|
||||
A list of number of candidate neighbors for each word
|
||||
The population.
|
||||
"""
|
||||
words = attacked_text.words
|
||||
neighbors_list = [[] for _ in range(len(words))]
|
||||
words = initial_result.attacked_text.words
|
||||
num_candidates_per_word = np.zeros(len(words))
|
||||
transformations = self.get_transformations(
|
||||
attacked_text, original_text=self.original_attacked_text
|
||||
initial_result.attacked_text, original_text=initial_result.attacked_text
|
||||
)
|
||||
for transformed_text in transformations:
|
||||
diff_idx = attacked_text.first_word_diff_index(transformed_text)
|
||||
neighbors_list[diff_idx].append(transformed_text.words[diff_idx])
|
||||
neighbors_list = [np.array(x) for x in neighbors_list]
|
||||
neighbors_len = np.array([len(x) for x in neighbors_list])
|
||||
return neighbors_len
|
||||
diff_idx = initial_result.attacked_text.first_word_diff_index(
|
||||
transformed_text
|
||||
)
|
||||
num_candidates_per_word[diff_idx] += 1
|
||||
|
||||
# Just b/c there are no candidates now doesn't mean we never want to select the word for perturbation
|
||||
# Therefore, we give small non-zero probability for words with no candidates
|
||||
# Epsilon is some small number to approximately assign 1% probability
|
||||
num_total_candidates = np.sum(num_candidates_per_word)
|
||||
epsilon = max(1, int(num_total_candidates * 0.01))
|
||||
for i in range(len(num_candidates_per_word)):
|
||||
if num_candidates_per_word[i] == 0:
|
||||
num_candidates_per_word[i] = epsilon
|
||||
|
||||
population = []
|
||||
for _ in range(self.pop_size):
|
||||
pop_member = PopulationMember(
|
||||
initial_result.attacked_text,
|
||||
np.copy(num_candidates_per_word),
|
||||
initial_result,
|
||||
)
|
||||
# Perturb `pop_member` in-place
|
||||
self._perturb(pop_member, initial_result)
|
||||
population.append(pop_member)
|
||||
return population
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
self.original_attacked_text = initial_result.attacked_text
|
||||
self.correct_output = initial_result.output
|
||||
neighbors_len = self._get_neighbors_len(self.original_attacked_text)
|
||||
pop = self._generate_population(neighbors_len, initial_result)
|
||||
cur_score = initial_result.score
|
||||
self._search_over = False
|
||||
population = self._initialize_population(initial_result)
|
||||
current_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results, self.search_over = self.get_goal_results(
|
||||
[pm.attacked_text for pm in pop], self.correct_output
|
||||
)
|
||||
if self.search_over:
|
||||
if not len(pop_results):
|
||||
return pop[0].result
|
||||
return max(pop_results, key=lambda x: x.score)
|
||||
for idx, result in enumerate(pop_results):
|
||||
pop[idx].result = pop_results[idx]
|
||||
pop = sorted(pop, key=lambda x: -x.result.score)
|
||||
population = sorted(population, key=lambda x: x.result.score, reverse=True)
|
||||
if (
|
||||
self._search_over
|
||||
or population[0].result.goal_status
|
||||
== GoalFunctionResultStatus.SUCCEEDED
|
||||
):
|
||||
break
|
||||
|
||||
pop_scores = torch.Tensor([r.score for r in pop_results])
|
||||
logits = ((-pop_scores) / self.temp).exp()
|
||||
select_probs = (logits / logits.sum()).cpu().numpy()
|
||||
|
||||
if pop[0].result.succeeded:
|
||||
return pop[0].result
|
||||
|
||||
if pop[0].result.score > cur_score:
|
||||
cur_score = pop[0].result.score
|
||||
if population[0].result.score > current_score:
|
||||
current_score = population[0].result.score
|
||||
elif self.give_up_if_no_improvement:
|
||||
break
|
||||
|
||||
elite = [pop[0]]
|
||||
pop_scores = torch.Tensor([pm.result.score for pm in population])
|
||||
logits = ((-pop_scores) / self.temp).exp()
|
||||
select_probs = (logits / logits.sum()).cpu().numpy()
|
||||
|
||||
parent1_idx = np.random.choice(
|
||||
self.pop_size, size=self.pop_size - 1, p=select_probs
|
||||
)
|
||||
@@ -187,16 +230,27 @@ class GeneticAlgorithm(SearchMethod):
|
||||
self.pop_size, size=self.pop_size - 1, p=select_probs
|
||||
)
|
||||
|
||||
children = [
|
||||
self._crossover(pop[parent1_idx[idx]], pop[parent2_idx[idx]])
|
||||
for idx in range(self.pop_size - 1)
|
||||
]
|
||||
for c in children:
|
||||
self._perturb(c)
|
||||
children = []
|
||||
for idx in range(self.pop_size - 1):
|
||||
child = self._crossover(
|
||||
population[parent1_idx[idx]],
|
||||
population[parent2_idx[idx]],
|
||||
initial_result,
|
||||
)
|
||||
if self._search_over:
|
||||
break
|
||||
|
||||
pop = elite + children
|
||||
self._perturb(child, initial_result)
|
||||
children.append(child)
|
||||
|
||||
return pop[0].result
|
||||
# We need two `search_over` checks b/c value might change both in
|
||||
# `crossover` method and `perturb` method.
|
||||
if self._search_over:
|
||||
break
|
||||
|
||||
population = [population[0]] + children
|
||||
|
||||
return population[0].result
|
||||
|
||||
def check_transformation_compatibility(self, transformation):
|
||||
"""
|
||||
@@ -214,10 +268,10 @@ class PopulationMember:
|
||||
|
||||
Args:
|
||||
attacked_text: The ``AttackedText`` of the population member.
|
||||
neighbors_len: A list of the number of candidate neighbors list for each word.
|
||||
num_candidates_per_word (numpy.array): A list of the number of candidate neighbors list for each word.
|
||||
"""
|
||||
|
||||
def __init__(self, attacked_text, neighbors_len, result=None):
|
||||
def __init__(self, attacked_text, num_candidates_per_word, result):
|
||||
self.attacked_text = attacked_text
|
||||
self.neighbors_len = neighbors_len
|
||||
self.num_candidates_per_word = num_candidates_per_word
|
||||
self.result = result
|
||||
|
||||
Reference in New Issue
Block a user