1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

ga bug fix

This commit is contained in:
uvafan
2020-06-05 20:37:02 -04:00
parent 252b9d2064
commit e1a66fabf7

View File

@@ -77,17 +77,18 @@ class GeneticAlgorithm(SearchMethod):
neighbors_len[rand_idx] = 0
iterations += 1
def _generate_population(self, neighbors_len):
def _generate_population(self, neighbors_len, initial_result):
"""
Generates a population of texts each with one word replaced
Args:
neighbors_len: A list of the number of candidate neighbors for each word.
initial_result: The result to instantiate the population with
Returns:
The population.
"""
pop = []
for _ in range(self.pop_size):
pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len))
pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len), initial_result)
self._perturb(pop_member)
pop.append(pop_member)
return pop
@@ -137,7 +138,7 @@ class GeneticAlgorithm(SearchMethod):
self.original_tokenized_text = initial_result.tokenized_text
self.correct_output = initial_result.output
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
pop = self._generate_population(neighbors_len)
pop = self._generate_population(neighbors_len, initial_result)
cur_score = initial_result.score
for i in range(self.max_iters):
pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
@@ -193,6 +194,7 @@ class PopulationMember:
tokenized_text: The ``TokenizedText`` of the population member.
neighbors_len: A list of the number of candidate neighbors list for each word.
"""
def __init__(self, tokenized_text, neighbors_len):
def __init__(self, tokenized_text, neighbors_len, result):
self.tokenized_text = tokenized_text
self.neighbors_len = neighbors_len
self.result = result