1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add attributes to population member

This commit is contained in:
Jin Yong Yoo
2020-08-21 18:01:47 -04:00
parent 31a69a839e
commit 70b035aeb3
5 changed files with 56 additions and 30 deletions

View File

@@ -46,10 +46,16 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
@abstractmethod
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
"""Modify `pop_member` by returning a new copy with `new_text`,
`new_result`, and `word_attr_list` altered appropriately for given
`new_result`, and, `attributes` altered appropriately for given
`word_idx`"""
raise NotImplementedError()
@abstractmethod
def _get_word_select_prob_weights(self, pop_member):
"""Get the attribute of `pop_member` that is used for determining
probability of each word being selected for perturbation."""
raise NotImplementedError
def _perturb(self, pop_member, original_result, index=None):
"""Perturb `pop_member` and return it. Replaces a word at a random
(unless `index` is specified) in `pop_member`.
@@ -62,9 +68,11 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
Perturbed `PopulationMember`
"""
num_words = pop_member.attacked_text.num_words
# `word_attr_list` contains some attribute of each word that helps us choose which word to transform.
word_attr_list = np.copy(getattr(pop_member, self._attr_name))
non_zero_indices = np.count_nonzero(word_attr_list)
# `word_select_prob_weights` is a list of values used for sampling one word to transform
word_select_prob_weights = np.copy(
self._get_word_select_prob_weights(pop_member)
)
non_zero_indices = np.count_nonzero(word_select_prob_weights)
if non_zero_indices == 0:
return pop_member
iterations = 0
@@ -72,7 +80,9 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
if index:
idx = index
else:
w_select_probs = word_attr_list / np.sum(word_attr_list)
w_select_probs = word_select_prob_weights / np.sum(
word_select_prob_weights
)
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
transformed_texts = self.get_transformations(
@@ -103,7 +113,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
)
return pop_member
word_attr_list[idx] = 0
word_select_prob_weights[idx] = 0
iterations += 1
return pop_member
@@ -117,7 +127,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
pop_member1 (PopulationMember): The first population member.
pop_member2 (PopulationMember): The second population member.
Returns:
Tuple of `AttackedText` and `np.array` for new text and its corresponding `word_attr_list`.
Tuple of `AttackedText` and a dictionary of attributes.
"""
raise NotImplementedError()
@@ -167,9 +177,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
num_tries = 0
passed_constraints = False
while num_tries < self.max_crossover_retries + 1:
new_text, word_attr_list = self._crossover_operation(
pop_member1, pop_member2
)
new_text, attributes = self._crossover_operation(pop_member1, pop_member2)
replaced_indices = new_text.attack_attrs["newly_modified_indices"]
new_text.attack_attrs["modified_indices"] = (
@@ -203,7 +211,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
else:
new_results, self._search_over = self.get_goal_results([new_text])
return PopulationMember(
new_text, result=new_results[0], **{self._attr_name: word_attr_list},
new_text, result=new_results[0], attributes=attributes
)
@abstractmethod