mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
ga bug fix
This commit is contained in:
@@ -77,17 +77,18 @@ class GeneticAlgorithm(SearchMethod):
|
||||
neighbors_len[rand_idx] = 0
|
||||
iterations += 1
|
||||
|
||||
def _generate_population(self, neighbors_len):
|
||||
def _generate_population(self, neighbors_len, initial_result):
|
||||
"""
|
||||
Generates a population of texts each with one word replaced
|
||||
Args:
|
||||
neighbors_len: A list of the number of candidate neighbors for each word.
|
||||
initial_result: The result to instantiate the population with
|
||||
Returns:
|
||||
The population.
|
||||
"""
|
||||
pop = []
|
||||
for _ in range(self.pop_size):
|
||||
pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len))
|
||||
pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len), initial_result)
|
||||
self._perturb(pop_member)
|
||||
pop.append(pop_member)
|
||||
return pop
|
||||
@@ -137,7 +138,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
self.original_tokenized_text = initial_result.tokenized_text
|
||||
self.correct_output = initial_result.output
|
||||
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
||||
pop = self._generate_population(neighbors_len)
|
||||
pop = self._generate_population(neighbors_len, initial_result)
|
||||
cur_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
|
||||
@@ -193,6 +194,7 @@ class PopulationMember:
|
||||
tokenized_text: The ``TokenizedText`` of the population member.
|
||||
neighbors_len: A list of the number of candidate neighbors list for each word.
|
||||
"""
|
||||
def __init__(self, tokenized_text, neighbors_len):
|
||||
def __init__(self, tokenized_text, neighbors_len, result):
|
||||
self.tokenized_text = tokenized_text
|
||||
self.neighbors_len = neighbors_len
|
||||
self.result = result
|
||||
|
||||
Reference in New Issue
Block a user