mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
improve augmentation; merge in fix-docs
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""
|
||||
Algorithm from Generating Natural Language Adversarial Examples by Alzantot et. al
|
||||
Reimplementatio of search method from Generating Natural Language Adversarial Examples
|
||||
by Alzantot et. al
|
||||
`<arxiv.org/abs/1804.07998>`_
|
||||
`<github.com/nesl/nlp_adversarial_examples>`_
|
||||
"""
|
||||
@@ -8,26 +9,19 @@ import numpy as np
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from .attack import Attack
|
||||
from textattack.attack_results import FailedAttackResult, SuccessfulAttackResult
|
||||
from textattack.transformations import WordSwap
|
||||
from textattack.search_methods import SearchMethod
|
||||
from textattack.shared.validators import transformation_consists_of_word_swaps
|
||||
|
||||
class GeneticAlgorithm(Attack):
|
||||
class GeneticAlgorithm(SearchMethod):
|
||||
"""
|
||||
Attacks a model using a genetic algorithm.
|
||||
Attacks a model with word substiutitions using a genetic algorithm.
|
||||
|
||||
Args:
|
||||
goal_function: A function for determining how well a perturbation is doing at achieving the attack's goal.
|
||||
transformation: The type of transformation to use. Should be a subclass of WordSwap.
|
||||
pop_size (:obj:`int`, optional): The population size. Defauls to 20.
|
||||
max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 50.
|
||||
Raises:
|
||||
ValueError: If the transformation is not a subclass of WordSwap.
|
||||
"""
|
||||
def __init__(self, goal_function, transformation, constraints=[], pop_size=20, max_iters=50, temp=0.3,
|
||||
give_up_if_no_improvement=False):
|
||||
if not isinstance(transformation, WordSwap):
|
||||
raise ValueError(f'Transformation is of type {type(transformation)}, should be a subclass of WordSwap')
|
||||
super().__init__(goal_function, transformation, constraints=constraints)
|
||||
|
||||
def __init__(self, pop_size=20, max_iters=50, temp=0.3, give_up_if_no_improvement=False):
|
||||
self.max_iters = max_iters
|
||||
self.pop_size = pop_size
|
||||
self.temp = temp
|
||||
@@ -37,20 +31,22 @@ class GeneticAlgorithm(Attack):
|
||||
"""
|
||||
Select the best replacement for word at position (idx)
|
||||
in (pop_member) to maximize score.
|
||||
|
||||
Args:
|
||||
pop_member: The population member being perturbed.
|
||||
idx: The index at which to replace a word.
|
||||
|
||||
Returns:
|
||||
Whether a replacement which decreased the score was found.
|
||||
"""
|
||||
transformations = self.get_transformations(pop_member.tokenized_text,
|
||||
original_text=self.original_tokenized_text,
|
||||
indices_to_replace=[idx])
|
||||
indices_to_modify=[idx])
|
||||
if not len(transformations):
|
||||
return False
|
||||
new_x_results = self.goal_function.get_results(transformations, self.correct_output)
|
||||
new_x_results = self.get_goal_results(transformations, self.correct_output)
|
||||
new_x_scores = torch.Tensor([r.score for r in new_x_results])
|
||||
orig_score = self.goal_function.get_results([pop_member.tokenized_text], self.correct_output)[0].score
|
||||
orig_score = self.get_goal_results([pop_member.tokenized_text], self.correct_output)[0].score
|
||||
new_x_scores = new_x_scores - orig_score
|
||||
if new_x_scores.max() > 0:
|
||||
pop_member.tokenized_text = transformations[new_x_scores.argmax()]
|
||||
@@ -126,8 +122,7 @@ class GeneticAlgorithm(Attack):
|
||||
words = tokenized_text.words
|
||||
neighbors_list = [[] for _ in range(len(words))]
|
||||
transformations = self.get_transformations(tokenized_text,
|
||||
original_text=self.original_tokenized_text,
|
||||
apply_constraints=False)
|
||||
original_text=self.original_tokenized_text)
|
||||
for transformed_text in transformations:
|
||||
diff_idx = tokenized_text.first_word_diff_index(transformed_text)
|
||||
neighbors_list[diff_idx].append(transformed_text.words[diff_idx])
|
||||
@@ -135,15 +130,14 @@ class GeneticAlgorithm(Attack):
|
||||
neighbors_len = np.array([len(x) for x in neighbors_list])
|
||||
return neighbors_len
|
||||
|
||||
def attack_one(self, tokenized_text, correct_output):
|
||||
self.original_tokenized_text = tokenized_text
|
||||
self.correct_output = correct_output
|
||||
original_result = self.goal_function.get_results([tokenized_text], correct_output)[0]
|
||||
neighbors_len = self._get_neighbors_len(tokenized_text)
|
||||
def _perform_search(self, initial_result):
|
||||
self.original_tokenized_text = initial_result.tokenized_text
|
||||
self.correct_output = initial_result.output
|
||||
neighbors_len = self._get_neighbors_len(self.original_tokenized_text)
|
||||
pop = self._generate_population(neighbors_len)
|
||||
cur_score = original_result.score
|
||||
cur_score = initial_result.score
|
||||
for i in range(self.max_iters):
|
||||
pop_results = self.goal_function.get_results([pm.tokenized_text for pm in pop], correct_output)
|
||||
pop_results = self.get_goal_results([pm.tokenized_text for pm in pop], self.correct_output)
|
||||
for idx, result in enumerate(pop_results):
|
||||
pop[idx].result = pop_results[idx]
|
||||
pop = sorted(pop, key=lambda x: -x.result.score)
|
||||
@@ -154,10 +148,7 @@ class GeneticAlgorithm(Attack):
|
||||
select_probs = (logits / logits.sum()).cpu().numpy()
|
||||
|
||||
if pop[0].result.succeeded:
|
||||
return SuccessfulAttackResult(
|
||||
original_result,
|
||||
pop[0].result
|
||||
)
|
||||
return pop[0].result
|
||||
|
||||
if pop[0].result.score > cur_score:
|
||||
cur_score = pop[0].result.score
|
||||
@@ -177,13 +168,23 @@ class GeneticAlgorithm(Attack):
|
||||
|
||||
pop = elite + children
|
||||
|
||||
return FailedAttackResult(original_result, pop[0].result)
|
||||
return pop[0].result
|
||||
|
||||
def check_transformation_compatibility(self, transformation):
|
||||
"""
|
||||
The genetic algorithm is specifically designed for word substitutions.
|
||||
"""
|
||||
return transformation_consists_of_word_swaps(transformation)
|
||||
|
||||
def extra_repr_keys(self):
|
||||
return ['pop_size', 'max_iters', 'temp', 'give_up_if_no_improvement']
|
||||
|
||||
class PopulationMember:
|
||||
"""
|
||||
A member of the population during the course of the genetic algorithm.
|
||||
|
||||
Args:
|
||||
tokenized_text: The tokenized text of the population member.
|
||||
tokenized_text: The ``TokenizedText`` of the population member.
|
||||
neighbors_len: A list of the number of candidate neighbors list for each word.
|
||||
"""
|
||||
def __init__(self, tokenized_text, neighbors_len):
|
||||
|
||||
Reference in New Issue
Block a user