1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

split variable names for IGA and GA

This commit is contained in:
Jin Yong Yoo
2020-07-31 09:34:57 -04:00
parent 4ba7fac789
commit 97b4aa2886
4 changed files with 77 additions and 59 deletions

View File

@@ -46,7 +46,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
@abstractmethod
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
"""Modify `pop_member` by returning a new copy with `new_text`,
`new_result`, and `num_replacements_per_word` altered appropriately for
`new_result`, and `word_attr_list` altered appropriately for
given `word_idx`"""
raise NotImplementedError()
@@ -61,9 +61,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
Returns:
Perturbed `PopulationMember`
"""
num_words = pop_member.num_replacements_per_word.shape[0]
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
non_zero_indices = np.count_nonzero(num_replacements_per_word)
num_words = pop_member.attacked_text.num_words
# `word_attr_list` contains some attribute of each word that helps us choose which word to transform.
word_attr_list = np.copy(getattr(pop_member, self._attr_name))
non_zero_indices = np.count_nonzero(word_attr_list)
if non_zero_indices == 0:
return pop_member
iterations = 0
@@ -71,8 +72,8 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
if index:
idx = index
else:
w_select_probs = num_replacements_per_word / np.sum(
num_replacements_per_word
w_select_probs = word_attr_list / np.sum(
word_attr_list
)
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
@@ -104,22 +105,48 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
)
return pop_member
num_replacements_per_word[idx] = 0
word_attr_list[idx] = 0
iterations += 1
return pop_member
@abstractmethod
def _crossover_operation(self, pop_member1, pop_member2):
"""Actual operation for generating crossover between pop_member1 and
pop_member2.
"""
Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
to generate crossover between `pop_member1` and `pop_member2`.
Args:
pop_member1 (PopulationMember): The first population member.
pop_member2 (PopulationMember): The second population member.
Returns:
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
Tuple of `AttackedText` and `np.array` for new text and its corresponding `word_attr_list`.
"""
raise NotImplementedError()
def _post_crossover_check(self, new_text, parent_text1, parent_text2, original_text):
"""Check if `new_text` that has been produced by performing crossover between
`parent_text1` and `parent_text2` aligns with the constraints.
Args:
new_text (AttackedText): Text produced by crossover operation
parent_text1 (AttackedText): Parent text of `new_text`
parent_text2 (AttackedText): Second parent text of `new_text`
original_text (AttackedText): Original text
Returns:
`True` if `new_text` meets the constraints. If otherwise, return `False`.
"""
if "last_transformation" in new_text.attack_attrs:
previous_text = (
parent_text1
if "last_transformation" in parent_text1.attack_attrs
else parent_text2
)
passed_constraints = self._check_constraints(
new_text, previous_text, original_text=original_text
)
return passed_constraints
else:
# `new_text` has not been actually transformed, so return True
return True
def _crossover(self, pop_member1, pop_member2, original_text):
"""Generates a crossover between pop_member1 and pop_member2.
@@ -139,7 +166,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
num_tries = 0
passed_constraints = False
while num_tries < self.max_crossover_retries + 1:
new_text, num_replacements_per_word = self._crossover_operation(
new_text, word_attr_list = self._crossover_operation(
pop_member1, pop_member2
)
@@ -157,24 +184,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
"last_transformation"
]
if not self.post_crossover_check or (
new_text.text == x1_text.text or new_text.text == x2_text.text
):
break
if self.post_crossover_check:
passed_constraints = self._post_crossover_check(new_text, x1_text, x2_text, original_text)
if "last_transformation" in new_text.attack_attrs:
previous_text = (
x1_text
if "last_transformation" in x1_text.attack_attrs
else x2_text
)
passed_constraints = self._check_constraints(
new_text, previous_text, original_text=original_text
)
else:
passed_constraints = True
if passed_constraints:
if not self.post_crossover_check or passed_constraints:
break
num_tries += 1
@@ -189,7 +202,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
return PopulationMember(
new_text,
result=new_results[0],
num_replacements_per_word=num_replacements_per_word,
**{self._attr_name: word_attr_list},
)
@abstractmethod