From 0430055c1c887d79df30157aded87989fdf4458f Mon Sep 17 00:00:00 2001 From: uvafan Date: Wed, 1 Jul 2020 21:23:07 -0400 Subject: [PATCH] make review fixes --- .../classification_goal_function.py | 4 ++++ .../classification/input_reduction.py | 22 ++++++++++++------- .../classification/targeted_classification.py | 7 ++---- .../untargeted_classification.py | 3 --- textattack/loggers/csv_logger.py | 2 +- .../search_methods/greedy_word_swap_wir.py | 2 +- textattack/shared/attack.py | 2 +- 7 files changed, 23 insertions(+), 19 deletions(-) diff --git a/textattack/goal_functions/classification/classification_goal_function.py b/textattack/goal_functions/classification/classification_goal_function.py index 5647f6bb..d40f123a 100644 --- a/textattack/goal_functions/classification/classification_goal_function.py +++ b/textattack/goal_functions/classification/classification_goal_function.py @@ -52,3 +52,7 @@ class ClassificationGoalFunction(GoalFunction): def extra_repr_keys(self): return [] + + def _get_displayed_output(self, raw_output): + return int(raw_output.argmax()) + diff --git a/textattack/goal_functions/classification/input_reduction.py b/textattack/goal_functions/classification/input_reduction.py index 98317aad..3f7a0f9e 100644 --- a/textattack/goal_functions/classification/input_reduction.py +++ b/textattack/goal_functions/classification/input_reduction.py @@ -3,8 +3,12 @@ from .classification_goal_function import ClassificationGoalFunction class InputReduction(ClassificationGoalFunction): """ - An targeted attack on classification models which attempts to maximize the - score of the target label until it is the predicted label. + Attempts to reduce the input down to as few words as possible while maintaining + the same predicted label. + + From Feng, Wallace, Grissom, Iyyer, Rodriguez, Boyd-Graber. (2018). + Pathologies of Neural Models Make Interpretations Difficult. + ArXiv, abs/1804.07781. """ def __init__(self, *args, target_num_words=1, **kwargs): @@ -21,16 +25,18 @@ class InputReduction(ClassificationGoalFunction): return self.ground_truth_output != model_output.argmax() def _get_score(self, model_output, attacked_text): + # Give the lowest score possible to inputs which don't maintain the ground truth label. if self.ground_truth_output != model_output.argmax(): - return float("-inf") + return 0 + cur_num_words = attacked_text.num_words initial_num_words = self.initial_attacked_text.num_words - num_words_score = (initial_num_words - cur_num_words) / initial_num_words - model_score = model_output[self.ground_truth_output] - return num_words_score + model_score / initial_num_words - def _get_displayed_output(self, raw_output): - return int(raw_output.argmax()) + # The main goal is to reduce the number of words (num_words_score) + # Higher model score for the ground truth label is used as a tiebreaker (model_score) + num_words_score = max((initial_num_words - cur_num_words) / initial_num_words, 0) + model_score = model_output[self.ground_truth_output] + return min(num_words_score + model_score / initial_num_words, 1) def extra_repr_keys(self): return ["target_num_words"] diff --git a/textattack/goal_functions/classification/targeted_classification.py b/textattack/goal_functions/classification/targeted_classification.py index 1dae9c9a..fe223465 100644 --- a/textattack/goal_functions/classification/targeted_classification.py +++ b/textattack/goal_functions/classification/targeted_classification.py @@ -3,8 +3,8 @@ from .classification_goal_function import ClassificationGoalFunction class TargetedClassification(ClassificationGoalFunction): """ - An targeted attack on classification models which attempts to maximize the - score of the target label until it is the predicted label. + A targeted attack on classification models which attempts to maximize the + score of the target label. Complete when the arget label is the predicted label. """ def __init__(self, *args, target_class=0, **kwargs): @@ -24,8 +24,5 @@ class TargetedClassification(ClassificationGoalFunction): else: return model_output[self.target_class] - def _get_displayed_output(self, raw_output): - return int(raw_output.argmax()) - def extra_repr_keys(self): return ["target_class"] diff --git a/textattack/goal_functions/classification/untargeted_classification.py b/textattack/goal_functions/classification/untargeted_classification.py index 1158caed..5ceb9eb1 100644 --- a/textattack/goal_functions/classification/untargeted_classification.py +++ b/textattack/goal_functions/classification/untargeted_classification.py @@ -35,6 +35,3 @@ class UntargetedClassification(ClassificationGoalFunction): return abs(model_output.item() - self.ground_truth_output) else: return 1 - model_output[self.ground_truth_output] - - def _get_displayed_output(self, raw_output): - return int(raw_output.argmax()) diff --git a/textattack/loggers/csv_logger.py b/textattack/loggers/csv_logger.py index 3863210b..7fc0ecf4 100644 --- a/textattack/loggers/csv_logger.py +++ b/textattack/loggers/csv_logger.py @@ -21,7 +21,7 @@ class CSVLogger(Logger): def log_attack_result(self, result): original_text, perturbed_text = result.diff_color(self.color_method) - result_type = result.__class__.__name__[:-12] + result_type = result.__class__.__name__.replace('AttackResult','') row = { "original_text": original_text, "perturbed_text": perturbed_text, diff --git a/textattack/search_methods/greedy_word_swap_wir.py b/textattack/search_methods/greedy_word_swap_wir.py index 92c0178c..a48f8ad4 100644 --- a/textattack/search_methods/greedy_word_swap_wir.py +++ b/textattack/search_methods/greedy_word_swap_wir.py @@ -142,7 +142,7 @@ class GreedyWordSwapWIR(SearchMethod): def check_transformation_compatibility(self, transformation): """ - Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swaps transformations. + Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swap and deletion transformations. """ return transformation_consists_of_word_swaps_and_deletions(transformation) diff --git a/textattack/shared/attack.py b/textattack/shared/attack.py index 0d012da5..ca05cafc 100644 --- a/textattack/shared/attack.py +++ b/textattack/shared/attack.py @@ -178,7 +178,7 @@ class Attack: initial_result: The initial ``GoalFunctionResult`` from which to perturb. Returns: - Either a ``SuccessfulAttackResult``, ``FailedAttackResult``, + A ``SuccessfulAttackResult``, ``FailedAttackResult``, or ``MaximizedAttackResult``. """ final_result = self.search_method(initial_result)