diff --git a/scripts/benchmark_models.py b/scripts/benchmark_models.py index 327a256a..83a517a0 100644 --- a/scripts/benchmark_models.py +++ b/scripts/benchmark_models.py @@ -7,9 +7,9 @@ from run_attack_args_helper import * import textattack.models as models -def _cb(s): return textattack.shared.utils.color_text_by_method(str(s), color='blue', method='stdout') -def _cg(s): return textattack.shared.utils.color_text_by_method(str(s), color='green', method='stdout') -def _cr(s): return textattack.shared.utils.color_text_by_method(str(s), color='red', method='stdout') +def _cb(s): return textattack.shared.utils.color_text(str(s), color='blue', method='stdout') +def _cg(s): return textattack.shared.utils.color_text(str(s), color='green', method='stdout') +def _cr(s): return textattack.shared.utils.color_text(str(s), color='red', method='stdout') def _pb(): print(_cg('-' * 60)) from collections import Counter diff --git a/scripts/run_attack_args_helper.py b/scripts/run_attack_args_helper.py index 32fa0851..5cae1d0a 100644 --- a/scripts/run_attack_args_helper.py +++ b/scripts/run_attack_args_helper.py @@ -10,7 +10,7 @@ RECIPE_NAMES = { 'alzantot': 'textattack.attack_recipes.Alzantot2018GeneticAlgorithm', 'alz-adjusted': 'textattack.attack_recipes.Alzantot2018GeneticAlgorithmAdjusted', 'deepwordbug': 'textattack.attack_recipes.Gao2018DeepWordBug', - 'seq2sick': 'textattack.attack_recipes.Cheng2018Seq2Sick', + 'seq2sick': 'textattack.attack_recipes.Cheng2018Seq2SickBlackBox', 'textfooler': 'textattack.attack_recipes.Jin2019TextFooler', 'tf-adjusted': 'textattack.attack_recipes.Jin2019TextFoolerAdjusted', } diff --git a/textattack/attack_methods/attack.py b/textattack/attack_methods/attack.py index a58c9f4a..98743154 100644 --- a/textattack/attack_methods/attack.py +++ b/textattack/attack_methods/attack.py @@ -13,9 +13,9 @@ class Attack: An attack generates adversarial examples on text. This is an abstract class that contains main helper functionality for - attacks. An attack is comprised of a search method and a transformation, as - well as one or more linguistic constraints that successful examples must - meet. + attacks. An attack is comprised of a search method, a goal function, and a + transformation, as well as one or more linguistic constraints that + successful examples must meet. Args: goal_function: A function for determining how well a perturbation is doing at achieving the attack's goal. @@ -29,7 +29,7 @@ class Attack: """ self.goal_function = goal_function if not self.goal_function: - raise NameError('Cannot instantiate attack without self.goal_function for prediction scores') + raise NameError('Cannot instantiate attack without self.goal_function for predictions') if not hasattr(self, 'tokenizer'): if hasattr(self.goal_function.model, 'tokenizer'): self.tokenizer = self.goal_function.model.tokenizer @@ -130,7 +130,7 @@ class Attack: examples. If `False`, returns `num_examples` total examples. Returns: - results (List[Tuple[Int, TokenizedText, Boolean]]): a list of + results (Iterable[Tuple[GoalFunctionResult, Boolean]]): a list of objects containing (text, ground_truth_output, was_skipped) """ examples = [] diff --git a/textattack/attack_methods/greedy_word_swap_wir.py b/textattack/attack_methods/greedy_word_swap_wir.py index fb731a44..bf6abd4f 100644 --- a/textattack/attack_methods/greedy_word_swap_wir.py +++ b/textattack/attack_methods/greedy_word_swap_wir.py @@ -52,7 +52,6 @@ class GreedyWordSwapWIR(Attack): new_text_label = None i = 0 while ((self.max_depth is None) or num_words_changed <= self.max_depth) and i < len(index_order): - # import pdb; pdb.set_trace() transformed_text_candidates = self.get_transformations( tokenized_text, original_tokenized_text, diff --git a/textattack/attack_recipes/__init__.py b/textattack/attack_recipes/__init__.py index e9fd8919..ad1e1be3 100644 --- a/textattack/attack_recipes/__init__.py +++ b/textattack/attack_recipes/__init__.py @@ -1,6 +1,6 @@ from .alzantot_2018_genetic_algorithm import Alzantot2018GeneticAlgorithm from .alzantot_2018_genetic_algorithm_adjusted import Alzantot2018GeneticAlgorithmAdjusted -from .cheng_2018_seq2sick import Cheng2018Seq2Sick +from .cheng_2018_seq2sick_blackbox import Cheng2018Seq2SickBlackBox from .jin_2019_textfooler import Jin2019TextFooler from .jin_2019_textfooler_adjusted import Jin2019TextFoolerAdjusted from .gao_2018_deepwordbug import Gao2018DeepWordBug \ No newline at end of file diff --git a/textattack/attack_recipes/cheng_2018_seq2sick.py b/textattack/attack_recipes/cheng_2018_seq2sick_blackbox.py similarity index 94% rename from textattack/attack_recipes/cheng_2018_seq2sick.py rename to textattack/attack_recipes/cheng_2018_seq2sick_blackbox.py index b59f80d3..6ff1a229 100644 --- a/textattack/attack_recipes/cheng_2018_seq2sick.py +++ b/textattack/attack_recipes/cheng_2018_seq2sick_blackbox.py @@ -17,7 +17,7 @@ from textattack.constraints.overlap import LevenshteinEditDistance from textattack.goal_functions import NonOverlappingOutput from textattack.transformations import WordSwapEmbedding -def Cheng2018Seq2Sick(model, goal_function='non_overlapping'): +def Cheng2018Seq2SickBlackBox(model, goal_function='non_overlapping'): # # Goal is non-overlapping output. # diff --git a/textattack/datasets/dataset.py b/textattack/datasets/dataset.py index d788c1ab..af52b230 100644 --- a/textattack/datasets/dataset.py +++ b/textattack/datasets/dataset.py @@ -39,7 +39,6 @@ class TextAttackDataset: self.i = 0 file_path = utils.download_if_needed(file_name) self.examples = pickle.load( open(file_path, "rb" ) ) - import pdb; pdb.set_trace() self.examples = self.examples[offset:] def _load_classification_text_file(self, text_file_name, offset=0): diff --git a/textattack/datasets/translation/translation_datasets.py b/textattack/datasets/translation/translation_datasets.py index f096e0d7..71884860 100644 --- a/textattack/datasets/translation/translation_datasets.py +++ b/textattack/datasets/translation/translation_datasets.py @@ -1,4 +1,3 @@ -import gluonnlp from textattack.datasets import TextAttackDataset class NewsTest2013EnglishToGerman(TextAttackDataset): @@ -20,4 +19,4 @@ class NewsTest2013EnglishToGerman(TextAttackDataset): """ DATA_PATH = 'datasets/translation/NewsTest2013EnglishToGerman' def __init__(self, offset=0): - self._load_pickle_file(NewsTest2013EnglishToGerman.DATA_PATH) \ No newline at end of file + self._load_pickle_file(NewsTest2013EnglishToGerman.DATA_PATH, offset=offset) \ No newline at end of file diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index 7ad196e4..9fc44306 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -139,8 +139,8 @@ class GoalFunction: outputs = self._call_model_uncached(uncached_list) for text, output in zip(uncached_list, outputs): self._call_model_cache[text] = output - final_scores = [self._call_model_cache[text] for text in tokenized_text_list] - return final_scores + all_outputs = [self._call_model_cache[text] for text in tokenized_text_list] + return all_outputs def extra_repr_keys(self): return []