mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
make requested fixes
This commit is contained in:
@@ -9,7 +9,7 @@ import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.shared.validators import is_word_swap
|
||||
from textattack.shared.validators import consists_of_word_swaps
|
||||
|
||||
class GeneticAlgorithm(SearchMethod):
|
||||
"""
|
||||
@@ -26,9 +26,6 @@ class GeneticAlgorithm(SearchMethod):
|
||||
self.temp = temp
|
||||
self.give_up_if_no_improvement = give_up_if_no_improvement
|
||||
|
||||
def check_transformation_compatibility(self, transformation):
|
||||
return transformation.consists_of(is_word_swap)
|
||||
|
||||
def _replace_at_index(self, pop_member, idx):
|
||||
"""
|
||||
Select the best replacement for word at position (idx)
|
||||
@@ -133,7 +130,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
neighbors_len = np.array([len(x) for x in neighbors_list])
|
||||
return neighbors_len
|
||||
|
||||
def __call__(self, initial_result):
|
||||
def _perform_search(self, initial_result):
|
||||
self.original_tokenized_text = initial_result.tokenized_text
|
||||
self.correct_output = initial_result.output
|
||||
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
||||
@@ -173,6 +170,12 @@ class GeneticAlgorithm(SearchMethod):
|
||||
|
||||
return pop[0].result
|
||||
|
||||
def check_transformation_compatibility(self, transformation):
|
||||
"""
|
||||
The genetic algorithm is specifically designed for word substitutions.
|
||||
"""
|
||||
return consists_of_word_swaps(transformation)
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ['pop_size', 'max_iters', 'temp', 'give_up_if_no_improvement']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user