1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

refactor pso and transformation caching

This commit is contained in:
Jin Yong Yoo
2020-07-25 08:57:52 -04:00
parent 613bbf0b88
commit 5a8d74d288
22 changed files with 166 additions and 116 deletions

View File

@@ -11,11 +11,11 @@ import numpy as np
import torch
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import PopulationBasedMethod, PopulationMember
from textattack.search_methods import PopulationBasedSearch, PopulationMember
from textattack.shared.validators import transformation_consists_of_word_swaps
class GeneticAlgorithm(PopulationBasedMethod):
class GeneticAlgorithm(PopulationBasedSearch):
"""Attacks a model with word substiutitions using a genetic algorithm.
Args:
@@ -57,7 +57,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
pop_member (PopulationMember): The population member being perturbed.
original_result (GoalFunctionResult): Result of original sample being attacked
Returns:
`True` if perturbation occured. `False` if not.
`True` if perturbation occurred. `False` if not.
"""
num_words = pop_member.num_replacements_per_word.shape[0]
num_replacements_per_word = np.copy(pop_member.num_replacements_per_word)
@@ -71,17 +71,17 @@ class GeneticAlgorithm(PopulationBasedMethod):
)
idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
transformations = self.get_transformations(
transformed_texts = self.get_transformations(
pop_member.attacked_text,
original_text=original_result.attacked_text,
indices_to_modify=[idx],
)
if not len(transformations):
if not len(transformed_texts):
iterations += 1
continue
new_results, self._search_over = self.get_goal_results(transformations)
new_results, self._search_over = self.get_goal_results(transformed_texts)
if self._search_over:
break
@@ -91,7 +91,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
)
if len(diff_scores) and diff_scores.max() > 0:
idx_with_max_score = diff_scores.argmax()
pop_member.attacked_text = transformations[idx_with_max_score]
pop_member.attacked_text = transformed_texts[idx_with_max_score]
pop_member.results = new_results[idx_with_max_score]
pop_member.num_replacements_per_word[idx] = 0
return True
@@ -100,7 +100,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
iterations += 1
return False
def _crossover(self, pop_member1, pop_member2, original_result):
def _crossover(self, pop_member1, pop_member2, original_text):
"""Generates a crossover between pop_member1 and pop_member2.
If the child fails to satisfy the constraits, we re-try crossover for a fix number of times,
@@ -108,6 +108,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
Args:
pop_member1 (PopulationMember): The first population member.
pop_member2 (PopulationMember): The second population member.
original_text (AttackedText): Original text
Returns:
A population member containing the crossover.
"""
@@ -138,18 +139,28 @@ class GeneticAlgorithm(PopulationBasedMethod):
x1_text.attack_attrs["modified_indices"] - indices_to_replace
) | (x2_text.attack_attrs["modified_indices"] & indices_to_replace)
if "last_transformation" in x1_text.attack_attrs:
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[
"last_transformation"
]
elif "last_transformation" in x2_text.attack_attrs:
new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[
"last_transformation"
]
if not self.post_crossover_check or (
new_text.text == x1_text.text or new_text.text == x2_text.text
):
break
if "last_transformation" in x1_text.attack_attrs:
passed_constraints = self._check_constraints(
new_text, x1_text, original_text=original_result.attacked_text
if "last_transformation" in new_text.attack_attrs:
previous_text = (
x1_text
if "last_transformation" in x1_text.attack_attrs
else x2_text
)
elif "last_transformation" in x2_text.attack_attrs:
passed_constraints = self._check_constraints(
new_text, x2_text, original_text=original_result.attacked_text
new_text, previous_text, original_text=original_text
)
else:
passed_constraints = True
@@ -183,10 +194,10 @@ class GeneticAlgorithm(PopulationBasedMethod):
"""
words = initial_result.attacked_text.words
num_replacements_per_word = np.zeros(len(words))
transformations = self.get_transformations(
transformed_texts = self.get_transformations(
initial_result.attacked_text, original_text=initial_result.attacked_text
)
for transformed_text in transformations:
for transformed_text in transformed_texts:
diff_idx = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
@@ -194,12 +205,11 @@ class GeneticAlgorithm(PopulationBasedMethod):
# Just b/c there are no replacements now doesn't mean we never want to select the word for perturbation
# Therefore, we give small non-zero probability for words with no replacements
# Epsilon is some small number to approximately assign 1% probability
num_total_candidates = np.sum(num_replacements_per_word)
epsilon = max(1, int(num_total_candidates * 0.01))
# Epsilon is some small number to approximately assign small probability
min_num_candidates = np.amin(num_replacements_per_word)
epsilon = max(1, int(min_num_candidates * 0.1))
for i in range(len(num_replacements_per_word)):
if num_replacements_per_word[i] == 0:
num_replacements_per_word[i] = epsilon
num_replacements_per_word[i] = max(num_replacements_per_word[i], epsilon)
population = []
for _ in range(pop_size):
@@ -249,7 +259,7 @@ class GeneticAlgorithm(PopulationBasedMethod):
child = self._crossover(
population[parent1_idx[idx]],
population[parent2_idx[idx]],
initial_result,
initial_result.attacked_text,
)
if self._search_over:
break
@@ -261,8 +271,6 @@ class GeneticAlgorithm(PopulationBasedMethod):
# `crossover` method and `perturb` method.
if self._search_over:
break
if self._search_over:
break
population = [population[0]] + children