mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
refactor genetic algorithms to extend from same class
This commit is contained in:
@@ -1,11 +1,4 @@
|
||||
"""Reimplementation of search method from Generating Natural Language
|
||||
Adversarial Examples by Alzantot et.
|
||||
|
||||
al `<arxiv.org/abs/1804.07998>`_
|
||||
`<github.com/nesl/nlp_adversarial_examples>`_
|
||||
"""
|
||||
|
||||
# from copy import deepcopy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -15,8 +8,9 @@ from textattack.search_methods import PopulationBasedSearch, PopulationMember
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
|
||||
class GeneticAlgorithm(PopulationBasedSearch):
|
||||
"""Attacks a model with word substiutitions using a genetic algorithm.
|
||||
class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
"""Base class for attacking a model with word substiutitions using a
|
||||
genetic algorithm.
|
||||
|
||||
Args:
|
||||
pop_size (int): The population size. Defaults to 20.
|
||||
@@ -49,15 +43,23 @@ class GeneticAlgorithm(PopulationBasedSearch):
|
||||
# internal flag to indicate if search should end immediately
|
||||
self._search_over = False
|
||||
|
||||
def _perturb(self, pop_member, original_result):
|
||||
"""Perturb `pop_member` in-place.
|
||||
@abstractmethod
|
||||
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
|
||||
"""Modify `pop_member` by returning a new copy with `new_text`,
|
||||
`new_result`, and `num_replacements_per_word` altered appropriately for
|
||||
given `word_idx`"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _perturb(self, pop_member, original_result, index=None):
|
||||
"""Perturb `pop_member` and return it. Replaces a word at a random
|
||||
(unless `index` is specified) in `pop_member`.
|
||||
|
||||
Replaces a word at a random in `pop_member` with replacement word that maximizes increase in score.
|
||||
Args:
|
||||
pop_member (PopulationMember): The population member being perturbed.
|
||||
original_result (GoalFunctionResult): Result of original sample being attacked
|
||||
index (int): Index of word to perturb.
|
||||
Returns:
|
||||
`True` if perturbation occurred. `False` if not.
|
||||
Perturbed `PopulationMember`
|
||||
"""
|
||||
num_words = pop_member.num_replacements_per_word.shape[0]
|
||||
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
|
||||
@@ -66,10 +68,13 @@ class GeneticAlgorithm(PopulationBasedSearch):
|
||||
return False
|
||||
iterations = 0
|
||||
while iterations < non_zero_indices:
|
||||
w_select_probs = num_replacements_per_word / np.sum(
|
||||
num_replacements_per_word
|
||||
)
|
||||
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
if index:
|
||||
idx = index
|
||||
else:
|
||||
w_select_probs = num_replacements_per_word / np.sum(
|
||||
num_replacements_per_word
|
||||
)
|
||||
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
transformed_texts = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
@@ -91,19 +96,35 @@ class GeneticAlgorithm(PopulationBasedSearch):
|
||||
)
|
||||
if len(diff_scores) and diff_scores.max() > 0:
|
||||
idx_with_max_score = diff_scores.argmax()
|
||||
pop_member.attacked_text = transformed_texts[idx_with_max_score]
|
||||
pop_member.results = new_results[idx_with_max_score]
|
||||
pop_member.num_replacements_per_word[idx] = 0
|
||||
return True
|
||||
pop_member = self._modify_population_member(
|
||||
pop_member,
|
||||
transformed_texts[idx_with_max_score],
|
||||
new_results[idx_with_max_score],
|
||||
idx,
|
||||
)
|
||||
return pop_member
|
||||
|
||||
num_replacements_per_word[idx] = 0
|
||||
iterations += 1
|
||||
return False
|
||||
return pop_member
|
||||
|
||||
@abstractmethod
|
||||
def _crossover_operation(self, pop_member1, pop_member2):
|
||||
"""Actual operation for generating crossover between pop_member1 and
|
||||
pop_member2.
|
||||
|
||||
Args:
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
Returns:
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _crossover(self, pop_member1, pop_member2, original_text):
|
||||
"""Generates a crossover between pop_member1 and pop_member2.
|
||||
|
||||
If the child fails to satisfy the constraits, we re-try crossover for a fix number of times,
|
||||
If the child fails to satisfy the constraints, we re-try crossover for a fix number of times,
|
||||
before taking one of the parents at random as the resulting child.
|
||||
Args:
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
@@ -114,30 +135,18 @@ class GeneticAlgorithm(PopulationBasedSearch):
|
||||
"""
|
||||
x1_text = pop_member1.attacked_text
|
||||
x2_text = pop_member2.attacked_text
|
||||
x2_words = x2_text.words
|
||||
|
||||
num_tries = 0
|
||||
passed_constraints = False
|
||||
while num_tries < self.max_crossover_retries + 1:
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
num_replacements_per_word = np.copy(pop_member1.num_replacements_per_word)
|
||||
|
||||
for i in range(len(x1_text.words)):
|
||||
if np.random.uniform() < 0.5:
|
||||
indices_to_replace.append(i)
|
||||
words_to_replace.append(x2_words[i])
|
||||
num_replacements_per_word[
|
||||
i
|
||||
] = pop_member2.num_replacements_per_word[i]
|
||||
|
||||
new_text = x1_text.replace_words_at_indices(
|
||||
indices_to_replace, words_to_replace
|
||||
new_text, num_replacements_per_word = self._crossover_operation(
|
||||
pop_member1, pop_member2
|
||||
)
|
||||
indices_to_replace = set(indices_to_replace)
|
||||
|
||||
replaced_indices = new_text.attack_attrs["newly_modified_indices"]
|
||||
new_text.attack_attrs["modified_indices"] = (
|
||||
x1_text.attack_attrs["modified_indices"] - indices_to_replace
|
||||
) | (x2_text.attack_attrs["modified_indices"] & indices_to_replace)
|
||||
x1_text.attack_attrs["modified_indices"] - replaced_indices
|
||||
) | (x2_text.attack_attrs["modified_indices"] & replaced_indices)
|
||||
|
||||
if "last_transformation" in x1_text.attack_attrs:
|
||||
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[
|
||||
@@ -179,10 +188,11 @@ class GeneticAlgorithm(PopulationBasedSearch):
|
||||
new_results, self._search_over = self.get_goal_results([new_text])
|
||||
return PopulationMember(
|
||||
new_text,
|
||||
new_results[0],
|
||||
result=new_results[0],
|
||||
num_replacements_per_word=num_replacements_per_word,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_population(self, initial_result, pop_size):
|
||||
"""
|
||||
Initialize a population of size `pop_size` with `initial_result`
|
||||
@@ -192,45 +202,17 @@ class GeneticAlgorithm(PopulationBasedSearch):
|
||||
Returns:
|
||||
population as `list[PopulationMember]`
|
||||
"""
|
||||
words = initial_result.attacked_text.words
|
||||
num_replacements_per_word = np.zeros(len(words))
|
||||
transformed_texts = self.get_transformations(
|
||||
initial_result.attacked_text, original_text=initial_result.attacked_text
|
||||
)
|
||||
for transformed_text in transformed_texts:
|
||||
diff_idx = next(
|
||||
iter(transformed_text.attack_attrs["newly_modified_indices"])
|
||||
)
|
||||
num_replacements_per_word[diff_idx] += 1
|
||||
|
||||
# Just b/c there are no replacements now doesn't mean we never want to select the word for perturbation
|
||||
# Therefore, we give small non-zero probability for words with no replacements
|
||||
# Epsilon is some small number to approximately assign small probability
|
||||
min_num_candidates = np.amin(num_replacements_per_word)
|
||||
epsilon = max(1, int(min_num_candidates * 0.1))
|
||||
for i in range(len(num_replacements_per_word)):
|
||||
num_replacements_per_word[i] = max(num_replacements_per_word[i], epsilon)
|
||||
|
||||
population = []
|
||||
for _ in range(pop_size):
|
||||
pop_member = PopulationMember(
|
||||
initial_result.attacked_text,
|
||||
initial_result,
|
||||
num_replacements_per_word=np.copy(num_replacements_per_word),
|
||||
)
|
||||
# Perturb `pop_member` in-place
|
||||
self._perturb(pop_member, initial_result)
|
||||
population.append(pop_member)
|
||||
|
||||
return population
|
||||
raise NotImplementedError()
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
self._search_over = False
|
||||
population = self._initialize_population(initial_result, self.pop_size)
|
||||
pop_size = len(population)
|
||||
current_score = initial_result.score
|
||||
|
||||
for i in range(self.max_iters):
|
||||
population = sorted(population, key=lambda x: x.result.score, reverse=True)
|
||||
|
||||
if (
|
||||
self._search_over
|
||||
or population[0].result.goal_status
|
||||
@@ -247,24 +229,20 @@ class GeneticAlgorithm(PopulationBasedSearch):
|
||||
logits = ((-pop_scores) / self.temp).exp()
|
||||
select_probs = (logits / logits.sum()).cpu().numpy()
|
||||
|
||||
parent1_idx = np.random.choice(
|
||||
self.pop_size, size=self.pop_size - 1, p=select_probs
|
||||
)
|
||||
parent2_idx = np.random.choice(
|
||||
self.pop_size, size=self.pop_size - 1, p=select_probs
|
||||
)
|
||||
parent1_idx = np.random.choice(pop_size, size=pop_size - 1, p=select_probs)
|
||||
parent2_idx = np.random.choice(pop_size, size=pop_size - 1, p=select_probs)
|
||||
|
||||
children = []
|
||||
for idx in range(self.pop_size - 1):
|
||||
for idx in range(pop_size - 1):
|
||||
child = self._crossover(
|
||||
population[parent1_idx[idx]],
|
||||
population[parent2_idx[idx]],
|
||||
initial_result.attacked_text,
|
||||
initial_result,
|
||||
)
|
||||
if self._search_over:
|
||||
break
|
||||
|
||||
self._perturb(child, initial_result)
|
||||
child = self._perturb(child, initial_result)
|
||||
children.append(child)
|
||||
|
||||
# We need two `search_over` checks b/c value might change both in
|
||||
|
||||
Reference in New Issue
Block a user