1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

ga bug fix

This commit is contained in:
uvafan
2020-06-05 20:37:02 -04:00
parent 252b9d2064
commit e1a66fabf7

View File

@@ -77,17 +77,18 @@ class GeneticAlgorithm(SearchMethod):
neighbors_len[rand_idx] = 0 neighbors_len[rand_idx] = 0
iterations += 1 iterations += 1
def _generate_population(self, neighbors_len): def _generate_population(self, neighbors_len, initial_result):
""" """
Generates a population of texts each with one word replaced Generates a population of texts each with one word replaced
Args: Args:
neighbors_len: A list of the number of candidate neighbors for each word. neighbors_len: A list of the number of candidate neighbors for each word.
initial_result: The result to instantiate the population with
Returns: Returns:
The population. The population.
""" """
pop = [] pop = []
for _ in range(self.pop_size): for _ in range(self.pop_size):
pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len)) pop_member = PopulationMember(self.original_tokenized_text, deepcopy(neighbors_len), initial_result)
self._perturb(pop_member) self._perturb(pop_member)
pop.append(pop_member) pop.append(pop_member)
return pop return pop
@@ -137,7 +138,7 @@ class GeneticAlgorithm(SearchMethod):
self.original_tokenized_text = initial_result.tokenized_text self.original_tokenized_text = initial_result.tokenized_text
self.correct_output = initial_result.output self.correct_output = initial_result.output
neighbors_len = self._get_neighbors_len(self.original_tokenized_text) neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
pop = self._generate_population(neighbors_len) pop = self._generate_population(neighbors_len, initial_result)
cur_score = initial_result.score cur_score = initial_result.score
for i in range(self.max_iters): for i in range(self.max_iters):
pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output) pop_results, self.search_over = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
@@ -193,6 +194,7 @@ class PopulationMember:
tokenized_text: The ``TokenizedText`` of the population member. tokenized_text: The ``TokenizedText`` of the population member.
neighbors_len: A list of the number of candidate neighbors list for each word. neighbors_len: A list of the number of candidate neighbors list for each word.
""" """
def __init__(self, tokenized_text, neighbors_len): def __init__(self, tokenized_text, neighbors_len, result):
self.tokenized_text = tokenized_text self.tokenized_text = tokenized_text
self.neighbors_len = neighbors_len self.neighbors_len = neighbors_len
self.result = result