1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

refactor PSO

This commit is contained in:
Jin Yong Yoo
2020-07-13 13:14:59 -04:00
parent 39e4a8e2c3
commit c351f537cd
3 changed files with 109 additions and 71 deletions

View File

@@ -51,6 +51,6 @@ def PSOZang2020(model):
#
# Perform word substitution with a Particle Swarm Optimization (PSO) algorithm.
#
search_method = ParticleSwarmOptimization(pop_size=60, max_iters=20)
search_method = ParticleSwarmOptimization(pop_size=60, max_iters=20, post_turn_check=False)
return Attack(goal_function, constraints, transformation, search_method)

View File

@@ -190,11 +190,7 @@ class GeneticAlgorithm(SearchMethod):
if self.post_crossover_check and not passed_constraints:
# If we cannot find a child that passes the constraints,
# we just randomly pick one of the parents to be the child for the next iteration.
pop_mem= (
pop_member1
if np.random.uniform() < 0.5
else pop_member2
)
pop_mem = pop_member1 if np.random.uniform() < 0.5 else pop_member2
return pop_mem
else:
new_results, self._search_over = self.get_goal_results([new_text])

View File

@@ -48,7 +48,6 @@ class ParticleSwarmOptimization(SearchMethod):
self.C2_origin = 0.2
self.V_max = 3.0
def _norm(self, n):
n = [max(0, i) for i in n]
s = sum(n)
@@ -77,34 +76,40 @@ class ParticleSwarmOptimization(SearchMethod):
Returns:
New `Position` that we moved to (or if we fail to move, same as `source_x`)
"""
assert len(source_x.words) == len(target_x.words), "Word length mismatch for turn operation."
assert len(source_x.words) == len(prob), "Length mistmatch for words and probability list."
assert len(source_x.words) == len(
target_x.words
), "Word length mismatch for turn operation."
assert len(source_x.words) == len(
prob
), "Length mistmatch for words and probability list."
len_x = len(source_x.words)
num_tries = 0
passed_constraints = False
while num_tries < self.max_crossover_retries + 1:
while num_tries < self.max_turn_retries + 1:
indices_to_replace = []
words_to_replace = []
for i in range(len_x):
if np.random.uniform() < prob[i]:
indices_to_replace.append(i)
words_to_replace.append(target_x.words[i])
new_text = source_x.attacked_text.replace_words_at_indices(indices_to_replace, words_to_replace)
new_text = source_x.attacked_text.replace_words_at_indices(
indices_to_replace, words_to_replace
)
if not self.post_turn_check or (new_text.words == source_x.words):
break
if "last_transformation" in source_x.attacked_text.attack_attrs:
new_text.attack_attrs["last_transformation"] = source_x.attacked_text.attack_attrs[
new_text.attack_attrs[
"last_transformation"
]
] = source_x.attacked_text.attack_attrs["last_transformation"]
filtered = self.filter_transformations(
[new_text], source_x.attacked_text, original_text=original_text
)
else:
# In this case, source_x has not been transformed,
# meaning that new_text = source_x = original_text
# meaning that new_text = source_x = original_text
filtered = [new_text]
if filtered:
@@ -120,13 +125,13 @@ class ParticleSwarmOptimization(SearchMethod):
else:
return Position(new_text)
def _get_best_neighbors(self, current_position, neighbors_list):
def _get_best_neighbors(self, neighbors_list, current_position):
"""
For given `current_position`, find the neighboring position that yields
For given `current_position`, find the neighboring position that yields
maximum improvement (in goal function score) for each word.
Args:
current_text (AttackedText): Current position
neighbors_list (list[list[AttackedText]]): List of "neighboring" AttackedText for each word in `current_text`.
neighbors_list (list[list[AttackedText]]): List of "neighboring" AttackedText for each word in `current_text`.
current_position (Position): Current position
Returns:
best_neighbors (list[Position]): Best neighboring positions for each word
prob_list (list[float]): discrete probablity distribution for sampling a neighbor from `best_neighbors`
@@ -135,20 +140,24 @@ class ParticleSwarmOptimization(SearchMethod):
score_list = []
for i in range(len(neighbors_list)):
if not neighbors_list[i]:
candidate_list.append(current_position)
best_neighbors.append(current_position)
score_list.append(0)
continue
neighbor_results, self._search_over = self.get_goal_results(neighbors_list[i])
neighbor_results, self._search_over = self.get_goal_results(
neighbors_list[i]
)
if neighbor_results:
# This is incase query budget forces len(neighbor_results) == 0
neighbor_scores = np.array([r.score for in neighbor_results])
score_diff = neighbor_scores - current_result.score
best_idx = np.argmax(neighbor_scores)[0]
best_neighbors.append(Position(neighbors_list[i][best_idx], neighbors_results[best_idx]))
neighbor_scores = np.array([r.score for r in neighbor_results])
score_diff = neighbor_scores - current_position.score
best_idx = np.argmax(neighbor_scores)
best_neighbors.append(
Position(neighbors_list[i][best_idx], neighbor_results[best_idx])
)
score_list.append(score_diff[best_idx])
if self._search_over
if self._search_over:
break
prob_list = self._norm(score_list)
@@ -159,33 +168,30 @@ class ParticleSwarmOptimization(SearchMethod):
"""
For each word in `current_text`, find list of available transformations.
Args:
current_Text (AttackedText)
current_text (AttackedText): Current text
original_text (AttackedText): Original text for constraint check
Returns:
`list[list[AttackedText]]` representing list of candidate neighbors for each word
"""
words = attacked_text.words
neighbors_list = [[] for _ in range(len(words))]
neighbors_list = [[] for _ in range(len(current_text.words))]
transformations = self.get_transformations(
current_text, original_text=original_text
)
for transformed_text in transformations:
try:
diff_idx = next(iter(transformed_text.attack_attrs["newly_modified_indices"]))
neighbors_list[diff_idx].append(transformed_text)
except:
assert len(attacked_text.words) == len(transformed_text.words)
assert all(
[
w1 == w2
for w1, w2 in zip(attacked_text.words, transformed_text.words)
]
)
diff_idx = next(
iter(transformed_text.attack_attrs["newly_modified_indices"])
)
neighbors_list[diff_idx].append(transformed_text)
return neighbors_list
def _mutate(self, current_position, original_text):
neighbors_list = self._get_neighbors_list(current_position.attacked_text, original_text)
candidate_list, prob_list = self._get_best_neighbors(neighbors_list, original_text)
neighbors_list = self._get_neighbors_list(
current_position.attacked_text, original_text
)
candidate_list, prob_list = self._get_best_neighbors(
neighbors_list, current_position
)
if self._search_over:
return current_position
random_candidate = np.random.choice(candidate_list, 1, p=prob_list)[0]
@@ -199,8 +205,12 @@ class ParticleSwarmOptimization(SearchMethod):
Returns:
`list[Position]` representing population
"""
neighbors_list = self._get_neighbors_list(initial_position.attacked_text, initial_position.attacked_text)
best_neighbors, prob_list = self._get_best_neighbors(neighbors_list, initial_position)
neighbors_list = self._get_neighbors_list(
initial_position.attacked_text, initial_position.attacked_text
)
best_neighbors, prob_list = self._get_best_neighbors(
neighbors_list, initial_position
)
population = []
for _ in range(self.pop_size):
# Mutation step
@@ -210,18 +220,24 @@ class ParticleSwarmOptimization(SearchMethod):
def _perform_search(self, initial_result):
self._search_over = False
original_position = Position(initial_result)
original_position = Position(initial_result.attacked_text, initial_result)
# get word substitute candidates and generate population
population = self._generate_population(original_position)
global_elite = max(population, key=lambda x: x.score)
if self._search_over or global_elite.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return global_elite
if (
self._search_over
or global_elite.result.goal_status == GoalFunctionResultStatus.SUCCEEDED
):
return global_elite.result
# rank the scores from low to high and check if there is a successful attack
local_elites = deepcopy(population)
# set up hyper-parameters
V = np.random.uniform(-self.V_max, self.V_max, self.pop_size)
V_P = [[V[t] for _ in range(x_len)] for t in range(self.pop_size)]
V_P = [
[V[t] for _ in range(len(initial_result.attacked_text.words))]
for t in range(self.pop_size)
]
# start iterations
for i in range(self.max_iters):
@@ -233,41 +249,59 @@ class ParticleSwarmOptimization(SearchMethod):
P1 = C1
P2 = C2
new_population = []
for k in range(self.pop_size):
for k in range(len(population)):
# calculate the probability of turning each word
particle_words = population[k].words
local_elite_words = local_elites[k].words
assert len(particle_words) == len(local_elite_words), "PSO word length mismatch!"
assert len(particle_words) == len(
local_elite_words
), "PSO word length mismatch!"
for dim in range(len(particle_words)):
V_P[k][dim] = omega * V_P[k][dim] + (1 - omega) * (
self._equal(particle_words[dim], local_elite_words[dim])
+ self._equal(particle_words[dim], local_elite_words[dim])
)
turn_prob = [self._sigmoid(V_P[k][d]) for d in range(len(particle_words))]
turn_prob = [
self._sigmoid(V_P[k][d]) for d in range(len(particle_words))
]
if np.random.uniform() < P1:
# Move towards local elite
population[k] = self._turn(local_elites[k], population[k], turn_prob)
population[k] = self._turn(
local_elites[k],
population[k],
turn_prob,
initial_result.attacked_text,
)
if np.random.uniform() < P2:
# Move towards global elite
population[k] = self._turn(global_elite, population[k], turn_prob)
population[k] = self._turn(
global_elite,
population[k],
turn_prob,
initial_result.attacked_text,
)
# Check if there is any successful attack in the current population
pop_results, self._search_over = self.get_goal_results([p.attacked_text for p in population])
for k in population:
pop_results, self._search_over = self.get_goal_results(
[p.attacked_text for p in population]
)
for k in range(len(population)):
population[k].result = pop_results[k]
top_result = max(population, key=lambda x: x.score).result
if self._search_over or top_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
if (
self._search_over
or top_result.goal_status == GoalFunctionResultStatus.SUCCEEDED
):
return top_result
# Mutation based on the current change rate
for k in range(self.pop_size):
for k in range(len(population)):
p = population[k]
change_ratio = self._count_change_ratio(p, initial_result.attacked_text)
# Referred from the original source code
p_change = 1 - 2 * change_ratio
p_change = 1 - 2 * change_ratio
if np.random.uniform() < p_change:
population[k] = self._mutate(p, initial_result.attacked_text)
@@ -275,24 +309,30 @@ class ParticleSwarmOptimization(SearchMethod):
break
if self._search_over:
return top_result
# Check if there is any successful attack in the current population
pop_results, self._search_over = self.get_goal_results([p.attacked_text for p in population])
for k in population:
pop_results, self._search_over = self.get_goal_results(
[p.attacked_text for p in population]
)
for k in range(len(population)):
population[k].result = pop_results[k]
top_result = max(population, key=lambda x: x.score).result
if self._search_over or top_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
if (
self._search_over
or top_result.goal_status == GoalFunctionResultStatus.SUCCEEDED
):
return top_result
# Update the elite if the score is increased
for k in range(self.pop_size):
for k in range(len(population)):
if population[k].score > local_elites[k].score:
local_elites[k] = deepcopy(population[k])
if top_result.score > global_elite.score:
global_elite = deepcopy(top_result)
return global_elite
if top_result.score > global_elite.score:
elite_result = deepcopy(top_result)
global_elite = Position(elite_result.attacked_text, elite_result)
return global_elite.result
def check_transformation_compatibility(self, transformation):
"""The genetic algorithm is specifically designed for word
@@ -302,6 +342,7 @@ class ParticleSwarmOptimization(SearchMethod):
def extra_repr_keys(self):
return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]
class Position:
"""
Helper class for particle-swarm optimization.
@@ -310,14 +351,15 @@ class Position:
attacked_text (:obj:`AttackedText`): `AttackedText` for the transformed text
result (:obs:`GoalFunctionResult`, optional): `GoalFunctionResult` for the transformed text
"""
def __init__(self, attacked_text, result=None):
self.attacked_text = attacked_text
self.result = result
@property
def score(self):
if not result:
raise ValueError("\"result\" attribute undefined for Position")
if not self.result:
raise ValueError('"result" attribute undefined for Position')
return self.result.score
@property