mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add attributes to population member
This commit is contained in:
@@ -46,10 +46,16 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
@abstractmethod
|
||||
def _modify_population_member(self, pop_member, new_text, new_result, word_idx):
|
||||
"""Modify `pop_member` by returning a new copy with `new_text`,
|
||||
`new_result`, and `word_attr_list` altered appropriately for given
|
||||
`new_result`, and, `attributes` altered appropriately for given
|
||||
`word_idx`"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def _get_word_select_prob_weights(self, pop_member):
|
||||
"""Get the attribute of `pop_member` that is used for determining
|
||||
probability of each word being selected for perturbation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _perturb(self, pop_member, original_result, index=None):
|
||||
"""Perturb `pop_member` and return it. Replaces a word at a random
|
||||
(unless `index` is specified) in `pop_member`.
|
||||
@@ -62,9 +68,11 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
Perturbed `PopulationMember`
|
||||
"""
|
||||
num_words = pop_member.attacked_text.num_words
|
||||
# `word_attr_list` contains some attribute of each word that helps us choose which word to transform.
|
||||
word_attr_list = np.copy(getattr(pop_member, self._attr_name))
|
||||
non_zero_indices = np.count_nonzero(word_attr_list)
|
||||
# `word_select_prob_weights` is a list of values used for sampling one word to transform
|
||||
word_select_prob_weights = np.copy(
|
||||
self._get_word_select_prob_weights(pop_member)
|
||||
)
|
||||
non_zero_indices = np.count_nonzero(word_select_prob_weights)
|
||||
if non_zero_indices == 0:
|
||||
return pop_member
|
||||
iterations = 0
|
||||
@@ -72,7 +80,9 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
if index:
|
||||
idx = index
|
||||
else:
|
||||
w_select_probs = word_attr_list / np.sum(word_attr_list)
|
||||
w_select_probs = word_select_prob_weights / np.sum(
|
||||
word_select_prob_weights
|
||||
)
|
||||
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
transformed_texts = self.get_transformations(
|
||||
@@ -103,7 +113,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
)
|
||||
return pop_member
|
||||
|
||||
word_attr_list[idx] = 0
|
||||
word_select_prob_weights[idx] = 0
|
||||
iterations += 1
|
||||
return pop_member
|
||||
|
||||
@@ -117,7 +127,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
Returns:
|
||||
Tuple of `AttackedText` and `np.array` for new text and its corresponding `word_attr_list`.
|
||||
Tuple of `AttackedText` and a dictionary of attributes.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -167,9 +177,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
num_tries = 0
|
||||
passed_constraints = False
|
||||
while num_tries < self.max_crossover_retries + 1:
|
||||
new_text, word_attr_list = self._crossover_operation(
|
||||
pop_member1, pop_member2
|
||||
)
|
||||
new_text, attributes = self._crossover_operation(pop_member1, pop_member2)
|
||||
|
||||
replaced_indices = new_text.attack_attrs["newly_modified_indices"]
|
||||
new_text.attack_attrs["modified_indices"] = (
|
||||
@@ -203,7 +211,7 @@ class GeneticAlgorithm(PopulationBasedSearch, ABC):
|
||||
else:
|
||||
new_results, self._search_over = self.get_goal_results([new_text])
|
||||
return PopulationMember(
|
||||
new_text, result=new_results[0], **{self._attr_name: word_attr_list},
|
||||
new_text, result=new_results[0], attributes=attributes
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
Reference in New Issue
Block a user