mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
ga bug fix
This commit is contained in:
@@ -77,17 +77,18 @@ class GeneticAlgorithm(SearchMethod):
|
|||||||
neighbors_len[rand_idx] = 0
|
neighbors_len[rand_idx] = 0
|
||||||
iterations += 1
|
iterations += 1
|
||||||
|
|
||||||
def _generate_population(self, neighbors_len):
|
def _generate_population(self, neighbors_len, initial_result):
|
||||||
"""
|
"""
|
||||||
Generates a population of texts each with one word replaced
|
Generates a population of texts each with one word replaced
|
||||||
Args:
|
Args:
|
||||||
neighbors_len: A list of the number of candidate neighbors for each word.
|
neighbors_len: A list of the number of candidate neighbors for each word.
|
||||||
|
initial_result: The result to instantiate the population with
|
||||||
Returns:
|
Returns:
|
||||||
The population.
|
The population.
|
||||||
"""
|
"""
|
||||||
pop = []
|
pop = []
|
||||||
for _ in range(self.pop_size):
|
for _ in range(self.pop_size):
|
||||||
pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len))
|
pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len), initial_result)
|
||||||
self._perturb(pop_member)
|
self._perturb(pop_member)
|
||||||
pop.append(pop_member)
|
pop.append(pop_member)
|
||||||
return pop
|
return pop
|
||||||
@@ -137,7 +138,7 @@ class GeneticAlgorithm(SearchMethod):
|
|||||||
self.original_tokenized_text = initial_result.tokenized_text
|
self.original_tokenized_text = initial_result.tokenized_text
|
||||||
self.correct_output = initial_result.output
|
self.correct_output = initial_result.output
|
||||||
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
||||||
pop = self._generate_population(neighbors_len)
|
pop = self._generate_population(neighbors_len, initial_result)
|
||||||
cur_score = initial_result.score
|
cur_score = initial_result.score
|
||||||
for i in range(self.max_iters):
|
for i in range(self.max_iters):
|
||||||
pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
|
pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
|
||||||
@@ -193,6 +194,7 @@ class PopulationMember:
|
|||||||
tokenized_text: The ``TokenizedText`` of the population member.
|
tokenized_text: The ``TokenizedText`` of the population member.
|
||||||
neighbors_len: A list of the number of candidate neighbors list for each word.
|
neighbors_len: A list of the number of candidate neighbors list for each word.
|
||||||
"""
|
"""
|
||||||
def __init__(self, tokenized_text, neighbors_len):
|
def __init__(self, tokenized_text, neighbors_len, result):
|
||||||
self.tokenized_text = tokenized_text
|
self.tokenized_text = tokenized_text
|
||||||
self.neighbors_len = neighbors_len
|
self.neighbors_len = neighbors_len
|
||||||
|
self.result = result
|
||||||
|
|||||||
Reference in New Issue
Block a user