1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

make requested fixes

This commit is contained in:
uvafan
2020-05-18 18:30:23 -04:00
parent 0794f6ed38
commit 8c265fe23f
20 changed files with 113 additions and 176 deletions

View File

@@ -9,7 +9,7 @@ import torch
from copy import deepcopy
from textattack.search_methods import SearchMethod
from textattack.shared.validators import is_word_swap
from textattack.shared.validators import consists_of_word_swaps
class GeneticAlgorithm(SearchMethod):
"""
@@ -26,9 +26,6 @@ class GeneticAlgorithm(SearchMethod):
self.temp = temp
self.give_up_if_no_improvement = give_up_if_no_improvement
def check_transformation_compatibility(self, transformation):
return transformation.consists_of(is_word_swap)
def _replace_at_index(self, pop_member, idx):
"""
Select the best replacement for word at position (idx)
@@ -133,7 +130,7 @@ class GeneticAlgorithm(SearchMethod):
neighbors_len = np.array([len(x) for x in neighbors_list])
return neighbors_len
def __call__(self, initial_result):
def _perform_search(self, initial_result):
self.original_tokenized_text = initial_result.tokenized_text
self.correct_output = initial_result.output
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
@@ -173,6 +170,12 @@ class GeneticAlgorithm(SearchMethod):
return pop[0].result
def check_transformation_compatibility(self, transformation):
"""
The genetic algorithm is specifically designed for word substitutions.
"""
return consists_of_word_swaps(transformation)
def extra_repr_keys(self):
return ['pop_size', 'max_iters', 'temp', 'give_up_if_no_improvement']