mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
split variable names for IGA and GA
This commit is contained in:
@@ -46,7 +46,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
@abstractmethod
|
||||
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
|
||||
"""Modify `pop_member` by returning a new copy with `new_text`,
|
||||
`new_result`, and `num_replacements_per_word` altered appropriately for
|
||||
`new_result`, and `word_attr_list` altered appropriately for
|
||||
given `word_idx`"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -61,9 +61,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
Returns:
|
||||
Perturbed `PopulationMember`
|
||||
"""
|
||||
num_words = pop_member.num_replacements_per_word.shape[0]
|
||||
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
|
||||
non_zero_indices = np.count_nonzero(num_replacements_per_word)
|
||||
num_words = pop_member.attacked_text.num_words
|
||||
# `word_attr_list` contains some attribute of each word that helps us choose which word to transform.
|
||||
word_attr_list = np.copy(getattr(pop_member, self._attr_name))
|
||||
non_zero_indices = np.count_nonzero(word_attr_list)
|
||||
if non_zero_indices == 0:
|
||||
return pop_member
|
||||
iterations = 0
|
||||
@@ -71,8 +72,8 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
if index:
|
||||
idx = index
|
||||
else:
|
||||
w_select_probs = num_replacements_per_word / np.sum(
|
||||
num_replacements_per_word
|
||||
w_select_probs = word_attr_list / np.sum(
|
||||
word_attr_list
|
||||
)
|
||||
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
@@ -104,22 +105,48 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
)
|
||||
return pop_member
|
||||
|
||||
num_replacements_per_word[idx] = 0
|
||||
word_attr_list[idx] = 0
|
||||
iterations += 1
|
||||
return pop_member
|
||||
|
||||
@abstractmethod
|
||||
def _crossover_operation(self, pop_member1, pop_member2):
|
||||
"""Actual operation for generating crossover between pop_member1 and
|
||||
pop_member2.
|
||||
"""
|
||||
Actual operation that takes `pop_member1` text and `pop_member2` text and mixes the two
|
||||
to generate crossover between `pop_member1` and `pop_member2`.
|
||||
|
||||
Args:
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
Returns:
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `num_replacements_per_word`.
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `word_attr_list`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _post_crossover_check(self, new_text, parent_text1, parent_text2, original_text):
|
||||
"""Check if `new_text` that has been produced by performing crossover between
|
||||
`parent_text1` and `parent_text2` aligns with the constraints.
|
||||
Args:
|
||||
new_text (AttackedText): Text produced by crossover operation
|
||||
parent_text1 (AttackedText): Parent text of `new_text`
|
||||
parent_text2 (AttackedText): Second parent text of `new_text`
|
||||
original_text (AttackedText): Original text
|
||||
Returns:
|
||||
`True` if `new_text` meets the constraints. If otherwise, return `False`.
|
||||
"""
|
||||
if "last_transformation" in new_text.attack_attrs:
|
||||
previous_text = (
|
||||
parent_text1
|
||||
if "last_transformation" in parent_text1.attack_attrs
|
||||
else parent_text2
|
||||
)
|
||||
passed_constraints = self._check_constraints(
|
||||
new_text, previous_text, original_text=original_text
|
||||
)
|
||||
return passed_constraints
|
||||
else:
|
||||
# `new_text` has not been actually transformed, so return True
|
||||
return True
|
||||
|
||||
def _crossover(self, pop_member1, pop_member2, original_text):
|
||||
"""Generates a crossover between pop_member1 and pop_member2.
|
||||
@@ -139,7 +166,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
num_tries = 0
|
||||
passed_constraints = False
|
||||
while num_tries < self.max_crossover_retries + 1:
|
||||
new_text, num_replacements_per_word = self._crossover_operation(
|
||||
new_text, word_attr_list = self._crossover_operation(
|
||||
pop_member1, pop_member2
|
||||
)
|
||||
|
||||
@@ -157,24 +184,10 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
"last_transformation"
|
||||
]
|
||||
|
||||
if not self.post_crossover_check or (
|
||||
new_text.text == x1_text.text or new_text.text == x2_text.text
|
||||
):
|
||||
break
|
||||
if self.post_crossover_check:
|
||||
passed_constraints = self._post_crossover_check(new_text, x1_text, x2_text, original_text)
|
||||
|
||||
if "last_transformation" in new_text.attack_attrs:
|
||||
previous_text = (
|
||||
x1_text
|
||||
if "last_transformation" in x1_text.attack_attrs
|
||||
else x2_text
|
||||
)
|
||||
passed_constraints = self._check_constraints(
|
||||
new_text, previous_text, original_text=original_text
|
||||
)
|
||||
else:
|
||||
passed_constraints = True
|
||||
|
||||
if passed_constraints:
|
||||
if not self.post_crossover_check or passed_constraints:
|
||||
break
|
||||
|
||||
num_tries += 1
|
||||
@@ -189,7 +202,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
return PopulationMember(
|
||||
new_text,
|
||||
result=new_results[0],
|
||||
num_replacements_per_word=num_replacements_per_word,
|
||||
**{self._attr_name: word_attr_list},
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
Reference in New Issue
Block a user