mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
allow maximization goal functions
This commit is contained in:
@@ -10,6 +10,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from textattack.goal_function_results import GoalFunctionResultStatus
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
@@ -52,12 +53,12 @@ class GeneticAlgorithm(SearchMethod):
|
||||
if not len(transformations):
|
||||
return False
|
||||
orig_result, self.search_over = self.get_goal_results(
|
||||
[pop_member.attacked_text], self.correct_output
|
||||
[pop_member.attacked_text]
|
||||
)
|
||||
if self.search_over:
|
||||
return False
|
||||
new_x_results, self.search_over = self.get_goal_results(
|
||||
transformations, self.correct_output
|
||||
transformations
|
||||
)
|
||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||
new_x_scores = new_x_scores - orig_result[0].score
|
||||
@@ -157,7 +158,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
cur_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results, self.search_over = self.get_goal_results(
|
||||
[pm.attacked_text for pm in pop], self.correct_output
|
||||
[pm.attacked_text for pm in pop]
|
||||
)
|
||||
if self.search_over:
|
||||
if not len(pop_results):
|
||||
@@ -171,7 +172,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
logits = ((-pop_scores) / self.temp).exp()
|
||||
select_probs = (logits / logits.sum()).cpu().numpy()
|
||||
|
||||
if pop[0].result.succeeded:
|
||||
if pop[0].result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
||||
return pop[0].result
|
||||
|
||||
if pop[0].result.score > cur_score:
|
||||
|
||||
Reference in New Issue
Block a user