mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
split variable names for IGA and GA
This commit is contained in:
@@ -42,43 +42,44 @@ class AlzantotGeneticAlgorithm(GeneticAlgorithm):
|
||||
post_crossover_check=post_crossover_check,
|
||||
max_crossover_retries=max_crossover_retries,
|
||||
)
|
||||
self._attr_name = "num_candidate_transformations"
|
||||
|
||||
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
|
||||
"""Modify `pop_member` by returning a new copy with `new_text`,
|
||||
`new_result`, and `num_replacements_per_word` altered appropriately for
|
||||
`new_result`, and `num_candidate_transformations` altered appropriately for
|
||||
given `word_idx`"""
|
||||
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
|
||||
num_replacements_per_word[word_idx] = 0
|
||||
num_candidate_transformations = np.copy(pop_member.num_candidate_transformations)
|
||||
num_candidate_transformations[word_idx] = 0
|
||||
return PopulationMember(
|
||||
new_text,
|
||||
result=new_result,
|
||||
num_replacements_per_word=num_replacements_per_word,
|
||||
num_candidate_transformations=num_candidate_transformations,
|
||||
)
|
||||
|
||||
def _crossover_operation(self, pop_member1, pop_member2):
|
||||
"""Actual operation for generating crossover between pop_member1 and
|
||||
pop_member2.
|
||||
"""Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
|
||||
to generate crossover between `pop_member1` and `pop_member2`.
|
||||
|
||||
Args:
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
Returns:
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_candidate_transformations`.
|
||||
"""
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
num_replacements_per_word = np.copy(pop_member1.num_replacements_per_word)
|
||||
num_candidate_transformations = np.copy(pop_member1.num_candidate_transformations)
|
||||
|
||||
for i in range(pop_member1.num_words):
|
||||
if np.random.uniform() < 0.5:
|
||||
indices_to_replace.append(i)
|
||||
words_to_replace.append(pop_member2.words[i])
|
||||
num_replacements_per_word[i] = pop_member2.num_replacements_per_word[i]
|
||||
num_candidate_transformations[i] = pop_member2.num_candidate_transformations[i]
|
||||
|
||||
new_text = pop_member1.attacked_text.replace_words_at_indices(
|
||||
indices_to_replace, words_to_replace
|
||||
)
|
||||
return new_text, num_replacements_per_word
|
||||
return new_text, num_candidate_transformations
|
||||
|
||||
def _initialize_population(self, initial_result, pop_size):
|
||||
"""
|
||||
@@ -90,7 +91,7 @@ class AlzantotGeneticAlgorithm(GeneticAlgorithm):
|
||||
population as `list[PopulationMember]`
|
||||
"""
|
||||
words = initial_result.attacked_text.words
|
||||
num_replacements_per_word = np.zeros(len(words))
|
||||
num_candidate_transformations = np.zeros(len(words))
|
||||
transformed_texts = self.get_transformations(
|
||||
initial_result.attacked_text, original_text=initial_result.attacked_text
|
||||
)
|
||||
@@ -98,22 +99,22 @@ class AlzantotGeneticAlgorithm(GeneticAlgorithm):
|
||||
diff_idx = next(
|
||||
iter(transformed_text.attack_attrs["newly_modified_indices"])
|
||||
)
|
||||
num_replacements_per_word[diff_idx] += 1
|
||||
num_candidate_transformations[diff_idx] += 1
|
||||
|
||||
# Just b/c there are no replacements now doesn't mean we never want to select the word for perturbation
|
||||
# Therefore, we give small non-zero probability for words with no replacements
|
||||
# Epsilon is some small number to approximately assign small probability
|
||||
min_num_candidates = np.amin(num_replacements_per_word)
|
||||
min_num_candidates = np.amin(num_candidate_transformations)
|
||||
epsilon = max(1, int(min_num_candidates * 0.1))
|
||||
for i in range(len(num_replacements_per_word)):
|
||||
num_replacements_per_word[i] = max(num_replacements_per_word[i], epsilon)
|
||||
for i in range(len(num_candidate_transformations)):
|
||||
num_candidate_transformations[i] = max(num_candidate_transformations[i], epsilon)
|
||||
|
||||
population = []
|
||||
for _ in range(pop_size):
|
||||
pop_member = PopulationMember(
|
||||
initial_result.attacked_text,
|
||||
initial_result,
|
||||
num_replacements_per_word=np.copy(num_replacements_per_word),
|
||||
num_candidate_transformations=np.copy(num_candidate_transformations),
|
||||
)
|
||||
# Perturb `pop_member` in-place
|
||||
pop_member = self._perturb(pop_member, initial_result)
|
||||
|
||||
@@ -46,7 +46,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
@abstractmethod
|
||||
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
|
||||
"""Modify `pop_member` by returning a new copy with `new_text`,
|
||||
`new_result`, and `num_replacements_per_word` altered appropriately for
|
||||
`new_result`, and `word_attr_list` altered appropriately for
|
||||
given `word_idx`"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -61,9 +61,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
Returns:
|
||||
Perturbed `PopulationMember`
|
||||
"""
|
||||
num_words = pop_member.num_replacements_per_word.shape[0]
|
||||
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
|
||||
non_zero_indices = np.count_nonzero(num_replacements_per_word)
|
||||
num_words = pop_member.attacked_text.num_words
|
||||
# `word_attr_list` contains some attribute of each word that helps us choose which word to transform.
|
||||
word_attr_list = np.copy(getattr(pop_member, self._attr_name))
|
||||
non_zero_indices = np.count_nonzero(word_attr_list)
|
||||
if non_zero_indices == 0:
|
||||
return pop_member
|
||||
iterations = 0
|
||||
@@ -71,8 +72,8 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
if index:
|
||||
idx = index
|
||||
else:
|
||||
w_select_probs = num_replacements_per_word / np.sum(
|
||||
num_replacements_per_word
|
||||
w_select_probs = word_attr_list / np.sum(
|
||||
word_attr_list
|
||||
)
|
||||
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
@@ -104,22 +105,48 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
)
|
||||
return pop_member
|
||||
|
||||
num_replacements_per_word[idx] = 0
|
||||
word_attr_list[idx] = 0
|
||||
iterations += 1
|
||||
return pop_member
|
||||
|
||||
@abstractmethod
|
||||
def _crossover_operation(self, pop_member1, pop_member2):
|
||||
"""Actual operation for generating crossover between pop_member1 and
|
||||
pop_member2.
|
||||
"""
|
||||
Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
|
||||
to generate crossover between `pop_member1` and `pop_member2`.
|
||||
|
||||
Args:
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
Returns:
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `word_attr_list`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _post_crossover_check(self, new_text, parent_text1, parent_text2, original_text):
|
||||
"""Check if `new_text` that has been produced by performing crossover between
|
||||
`parent_text1` and `parent_text2` aligns with the constraints.
|
||||
Args:
|
||||
new_text (AttackedText): Text produced by crossover operation
|
||||
parent_text1 (AttackedText): Parent text of `new_text`
|
||||
parent_text2 (AttackedText): Second parent text of `new_text`
|
||||
original_text (AttackedText): Original text
|
||||
Returns:
|
||||
`True` if `new_text` meets the constraints. If otherwise, return `False`.
|
||||
"""
|
||||
if "last_transformation" in new_text.attack_attrs:
|
||||
previous_text = (
|
||||
parent_text1
|
||||
if "last_transformation" in parent_text1.attack_attrs
|
||||
else parent_text2
|
||||
)
|
||||
passed_constraints = self._check_constraints(
|
||||
new_text, previous_text, original_text=original_text
|
||||
)
|
||||
return passed_constraints
|
||||
else:
|
||||
# `new_text` has not been actually transformed, so return True
|
||||
return True
|
||||
|
||||
def _crossover(self, pop_member1, pop_member2, original_text):
|
||||
"""Generates a crossover between pop_member1 and pop_member2.
|
||||
@@ -139,7 +166,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
num_tries = 0
|
||||
passed_constraints = False
|
||||
while num_tries < self.max_crossover_retries + 1:
|
||||
new_text, num_replacements_per_word = self._crossover_operation(
|
||||
new_text, word_attr_list = self._crossover_operation(
|
||||
pop_member1, pop_member2
|
||||
)
|
||||
|
||||
@@ -157,24 +184,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
"last_transformation"
|
||||
]
|
||||
|
||||
if not self.post_crossover_check or (
|
||||
new_text.text == x1_text.text or new_text.text == x2_text.text
|
||||
):
|
||||
break
|
||||
if self.post_crossover_check:
|
||||
passed_constraints = self._post_crossover_check(new_text, x1_text, x2_text, original_text)
|
||||
|
||||
if "last_transformation" in new_text.attack_attrs:
|
||||
previous_text = (
|
||||
x1_text
|
||||
if "last_transformation" in x1_text.attack_attrs
|
||||
else x2_text
|
||||
)
|
||||
passed_constraints = self._check_constraints(
|
||||
new_text, previous_text, original_text=original_text
|
||||
)
|
||||
else:
|
||||
passed_constraints = True
|
||||
|
||||
if passed_constraints:
|
||||
if not self.post_crossover_check or passed_constraints:
|
||||
break
|
||||
|
||||
num_tries += 1
|
||||
@@ -189,7 +202,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
return PopulationMember(
|
||||
new_text,
|
||||
result=new_results[0],
|
||||
num_replacements_per_word=num_replacements_per_word,
|
||||
**{self._attr_name: word_attr_list},
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -45,32 +45,33 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
|
||||
)
|
||||
|
||||
self.max_replace_times_per_index = max_replace_times_per_index
|
||||
self._attr_name = "num_replacements_left"
|
||||
|
||||
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
|
||||
"""Modify `pop_member` by returning a new copy with `new_text`,
|
||||
`new_result`, and `num_replacements_per_word` altered appropriately for
|
||||
`new_result`, and `num_replacements_left` altered appropriately for
|
||||
given `word_idx`"""
|
||||
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
|
||||
num_replacements_per_word[word_idx] -= 1
|
||||
num_replacements_left = np.copy(pop_member.num_replacements_left)
|
||||
num_replacements_left[word_idx] -= 1
|
||||
return PopulationMember(
|
||||
new_text,
|
||||
result=new_result,
|
||||
num_replacements_per_word=num_replacements_per_word,
|
||||
num_replacements_left=num_replacements_left,
|
||||
)
|
||||
|
||||
def _crossover_operation(self, pop_member1, pop_member2):
|
||||
"""Actual operation for generating crossover between pop_member1 and
|
||||
pop_member2.
|
||||
"""Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
|
||||
to generate crossover between `pop_member1` and `pop_member2`.
|
||||
|
||||
Args:
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
Returns:
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_left`.
|
||||
"""
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
num_replacements_per_word = np.copy(pop_member1.num_replacements_per_word)
|
||||
num_replacements_left = np.copy(pop_member1.num_replacements_left)
|
||||
|
||||
# To better simulate the reproduction and biological crossover,
|
||||
# IGA randomly cut the text from two parents and concat two fragments into a new text
|
||||
@@ -79,12 +80,12 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
|
||||
for i in range(crossover_point, pop_member1.num_words):
|
||||
indices_to_replace.append(i)
|
||||
words_to_replace.append(pop_member2.words[i])
|
||||
num_replacements_per_word[i] = pop_member2.num_replacements_per_word[i]
|
||||
num_replacements_left[i] = pop_member2.num_replacements_left[i]
|
||||
|
||||
new_text = pop_member1.attacked_text.replace_words_at_indices(
|
||||
indices_to_replace, words_to_replace
|
||||
)
|
||||
return new_text, num_replacements_per_word
|
||||
return new_text, num_replacements_left
|
||||
|
||||
def _initialize_population(self, initial_result, pop_size):
|
||||
"""
|
||||
@@ -96,8 +97,8 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
|
||||
population as `list[PopulationMember]`
|
||||
"""
|
||||
words = initial_result.attacked_text.words
|
||||
# For IGA, `num_replacements_per_word` represents the number of times the word at each index can be modified
|
||||
num_replacements_per_word = np.array(
|
||||
# For IGA, `num_replacements_left` represents the number of times the word at each index can be modified
|
||||
num_replacements_left = np.array(
|
||||
[self.max_replace_times_per_index] * len(words)
|
||||
)
|
||||
population = []
|
||||
@@ -107,7 +108,7 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
|
||||
pop_member = PopulationMember(
|
||||
initial_result.attacked_text,
|
||||
initial_result,
|
||||
num_replacements_per_word=np.copy(num_replacements_per_word),
|
||||
num_replacements_left=np.copy(num_replacements_left),
|
||||
)
|
||||
pop_member = self._perturb(pop_member, initial_result, index=idx)
|
||||
population.append(pop_member)
|
||||
|
||||
@@ -344,7 +344,10 @@ class AttackedText:
|
||||
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
|
||||
|
||||
def words_diff_ratio(self, x):
|
||||
"""Get the ratio of words difference between current text and x."""
|
||||
"""Get the ratio of words difference between current text and `x`.
|
||||
Note that current text and `x` must have same number of words.
|
||||
"""
|
||||
assert self.num_words == x.num_words
|
||||
return float(np.sum(self.words != x.words)) / self.num_words
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user