mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
discussed fixes
This commit is contained in:
@@ -26,6 +26,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
self.pop_size = pop_size
|
||||
self.temp = temp
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
self.search_over = False
|
||||
|
||||
def _replace_at_index(self, pop_member, idx):
|
||||
"""
|
||||
@@ -37,18 +38,20 @@ class GeneticAlgorithm(SearchMethod):
|
||||
idx: The index at which to replace a word.
|
||||
|
||||
Returns:
|
||||
Whether a replacement which decreased the score was found.
|
||||
Whether a replacement which increased the score was found.
|
||||
"""
|
||||
transformations = self.get_transformations(pop_member.tokenized_text,
|
||||
original_text=self.original_tokenized_text,
|
||||
indices_to_modify=[idx])
|
||||
if not len(transformations):
|
||||
return False
|
||||
new_x_results = self.get_goal_results(transformations, self.correct_output)
|
||||
orig_result, self.search_over = self.get_goal_results([pop_member.tokenized_text], self.correct_output)
|
||||
if self.search_over:
|
||||
return False
|
||||
new_x_results, self.search_over = self.get_goal_results(transformations, self.correct_output)
|
||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||
orig_score = self.get_goal_results([pop_member.tokenized_text], self.correct_output)[0].score
|
||||
new_x_scores = new_x_scores - orig_score
|
||||
if new_x_scores.max() > 0:
|
||||
new_x_scores = new_x_scores - orig_result[0].score
|
||||
if len(new_x_scores) and new_x_scores.max() > 0:
|
||||
pop_member.tokenized_text = transformations[new_x_scores.argmax()]
|
||||
return True
|
||||
return False
|
||||
@@ -65,7 +68,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
if non_zero_indices == 0:
|
||||
return
|
||||
iterations = 0
|
||||
while iterations < non_zero_indices:
|
||||
while iterations < non_zero_indices and not self.search_over:
|
||||
w_select_probs = neighbors_len / np.sum(neighbors_len)
|
||||
rand_idx = np.random.choice(x_len, 1, p=w_select_probs)[0]
|
||||
if self._replace_at_index(pop_member, rand_idx):
|
||||
@@ -137,10 +140,11 @@ class GeneticAlgorithm(SearchMethod):
|
||||
pop = self._generate_population(neighbors_len)
|
||||
cur_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
|
||||
if not len(pop_results):
|
||||
# Over query budget
|
||||
return pop[0].result
|
||||
pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
|
||||
if self.search_over:
|
||||
if not len(pop_results):
|
||||
return pop[0].result
|
||||
return max(pop_results, key=lambda x: x.score)
|
||||
for idx, result in enumerate(pop_results):
|
||||
pop[idx].result = pop_results[idx]
|
||||
pop = sorted(pop, key=lambda x: -x.result.score)
|
||||
|
||||
Reference in New Issue
Block a user