1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

split variable names for IGA and GA

This commit is contained in:
Jin Yong Yoo
2020-07-31 09:34:57 -04:00
parent 4ba7fac789
commit 97b4aa2886
4 changed files with 77 additions and 59 deletions

View File

@@ -42,43 +42,44 @@ class AlzantotGeneticAlgorithm(GeneticAlgorithm):
post_crossover_check=post_crossover_check,
max_crossover_retries=max_crossover_retries,
)
self._attr_name = "num_candidate_transformations"
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
"""Modify `pop_member` by returning a new copy with `new_text`,
`new_result`, and `num_replacements_per_word` altered appropriately for
`new_result`, and `num_candidate_transformations` altered appropriately for
given `word_idx`"""
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
num_replacements_per_word[word_idx] = 0
num_candidate_transformations = np.copy(pop_member.num_candidate_transformations)
num_candidate_transformations[word_idx] = 0
return PopulationMember(
new_text,
result=new_result,
num_replacements_per_word=num_replacements_per_word,
num_candidate_transformations=num_candidate_transformations,
)
def _crossover_operation(self, pop_member1, pop_member2):
"""Actual operation for generating crossover between pop_member1 and
pop_member2.
"""Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
to generate crossover between `pop_member1` and `pop_member2`.
Args:
pop_member1 (PopulationMember): The first population member.
pop_member2 (PopulationMember): The second population member.
Returns:
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_candidate_transformations`.
"""
indices_to_replace = []
words_to_replace = []
num_replacements_per_word = np.copy(pop_member1.num_replacements_per_word)
num_candidate_transformations = np.copy(pop_member1.num_candidate_transformations)
for i in range(pop_member1.num_words):
if np.random.uniform() < 0.5:
indices_to_replace.append(i)
words_to_replace.append(pop_member2.words[i])
num_replacements_per_word[i] = pop_member2.num_replacements_per_word[i]
num_candidate_transformations[i] = pop_member2.num_candidate_transformations[i]
new_text = pop_member1.attacked_text.replace_words_at_indices(
indices_to_replace, words_to_replace
)
return new_text, num_replacements_per_word
return new_text, num_candidate_transformations
def _initialize_population(self, initial_result, pop_size):
"""
@@ -90,7 +91,7 @@ class AlzantotGeneticAlgorithm(GeneticAlgorithm):
population as `list[PopulationMember]`
"""
words = initial_result.attacked_text.words
num_replacements_per_word = np.zeros(len(words))
num_candidate_transformations = np.zeros(len(words))
transformed_texts = self.get_transformations(
initial_result.attacked_text, original_text=initial_result.attacked_text
)
@@ -98,22 +99,22 @@ class AlzantotGeneticAlgorithm(GeneticAlgorithm):
diff_idx = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
num_replacements_per_word[diff_idx] += 1
num_candidate_transformations[diff_idx] += 1
# Just b/c there are no replacements now doesn't mean we never want to select the word for perturbation
# Therefore, we give small non-zero probability for words with no replacements
# Epsilon is some small number to approximately assign small probability
min_num_candidates = np.amin(num_replacements_per_word)
min_num_candidates = np.amin(num_candidate_transformations)
epsilon = max(1, int(min_num_candidates * 0.1))
for i in range(len(num_replacements_per_word)):
num_replacements_per_word[i] = max(num_replacements_per_word[i], epsilon)
for i in range(len(num_candidate_transformations)):
num_candidate_transformations[i] = max(num_candidate_transformations[i], epsilon)
population = []
for _ in range(pop_size):
pop_member = PopulationMember(
initial_result.attacked_text,
initial_result,
num_replacements_per_word=np.copy(num_replacements_per_word),
num_candidate_transformations=np.copy(num_candidate_transformations),
)
# Perturb `pop_member` in-place
pop_member = self._perturb(pop_member, initial_result)

View File

@@ -46,7 +46,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
@abstractmethod
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
"""Modify `pop_member` by returning a new copy with `new_text`,
`new_result`, and `num_replacements_per_word` altered appropriately for
`new_result`, and `word_attr_list` altered appropriately for
given `word_idx`"""
raise NotImplementedError()
@@ -61,9 +61,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
Returns:
Perturbed `PopulationMember`
"""
num_words = pop_member.num_replacements_per_word.shape[0]
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
non_zero_indices = np.count_nonzero(num_replacements_per_word)
num_words = pop_member.attacked_text.num_words
# `word_attr_list` contains some attribute of each word that helps us choose which word to transform.
word_attr_list = np.copy(getattr(pop_member, self._attr_name))
non_zero_indices = np.count_nonzero(word_attr_list)
if non_zero_indices == 0:
return pop_member
iterations = 0
@@ -71,8 +72,8 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
if index:
idx = index
else:
w_select_probs = num_replacements_per_word / np.sum(
num_replacements_per_word
w_select_probs = word_attr_list / np.sum(
word_attr_list
)
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
@@ -104,22 +105,48 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
)
return pop_member
num_replacements_per_word[idx] = 0
word_attr_list[idx] = 0
iterations += 1
return pop_member
@abstractmethod
def _crossover_operation(self, pop_member1, pop_member2):
"""Actual operation for generating crossover between pop_member1 and
pop_member2.
"""
Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
to generate crossover between `pop_member1` and `pop_member2`.
Args:
pop_member1 (PopulationMember): The first population member.
pop_member2 (PopulationMember): The second population member.
Returns:
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
Tuple of `AttackedText` and `np.array` for new text and its corresponding `word_attr_list`.
"""
raise NotImplementedError()
def _post_crossover_check(self, new_text, parent_text1, parent_text2, original_text):
"""Check if `new_text` that has been produced by performing crossover between
`parent_text1` and `parent_text2` aligns with the constraints.
Args:
new_text (AttackedText): Text produced by crossover operation
parent_text1 (AttackedText): Parent text of `new_text`
parent_text2 (AttackedText): Second parent text of `new_text`
original_text (AttackedText): Original text
Returns:
`True` if `new_text` meets the constraints. If otherwise, return `False`.
"""
if "last_transformation" in new_text.attack_attrs:
previous_text = (
parent_text1
if "last_transformation" in parent_text1.attack_attrs
else parent_text2
)
passed_constraints = self._check_constraints(
new_text, previous_text, original_text=original_text
)
return passed_constraints
else:
# `new_text` has not been actually transformed, so return True
return True
def _crossover(self, pop_member1, pop_member2, original_text):
"""Generates a crossover between pop_member1 and pop_member2.
@@ -139,7 +166,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
num_tries = 0
passed_constraints = False
while num_tries < self.max_crossover_retries + 1:
new_text, num_replacements_per_word = self._crossover_operation(
new_text, word_attr_list = self._crossover_operation(
pop_member1, pop_member2
)
@@ -157,24 +184,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
"last_transformation"
]
if not self.post_crossover_check or (
new_text.text == x1_text.text or new_text.text == x2_text.text
):
break
if self.post_crossover_check:
passed_constraints = self._post_crossover_check(new_text, x1_text, x2_text, original_text)
if "last_transformation" in new_text.attack_attrs:
previous_text = (
x1_text
if "last_transformation" in x1_text.attack_attrs
else x2_text
)
passed_constraints = self._check_constraints(
new_text, previous_text, original_text=original_text
)
else:
passed_constraints = True
if passed_constraints:
if not self.post_crossover_check or passed_constraints:
break
num_tries += 1
@@ -189,7 +202,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
return PopulationMember(
new_text,
result=new_results[0],
num_replacements_per_word=num_replacements_per_word,
**{self._attr_name: word_attr_list},
)
@abstractmethod

View File

@@ -45,32 +45,33 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
)
self.max_replace_times_per_index = max_replace_times_per_index
self._attr_name = "num_replacements_left"
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
"""Modify `pop_member` by returning a new copy with `new_text`,
`new_result`, and `num_replacements_per_word` altered appropriately for
`new_result`, and `num_replacements_left` altered appropriately for
given `word_idx`"""
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
num_replacements_per_word[word_idx] -= 1
num_replacements_left = np.copy(pop_member.num_replacements_left)
num_replacements_left[word_idx] -= 1
return PopulationMember(
new_text,
result=new_result,
num_replacements_per_word=num_replacements_per_word,
num_replacements_left=num_replacements_left,
)
def _crossover_operation(self, pop_member1, pop_member2):
"""Actual operation for generating crossover between pop_member1 and
pop_member2.
"""Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
to generate crossover between `pop_member1` and `pop_member2`.
Args:
pop_member1 (PopulationMember): The first population member.
pop_member2 (PopulationMember): The second population member.
Returns:
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_left`.
"""
indices_to_replace = []
words_to_replace = []
num_replacements_per_word = np.copy(pop_member1.num_replacements_per_word)
num_replacements_left = np.copy(pop_member1.num_replacements_left)
# To better simulate the reproduction and biological crossover,
# IGA randomly cut the text from two parents and concat two fragments into a new text
@@ -79,12 +80,12 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
for i in range(crossover_point, pop_member1.num_words):
indices_to_replace.append(i)
words_to_replace.append(pop_member2.words[i])
num_replacements_per_word[i] = pop_member2.num_replacements_per_word[i]
num_replacements_left[i] = pop_member2.num_replacements_left[i]
new_text = pop_member1.attacked_text.replace_words_at_indices(
indices_to_replace, words_to_replace
)
return new_text, num_replacements_per_word
return new_text, num_replacements_left
def _initialize_population(self, initial_result, pop_size):
"""
@@ -96,8 +97,8 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
population as `list[PopulationMember]`
"""
words = initial_result.attacked_text.words
# For IGA, `num_replacements_per_word` represents the number of times the word at each index can be modified
num_replacements_per_word = np.array(
# For IGA, `num_replacements_left` represents the number of times the word at each index can be modified
num_replacements_left = np.array(
[self.max_replace_times_per_index] * len(words)
)
population = []
@@ -107,7 +108,7 @@ class ImprovedGeneticAlgorithm(GeneticAlgorithm):
pop_member = PopulationMember(
initial_result.attacked_text,
initial_result,
num_replacements_per_word=np.copy(num_replacements_per_word),
num_replacements_left=np.copy(num_replacements_left),
)
pop_member = self._perturb(pop_member, initial_result, index=idx)
population.append(pop_member)

View File

@@ -344,7 +344,10 @@ class AttackedText:
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)
def words_diff_ratio(self, x):
"""Get the ratio of words difference between current text and x."""
"""Get the ratio of words difference between current text and `x`.
Note that current text and `x` must have same number of words.
"""
assert self.num_words == x.num_words
return float(np.sum(self.words != x.words)) / self.num_words
@property