diff --git a/textattack/__init__.py b/textattack/__init__.py index 610780ca..f327bfd0 100644 --- a/textattack/__init__.py +++ b/textattack/__init__.py @@ -1,9 +1,9 @@ name = "textattack" -from .attacks import * -from .datasets import TextAttackDataset +from .shared import * from . import attack_recipes +from . import attack_results from . import attacks from . import constraints from . import datasets diff --git a/textattack/attack_recipes/alzantot_2018_genetic_algorithm.py b/textattack/attack_recipes/alzantot_2018_genetic_algorithm.py index c7f1ab49..9e210724 100644 --- a/textattack/attack_recipes/alzantot_2018_genetic_algorithm.py +++ b/textattack/attack_recipes/alzantot_2018_genetic_algorithm.py @@ -9,7 +9,7 @@ ArXiv, abs/1801.00554. """ -from textattack.attacks.blackbox import GeneticAlgorithm +from textattack.search import GeneticAlgorithm from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics.language_models import GoogleLanguageModel from textattack.transformations import WordSwapEmbedding diff --git a/textattack/attack_recipes/alzantot_2018_genetic_algorithm_adjusted.py b/textattack/attack_recipes/alzantot_2018_genetic_algorithm_adjusted.py index 8cb0d21a..2e612d3c 100644 --- a/textattack/attack_recipes/alzantot_2018_genetic_algorithm_adjusted.py +++ b/textattack/attack_recipes/alzantot_2018_genetic_algorithm_adjusted.py @@ -9,7 +9,7 @@ ArXiv, abs/1801.00554. """ -from textattack.attacks.blackbox import GeneticAlgorithm +from textattack.search import GeneticAlgorithm from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder, BERT from textattack.constraints.syntax import PartOfSpeech, LanguageTool diff --git a/textattack/attack_recipes/jin_2019_textfooler.py b/textattack/attack_recipes/jin_2019_textfooler.py index b4c1036d..e109f60d 100644 --- a/textattack/attack_recipes/jin_2019_textfooler.py +++ b/textattack/attack_recipes/jin_2019_textfooler.py @@ -8,7 +8,7 @@ """ -from textattack.attacks.blackbox import GreedyWordSwapWIR +from textattack.search import GreedyWordSwapWIR from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder from textattack.constraints.syntax import PartOfSpeech diff --git a/textattack/attack_recipes/jin_2019_textfooler_adjusted.py b/textattack/attack_recipes/jin_2019_textfooler_adjusted.py index 947bf472..75a329fa 100644 --- a/textattack/attack_recipes/jin_2019_textfooler_adjusted.py +++ b/textattack/attack_recipes/jin_2019_textfooler_adjusted.py @@ -8,7 +8,7 @@ """ -from textattack.attacks.blackbox import GreedyWordSwapWIR +from textattack.search import GreedyWordSwapWIR from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder, BERT from textattack.constraints.syntax import PartOfSpeech, LanguageTool diff --git a/textattack/attacks/__init__.py b/textattack/attack_results/__init__.py similarity index 100% rename from textattack/attacks/__init__.py rename to textattack/attack_results/__init__.py diff --git a/textattack/attacks/attack_result.py b/textattack/attack_results/attack_result.py similarity index 100% rename from textattack/attacks/attack_result.py rename to textattack/attack_results/attack_result.py diff --git a/textattack/attacks/failed_attack_result.py b/textattack/attack_results/failed_attack_result.py similarity index 100% rename from textattack/attacks/failed_attack_result.py rename to textattack/attack_results/failed_attack_result.py diff --git a/textattack/attacks/skipped_attack_result.py b/textattack/attack_results/skipped_attack_result.py similarity index 100% rename from textattack/attacks/skipped_attack_result.py rename to textattack/attack_results/skipped_attack_result.py diff --git a/textattack/attacks/blackbox/black_box_attack.py b/textattack/attacks/blackbox/black_box_attack.py deleted file mode 100644 index cf112cee..00000000 --- a/textattack/attacks/blackbox/black_box_attack.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch - -from textattack.attacks import Attack - -class BlackBoxAttack(Attack): - """ An abstract class that defines a black-box attack. - - - A black-box attack can access the prediction scores of a model, but - not any other information about its parameters or internal state. We - reduce the model here to just the __call__ function so that black-box - attacks that extend this class can obtain prediction scores, but not - any other information from the model. - - Arg: - model (nn.Module): The model to attack. - constraints (list(Constraint)): The list of constraints - for the model's transformations. - - """ - def __init__(self, model, constraints=[]): - self.model_description = model.__class__.__name__ - self.model = model.__call__ - self.tokenizer = model.tokenizer - super().__init__(constraints=constraints) - - def _call_model(self, *args, **kwargs): - with torch.no_grad(): - return super()._call_model(*args, **kwargs) diff --git a/textattack/attacks/blackbox/genetic_algorithm/__init__.py b/textattack/attacks/blackbox/genetic_algorithm/__init__.py deleted file mode 100644 index e3fed2b0..00000000 --- a/textattack/attacks/blackbox/genetic_algorithm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .genetic_algorithm import GeneticAlgorithm \ No newline at end of file diff --git a/textattack/attacks/blackbox/greedy_word_swap/__init__.py b/textattack/attacks/blackbox/greedy_word_swap/__init__.py deleted file mode 100644 index 3ab82ad2..00000000 --- a/textattack/attacks/blackbox/greedy_word_swap/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .greedy_word_swap import GreedyWordSwap \ No newline at end of file diff --git a/textattack/attacks/blackbox/greedy_word_swap_wir/__init__.py b/textattack/attacks/blackbox/greedy_word_swap_wir/__init__.py deleted file mode 100644 index e7b4089d..00000000 --- a/textattack/attacks/blackbox/greedy_word_swap_wir/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .greedy_word_swap_wir import GreedyWordSwapWIR \ No newline at end of file diff --git a/textattack/attacks/whitebox/hot_flip/hot_flip.py b/textattack/attacks/whitebox/hot_flip/hot_flip.py deleted file mode 100644 index 333ab77b..00000000 --- a/textattack/attacks/whitebox/hot_flip/hot_flip.py +++ /dev/null @@ -1,5 +0,0 @@ -from textattack.attacks.whitebox import WhiteBoxAttack - -class HotFlip(WhiteBoxAttack): - def __init__(self): - raise NotImplementedError() \ No newline at end of file diff --git a/textattack/attacks/whitebox/white_box_attack.py b/textattack/attacks/whitebox/white_box_attack.py deleted file mode 100644 index b31b6bed..00000000 --- a/textattack/attacks/whitebox/white_box_attack.py +++ /dev/null @@ -1,20 +0,0 @@ -from textattack.attacks import Attack - -class WhiteBoxAttack(Attack): - """ An abstract class that defines a white-box attack. - - - A white-box attack can information about the model, so it stores the entire - model as the field `self.model`. - - Arg: - model (nn.Module): The model to attack. - constraints (list(Constraint)): The list of constraints - for the model's transformations. - - """ - def __init__(self, model, constraints=[]): - self.model_description = model.__class__.__name__ - self.model = model - self.tokenizer = model.tokenizer - super().__init__(constraints=constraints, is_black_box=False) diff --git a/textattack/attacks/attack_logger.py b/textattack/loggers/attack_logger.py similarity index 100% rename from textattack/attacks/attack_logger.py rename to textattack/loggers/attack_logger.py diff --git a/textattack/run_attack.py b/textattack/run_attack.py index 0484ea82..208d2234 100644 --- a/textattack/run_attack.py +++ b/textattack/run_attack.py @@ -101,9 +101,9 @@ CONSTRAINT_CLASS_NAMES = { } ATTACK_CLASS_NAMES = { - 'greedy-word': 'attacks.blackbox.GreedyWordSwap', - 'ga-word': 'attacks.blackbox.GeneticAlgorithm', - 'greedy-word-wir': 'attacks.blackbox.GreedyWordSwapWIR', + 'greedy-word': 'search.GreedyWordSwap', + 'ga-word': 'search.GeneticAlgorithm', + 'greedy-word-wir': 'search.GreedyWordSwapWIR', } @@ -291,7 +291,7 @@ def main(args): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Cache TensorFlow Hub models here, if not otherwise specified. if 'TFHUB_CACHE_DIR' not in os.environ: - os.environ['TFHUB_CACHE_DIR'] = './tensorflow-hub' + os.environ['TFHUB_CACHE_DIR'] = '~/.cache/tensorflow-hub' start_time = time.time() diff --git a/textattack/attacks/blackbox/__init__.py b/textattack/search/__init__.py similarity index 75% rename from textattack/attacks/blackbox/__init__.py rename to textattack/search/__init__.py index 169ec8cf..8f2e8b59 100644 --- a/textattack/attacks/blackbox/__init__.py +++ b/textattack/search/__init__.py @@ -1,5 +1,3 @@ -from .black_box_attack import BlackBoxAttack - from .greedy_word_swap import GreedyWordSwap from .greedy_word_swap_wir import GreedyWordSwapWIR from .genetic_algorithm import GeneticAlgorithm diff --git a/textattack/attacks/blackbox/genetic_algorithm/genetic_algorithm.py b/textattack/search/genetic_algorithm.py similarity index 93% rename from textattack/attacks/blackbox/genetic_algorithm/genetic_algorithm.py rename to textattack/search/genetic_algorithm.py index c7691b31..84374bcb 100644 --- a/textattack/attacks/blackbox/genetic_algorithm/genetic_algorithm.py +++ b/textattack/search/genetic_algorithm.py @@ -8,12 +8,14 @@ Algorithm from Generating Natural Language Adversarial Examples by Alzantot et. import numpy as np +from .search import Search + from textattack.attacks import AttackResult, FailedAttackResult -from textattack.attacks.blackbox import BlackBoxAttack +from textattack.search import BlackBoxAttack from textattack.transformations import WordSwap from copy import deepcopy -class GeneticAlgorithm(BlackBoxAttack): +class GeneticAlgorithm(Search): ''' Attacks a model using a genetic algorithm. @@ -27,12 +29,11 @@ class GeneticAlgorithm(BlackBoxAttack): ValueError: If the transformation is not a subclass of WordSwap. ''' - def __init__(self, model, transformations=[], pop_size=20, max_iters=50, temp=0.3): + def __init__(self, get_transformations, pop_size=20, max_iters=50, temp=0.3): if not isinstance(transformations[0], WordSwap): raise ValueError(f'Transformation is of type {type(transformation)}, should be a subclass of WordSwap') - super().__init__(model) + self.get_transformations = get_transformations self.model = model - self.transformation = transformations[0] self.max_iters = max_iters self.pop_size = pop_size self.temp = temp @@ -50,8 +51,7 @@ class GeneticAlgorithm(BlackBoxAttack): Returns: Whether a replacement which decreased the score was found. """ - transformations = self.get_transformations(self.transformation, - pop_member.tokenized_text, + transformations = self.get_transformations(pop_member.tokenized_text, original_text=self.original_tokenized_text, indices_to_replace=[idx]) if not len(transformations): @@ -142,8 +142,7 @@ class GeneticAlgorithm(BlackBoxAttack): ''' words = tokenized_text.words neighbors_list = [[] for _ in range(len(words))] - transformations = self.get_transformations(self.transformation, - tokenized_text, + transformations = self.get_transformations(tokenized_text, original_text=self.original_tokenized_text, apply_constraints=False) for transformed_text in transformations: @@ -153,7 +152,7 @@ class GeneticAlgorithm(BlackBoxAttack): neighbors_len = np.array([len(x) for x in neighbors_list]) return neighbors_len - def attack_one(self, original_label, tokenized_text): + def __call__(self, original_label, tokenized_text): self.original_tokenized_text = tokenized_text original_prob = self._call_model([tokenized_text]).squeeze().max() neighbors_len = self._get_neighbors_len(tokenized_text) diff --git a/textattack/attacks/blackbox/greedy_word_swap/greedy_word_swap.py b/textattack/search/greedy_word_swap.py similarity index 84% rename from textattack/attacks/blackbox/greedy_word_swap/greedy_word_swap.py rename to textattack/search/greedy_word_swap.py index 9e87be92..06767bab 100644 --- a/textattack/attacks/blackbox/greedy_word_swap/greedy_word_swap.py +++ b/textattack/search/greedy_word_swap.py @@ -1,7 +1,9 @@ -from textattack.attacks import AttackResult, FailedAttackResult -from textattack.attacks.blackbox import BlackBoxAttack +from .search import Search -class GreedyWordSwap(BlackBoxAttack): +from textattack.attacks import AttackResult, FailedAttackResult +from textattack.search import BlackBoxAttack + +class GreedyWordSwap(Search): """ An attack that greedily chooses from a list of possible perturbations. @@ -12,12 +14,11 @@ class GreedyWordSwap(BlackBoxAttack): max_depth (:obj:`int`, optional): The maximum number of words to change. Defaults to 32. """ - def __init__(self, model, transformations=[], max_depth=32): - super().__init__(model) - self.transformation = transformations[0] + def __init__(self, get_transformations, max_depth=32): + self.get_transformations = get_transformations self.max_depth = max_depth - def attack_one(self, original_label, tokenized_text): + def __call__(self, original_label, tokenized_text): original_tokenized_text = tokenized_text original_prob = self._call_model([tokenized_text]).squeeze().max() num_words_changed = 0 @@ -27,9 +28,7 @@ class GreedyWordSwap(BlackBoxAttack): while num_words_changed <= self.max_depth and len(unswapped_word_indices): num_words_changed += 1 transformed_text_candidates = self.get_transformations( - self.transformation, - tokenized_text, - indices_to_replace=unswapped_word_indices) + tokenized_text, indices_to_replace=unswapped_word_indices) if len(transformed_text_candidates) == 0: # If we did not find any possible perturbations, give up. break diff --git a/textattack/attacks/blackbox/greedy_word_swap_wir/greedy_word_swap_wir.py b/textattack/search/greedy_word_swap_wir.py similarity index 95% rename from textattack/attacks/blackbox/greedy_word_swap_wir/greedy_word_swap_wir.py rename to textattack/search/greedy_word_swap_wir.py index 73aeedf2..3f92b56d 100644 --- a/textattack/attacks/blackbox/greedy_word_swap_wir/greedy_word_swap_wir.py +++ b/textattack/search/greedy_word_swap_wir.py @@ -1,9 +1,11 @@ import torch +from .search import Search + from textattack.attacks import AttackResult, FailedAttackResult from textattack.attacks.blackbox import BlackBoxAttack -class GreedyWordSwapWIR(BlackBoxAttack): +class GreedyWordSwapWIR(Search): """ An attack that greedily chooses from a list of possible perturbations for each index, after ranking indices by importance. @@ -17,12 +19,12 @@ class GreedyWordSwapWIR(BlackBoxAttack): max_depth (:obj:`int`, optional): The maximum number of words to change. Defaults to 32. """ - def __init__(self, model, transformations=[], max_depth=32): + def __init__(self, get_transformations, max_depth=32): super().__init__(model) - self.transformation = transformations[0] + self.get_transformations = get_transformations self.max_depth = max_depth - def attack_one(self, original_label, tokenized_text): + def __call__(self, original_label, tokenized_text): original_tokenized_text = tokenized_text num_words_changed = 0 diff --git a/textattack/search/search.py b/textattack/search/search.py new file mode 100644 index 00000000..a3c83883 --- /dev/null +++ b/textattack/search/search.py @@ -0,0 +1,10 @@ +class Search: + """ This is an abstract class that defines a textual search method. + `self.__call__` takes the original label and text of a text input + and searches for a potential adversarial example. + + Generally, a search method queries `get_transformations` to explore + candidate perturbed phrases. + """ + def __call__(self, original_label, tokenized_text): + raise NotImplementedError() \ No newline at end of file diff --git a/textattack/attacks/whitebox/__init__.py b/textattack/shared/__init__.py similarity index 100% rename from textattack/attacks/whitebox/__init__.py rename to textattack/shared/__init__.py diff --git a/textattack/attacks/attack.py b/textattack/shared/attack.py similarity index 94% rename from textattack/attacks/attack.py rename to textattack/shared/attack.py index 88fe149d..0cad6e44 100644 --- a/textattack/attacks/attack.py +++ b/textattack/shared/attack.py @@ -8,9 +8,9 @@ import time from textattack import utils as utils from textattack.constraints import Constraint +from textattack.loggers.attack_logger import AttackLogger from textattack.tokenized_text import TokenizedText -from textattack.attacks.attack_logger import AttackLogger -from textattack.attacks import AttackResult, FailedAttackResult +from textattack.attack_results import AttackResult, FailedAttackResult class Attack: """ @@ -21,7 +21,7 @@ class Attack: constraints: A list of constraints to add to the attack """ - def __init__(self, constraints=[], is_black_box=True): + def __init__(self, search, transformation, constraints=[], is_black_box=True): """ Initialize an attack object. Attacks can be run multiple times. """ if not self.model: @@ -32,9 +32,10 @@ class Attack: else: raise NameError('Cannot instantiate attack without tokenizer') # Transformation and corresponding constraints. + self.search = search + self.transformation = transformation self.constraints = [] - if constraints: - self.add_constraints(constraints) + self.add_constraints(constraints) # Logger self.logger = AttackLogger() self.is_black_box = is_black_box @@ -94,7 +95,7 @@ class Attack: for constraint in constraints: self.add_constraint(constraint) - def get_transformations(self, transformation, text, original_text=None, + def get_transformations(self, text, original_text=None, apply_constraints=True, **kwargs): """ Filters a list of transformations by self.constraints. @@ -110,6 +111,8 @@ class Attack: A filtered list of transformations where each transformation matches the constraints """ + if not self.transformation: + raise RuntimeError('Cannot call `get_transformations` without a transformation set.') transformations = np.array(transformation(text, **kwargs)) if apply_constraints: return self._filter_transformations(transformations, text, original_text) @@ -133,7 +136,7 @@ class Attack: `label`. """ - raise NotImplementedError() + return self.search(label, tokenized_text) def _call_model(self, tokenized_text_list, batch_size=8): """ diff --git a/textattack/tokenized_text.py b/textattack/shared/tokenized_text.py similarity index 100% rename from textattack/tokenized_text.py rename to textattack/shared/tokenized_text.py diff --git a/textattack/utils.py b/textattack/shared/utils.py similarity index 100% rename from textattack/utils.py rename to textattack/shared/utils.py