1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

discussed fixes

This commit is contained in:
uvafan
2020-06-05 18:22:48 -04:00
parent d24bd67c0b
commit 298518afca
8 changed files with 49 additions and 48 deletions

View File

@@ -26,6 +26,7 @@ class GeneticAlgorithm(SearchMethod):
self.pop_size = pop_size
self.temp = temp
self.give_up_if_no_improvement = give_up_if_no_improvement
self.search_over = False
def _replace_at_index(self, pop_member, idx):
"""
@@ -37,18 +38,20 @@ class GeneticAlgorithm(SearchMethod):
idx: The index at which to replace a word.
Returns:
Whether a replacement which decreased the score was found.
Whether a replacement which increased the score was found.
"""
transformations = self.get_transformations(pop_member.tokenized_text,
original_text=self.original_tokenized_text,
indices_to_modify=[idx])
if not len(transformations):
return False
new_x_results = self.get_goal_results(transformations, self.correct_output)
orig_result, self.search_over = self.get_goal_results([pop_member.tokenized_text], self.correct_output)
if self.search_over:
return False
new_x_results, self.search_over = self.get_goal_results(transformations, self.correct_output)
new_x_scores = torch.Tensor([r.score for r in new_x_results])
orig_score = self.get_goal_results([pop_member.tokenized_text], self.correct_output)[0].score
new_x_scores = new_x_scores - orig_score
if new_x_scores.max() > 0:
new_x_scores = new_x_scores - orig_result[0].score
if len(new_x_scores) and new_x_scores.max() > 0:
pop_member.tokenized_text = transformations[new_x_scores.argmax()]
return True
return False
@@ -65,7 +68,7 @@ class GeneticAlgorithm(SearchMethod):
if non_zero_indices == 0:
return
iterations = 0
while iterations < non_zero_indices:
while iterations < non_zero_indices and not self.search_over:
w_select_probs = neighbors_len / np.sum(neighbors_len)
rand_idx = np.random.choice(x_len, 1, p=w_select_probs)[0]
if self._replace_at_index(pop_member, rand_idx):
@@ -137,10 +140,11 @@ class GeneticAlgorithm(SearchMethod):
pop = self._generate_population(neighbors_len)
cur_score = initial_result.score
for i in range(self.max_iters):
pop_results = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
if not len(pop_results):
# Over query budget
return pop[0].result
pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
if self.search_over:
if not len(pop_results):
return pop[0].result
return max(pop_results, key=lambda x: x.score)
for idx, result in enumerate(pop_results):
pop[idx].result = pop_results[idx]
pop = sorted(pop, key=lambda x: -x.result.score)