1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

allow maximization goal functions

This commit is contained in:
uvafan
2020-06-23 23:33:48 -04:00
parent fe109267a1
commit 0fcfb51b7f
19 changed files with 115 additions and 78 deletions

View File

@@ -10,6 +10,7 @@ from copy import deepcopy
import numpy as np
import torch
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
from textattack.shared.validators import transformation_consists_of_word_swaps
@@ -52,12 +53,12 @@ class GeneticAlgorithm(SearchMethod):
if not len(transformations):
return False
orig_result, self.search_over = self.get_goal_results(
[pop_member.attacked_text], self.correct_output
[pop_member.attacked_text]
)
if self.search_over:
return False
new_x_results, self.search_over = self.get_goal_results(
transformations, self.correct_output
transformations
)
new_x_scores = torch.Tensor([r.score for r in new_x_results])
new_x_scores = new_x_scores - orig_result[0].score
@@ -157,7 +158,7 @@ class GeneticAlgorithm(SearchMethod):
cur_score = initial_result.score
for i in range(self.max_iters):
pop_results, self.search_over = self.get_goal_results(
[pm.attacked_text for pm in pop], self.correct_output
[pm.attacked_text for pm in pop]
)
if self.search_over:
if not len(pop_results):
@@ -171,7 +172,7 @@ class GeneticAlgorithm(SearchMethod):
logits = ((-pop_scores) / self.temp).exp()
select_probs = (logits / logits.sum()).cpu().numpy()
if pop[0].result.succeeded:
if pop[0].result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return pop[0].result
if pop[0].result.score > cur_score: