mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
update attackedtext references, need to update tokenization
This commit is contained in:
@@ -45,14 +45,14 @@ class GeneticAlgorithm(SearchMethod):
|
||||
Whether a replacement which increased the score was found.
|
||||
"""
|
||||
transformations = self.get_transformations(
|
||||
pop_member.tokenized_text,
|
||||
original_text=self.original_tokenized_text,
|
||||
pop_member.attacked_text,
|
||||
original_text=self.original_attacked_text,
|
||||
indices_to_modify=[idx],
|
||||
)
|
||||
if not len(transformations):
|
||||
return False
|
||||
orig_result, self.search_over = self.get_goal_results(
|
||||
[pop_member.tokenized_text], self.correct_output
|
||||
[pop_member.attacked_text], self.correct_output
|
||||
)
|
||||
if self.search_over:
|
||||
return False
|
||||
@@ -62,7 +62,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||
new_x_scores = new_x_scores - orig_result[0].score
|
||||
if len(new_x_scores) and new_x_scores.max() > 0:
|
||||
pop_member.tokenized_text = transformations[new_x_scores.argmax()]
|
||||
pop_member.attacked_text = transformations[new_x_scores.argmax()]
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -99,7 +99,7 @@ class GeneticAlgorithm(SearchMethod):
|
||||
pop = []
|
||||
for _ in range(self.pop_size):
|
||||
pop_member = PopulationMember(
|
||||
self.original_tokenized_text, deepcopy(neighbors_len), initial_result
|
||||
self.original_attacked_text, deepcopy(neighbors_len), initial_result
|
||||
)
|
||||
self._perturb(pop_member)
|
||||
pop.append(pop_member)
|
||||
@@ -116,8 +116,8 @@ class GeneticAlgorithm(SearchMethod):
|
||||
"""
|
||||
indices_to_replace = []
|
||||
words_to_replace = []
|
||||
x1_text = pop_member1.tokenized_text
|
||||
x2_words = pop_member2.tokenized_text.words
|
||||
x1_text = pop_member1.attacked_text
|
||||
x2_words = pop_member2.attacked_text.words
|
||||
new_neighbors_len = deepcopy(pop_member1.neighbors_len)
|
||||
for i in range(len(x1_text.words)):
|
||||
if np.random.uniform() < 0.5:
|
||||
@@ -129,35 +129,35 @@ class GeneticAlgorithm(SearchMethod):
|
||||
)
|
||||
return PopulationMember(new_text, deepcopy(new_neighbors_len))
|
||||
|
||||
def _get_neighbors_len(self, tokenized_text):
|
||||
def _get_neighbors_len(self, attacked_text):
|
||||
"""
|
||||
Generates this neighbors_len list
|
||||
Args:
|
||||
tokenized_text: The original text
|
||||
attacked_text: The original text
|
||||
Returns:
|
||||
A list of number of candidate neighbors for each word
|
||||
"""
|
||||
words = tokenized_text.words
|
||||
words = attacked_text.words
|
||||
neighbors_list = [[] for _ in range(len(words))]
|
||||
transformations = self.get_transformations(
|
||||
tokenized_text, original_text=self.original_tokenized_text
|
||||
attacked_text, original_text=self.original_attacked_text
|
||||
)
|
||||
for transformed_text in transformations:
|
||||
diff_idx = tokenized_text.first_word_diff_index(transformed_text)
|
||||
diff_idx = attacked_text.first_word_diff_index(transformed_text)
|
||||
neighbors_list[diff_idx].append(transformed_text.words[diff_idx])
|
||||
neighbors_list = [np.array(x) for x in neighbors_list]
|
||||
neighbors_len = np.array([len(x) for x in neighbors_list])
|
||||
return neighbors_len
|
||||
|
||||
def _perform_search(self, initial_result):
|
||||
self.original_tokenized_text = initial_result.tokenized_text
|
||||
self.original_attacked_text = initial_result.attacked_text
|
||||
self.correct_output = initial_result.output
|
||||
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
||||
neighbors_len = self._get_neighbors_len(self.original_attacked_text)
|
||||
pop = self._generate_population(neighbors_len, initial_result)
|
||||
cur_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results, self.search_over = self.get_goal_results(
|
||||
[pm.tokenized_text for pm in pop], self.correct_output
|
||||
[pm.attacked_text for pm in pop], self.correct_output
|
||||
)
|
||||
if self.search_over:
|
||||
if not len(pop_results):
|
||||
@@ -213,11 +213,11 @@ class PopulationMember:
|
||||
A member of the population during the course of the genetic algorithm.
|
||||
|
||||
Args:
|
||||
tokenized_text: The ``AttackedText`` of the population member.
|
||||
attacked_text: The ``AttackedText`` of the population member.
|
||||
neighbors_len: A list of the number of candidate neighbors list for each word.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenized_text, neighbors_len, result=None):
|
||||
self.tokenized_text = tokenized_text
|
||||
def __init__(self, attacked_text, neighbors_len, result=None):
|
||||
self.attacked_text = attacked_text
|
||||
self.neighbors_len = neighbors_len
|
||||
self.result = result
|
||||
|
||||
Reference in New Issue
Block a user