mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
refactor pso and transformation caching
This commit is contained in:
@@ -11,11 +11,11 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.search_methods import PopulationBasedMethod, PopulationMember
|
||||
from textattack.search_methods import PopulationBasedSearch, PopulationMember
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
|
||||
class GeneticAlgorithm(PopulationBasedMethod):
|
||||
class GeneticAlgorithm(PopulationBasedSearch):
|
||||
"""Attacks a model with word substiutitions using a genetic algorithm.
|
||||
|
||||
Args:
|
||||
@@ -57,7 +57,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
pop_member (PopulationMember): The population member being perturbed.
|
||||
original_result (GoalFunctionResult): Result of original sample being attacked
|
||||
Returns:
|
||||
`True` if perturbation occured. `False` if not.
|
||||
`True` if perturbation occurred. `False` if not.
|
||||
"""
|
||||
num_words = pop_member.num_replacements_per_word.shape[0]
|
||||
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
|
||||
@@ -71,17 +71,17 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
)
|
||||
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
transformations = self.get_transformations(
|
||||
transformed_texts = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=original_result.attacked_text,
|
||||
indices_to_modify=[idx],
|
||||
)
|
||||
|
||||
if not len(transformations):
|
||||
if not len(transformed_texts):
|
||||
iterations += 1
|
||||
continue
|
||||
|
||||
new_results, self._search_over = self.get_goal_results(transformations)
|
||||
new_results, self._search_over = self.get_goal_results(transformed_texts)
|
||||
|
||||
if self._search_over:
|
||||
break
|
||||
@@ -91,7 +91,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
)
|
||||
if len(diff_scores) and diff_scores.max() > 0:
|
||||
idx_with_max_score = diff_scores.argmax()
|
||||
pop_member.attacked_text = transformations[idx_with_max_score]
|
||||
pop_member.attacked_text = transformed_texts[idx_with_max_score]
|
||||
pop_member.results = new_results[idx_with_max_score]
|
||||
pop_member.num_replacements_per_word[idx] = 0
|
||||
return True
|
||||
@@ -100,7 +100,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
iterations += 1
|
||||
return False
|
||||
|
||||
def _crossover(self, pop_member1, pop_member2, original_result):
|
||||
def _crossover(self, pop_member1, pop_member2, original_text):
|
||||
"""Generates a crossover between pop_member1 and pop_member2.
|
||||
|
||||
If the child fails to satisfy the constraits, we re-try crossover for a fix number of times,
|
||||
@@ -108,6 +108,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
Args:
|
||||
pop_member1 (PopulationMember): The first population member.
|
||||
pop_member2 (PopulationMember): The second population member.
|
||||
original_text (AttackedText): Original text
|
||||
Returns:
|
||||
A population member containing the crossover.
|
||||
"""
|
||||
@@ -138,18 +139,28 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
x1_text.attack_attrs["modified_indices"] - indices_to_replace
|
||||
) | (x2_text.attack_attrs["modified_indices"] & indices_to_replace)
|
||||
|
||||
if "last_transformation" in x1_text.attack_attrs:
|
||||
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[
|
||||
"last_transformation"
|
||||
]
|
||||
elif "last_transformation" in x2_text.attack_attrs:
|
||||
new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[
|
||||
"last_transformation"
|
||||
]
|
||||
|
||||
if not self.post_crossover_check or (
|
||||
new_text.text == x1_text.text or new_text.text == x2_text.text
|
||||
):
|
||||
break
|
||||
|
||||
if "last_transformation" in x1_text.attack_attrs:
|
||||
passed_constraints = self._check_constraints(
|
||||
new_text, x1_text, original_text=original_result.attacked_text
|
||||
if "last_transformation" in new_text.attack_attrs:
|
||||
previous_text = (
|
||||
x1_text
|
||||
if "last_transformation" in x1_text.attack_attrs
|
||||
else x2_text
|
||||
)
|
||||
elif "last_transformation" in x2_text.attack_attrs:
|
||||
passed_constraints = self._check_constraints(
|
||||
new_text, x2_text, original_text=original_result.attacked_text
|
||||
new_text, previous_text, original_text=original_text
|
||||
)
|
||||
else:
|
||||
passed_constraints = True
|
||||
@@ -183,10 +194,10 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
"""
|
||||
words = initial_result.attacked_text.words
|
||||
num_replacements_per_word = np.zeros(len(words))
|
||||
transformations = self.get_transformations(
|
||||
transformed_texts = self.get_transformations(
|
||||
initial_result.attacked_text, original_text=initial_result.attacked_text
|
||||
)
|
||||
for transformed_text in transformations:
|
||||
for transformed_text in transformed_texts:
|
||||
diff_idx = next(
|
||||
iter(transformed_text.attack_attrs["newly_modified_indices"])
|
||||
)
|
||||
@@ -194,12 +205,11 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
|
||||
# Just b/c there are no replacements now doesn't mean we never want to select the word for perturbation
|
||||
# Therefore, we give small non-zero probability for words with no replacements
|
||||
# Epsilon is some small number to approximately assign 1% probability
|
||||
num_total_candidates = np.sum(num_replacements_per_word)
|
||||
epsilon = max(1, int(num_total_candidates * 0.01))
|
||||
# Epsilon is some small number to approximately assign small probability
|
||||
min_num_candidates = np.amin(num_replacements_per_word)
|
||||
epsilon = max(1, int(min_num_candidates * 0.1))
|
||||
for i in range(len(num_replacements_per_word)):
|
||||
if num_replacements_per_word[i] == 0:
|
||||
num_replacements_per_word[i] = epsilon
|
||||
num_replacements_per_word[i] = max(num_replacements_per_word[i], epsilon)
|
||||
|
||||
population = []
|
||||
for _ in range(pop_size):
|
||||
@@ -249,7 +259,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
child = self._crossover(
|
||||
population[parent1_idx[idx]],
|
||||
population[parent2_idx[idx]],
|
||||
initial_result,
|
||||
initial_result.attacked_text,
|
||||
)
|
||||
if self._search_over:
|
||||
break
|
||||
@@ -261,8 +271,6 @@ class GeneticAlgorithm(PopulationBasedMethod):
|
||||
# `crossover` method and `perturb` method.
|
||||
if self._search_over:
|
||||
break
|
||||
if self._search_over:
|
||||
break
|
||||
|
||||
population = [population[0]] + children
|
||||
|
||||
|
||||
Reference in New Issue
Block a user