mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
low prob assignment for zero transformation cases
This commit is contained in:
@@ -25,9 +25,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
Higher temperature increases the sensitivity to lower probability candidates.
|
||||
give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found.
|
||||
max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints.
|
||||
Setting it to 0 means we immediately take one of the parents at random as the child.
|
||||
compare_againt_original (bool): If True, the reference text for constraints is the original text.
|
||||
Else, the reference text is the most recent text from which the new text is generated.
|
||||
Setting it to 0 means we immediately take one of the parents at random as the child.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -44,7 +42,6 @@ class GeneticAlgorithm(SearchMethod):
|
||||
self.temp = temp
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
self.max_crossover_retries = max_crossover_retries
|
||||
self.compare_against_original = compare_against_original
|
||||
|
||||
# internal flag to indicate if search should end immediately
|
||||
self._search_over = False
|
||||
@@ -60,7 +57,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
"""
|
||||
num_words = pop_member.num_candidates_per_word.shape[0]
|
||||
num_candidates_per_word = np.copy(pop_member.num_candidates_per_word)
|
||||
non_zero_indices = np.sum(np.sign(pop_member.num_candidates_per_word))
|
||||
non_zero_indices = np.count_nonzero(num_candidates_per_word)
|
||||
if non_zero_indices == 0:
|
||||
return
|
||||
iterations = 0
|
||||
@@ -68,16 +65,11 @@ class GeneticAlgorithm(SearchMethod):
|
||||
w_select_probs = num_candidates_per_word / np.sum(num_candidates_per_word)
|
||||
rand_idx = np.random.choice(num_words, 1, p=w_select_probs)[0]
|
||||
|
||||
if self.compare_against_original:
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=original_result.attacked_text,
|
||||
indices_to_modify=[rand_idx],
|
||||
)
|
||||
else:
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text, indices_to_modify=[rand_idx],
|
||||
)
|
||||
transformations = self.get_transformations(
|
||||
pop_member.attacked_text,
|
||||
original_text=original_result.attacked_text,
|
||||
indices_to_modify=[rand_idx]
|
||||
)
|
||||
|
||||
if not len(transformations):
|
||||
iterations += 1
|
||||
@@ -179,12 +171,15 @@ class GeneticAlgorithm(SearchMethod):
|
||||
)
|
||||
num_candidates_per_word[diff_idx] += 1
|
||||
|
||||
total_candidates = np.sum(num_candidates_per_word)
|
||||
# Just b/c there are no candidates now doesn't mean we never want to select the word for perturbation
|
||||
# Therefore, we give small non-zero probability for words with no candidates
|
||||
# Epsilon is some small number to approximately assign 1% probability
|
||||
total_candidates = np.sum(num_candidates_per_word)
|
||||
num_zero_elements = len(words) - np.count_zero(num_candidates_per_word)
|
||||
epsilon = min(1, int(total_candidates / (100 - num_zero_elements)))
|
||||
for i in range(len(num_candidates_per_word)):
|
||||
if num_candidates_per_word[i] == 0:
|
||||
num_candidates_per_word[i] = int(total_candidates * 0.01)
|
||||
num_candidates_per_word[i] = epsilon
|
||||
|
||||
population = []
|
||||
for _ in range(self.pop_size):
|
||||
|
||||
Reference in New Issue
Block a user