1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add kuleshov recipe; rename attack_methods --> search_methods

This commit is contained in:
Jack Morris
2020-05-06 23:19:13 -04:00
parent dd848c7ac1
commit f140541560
16 changed files with 44 additions and 20 deletions

View File

@@ -2,13 +2,13 @@ name = "textattack"
from . import attack_recipes from . import attack_recipes
from . import attack_results from . import attack_results
from . import attack_methods
from . import constraints from . import constraints
from . import datasets from . import datasets
from . import goal_functions from . import goal_functions
from . import goal_function_results from . import goal_function_results
from . import loggers from . import loggers
from . import models from . import models
from . import search_methods
from . import shared from . import shared
from . import tokenizers from . import tokenizers
from . import transformations from . import transformations

View File

@@ -2,6 +2,7 @@ from .alzantot_2018 import Alzantot2018
from .alzantot_2018_adjusted import Alzantot2018Adjusted from .alzantot_2018_adjusted import Alzantot2018Adjusted
from .deepwordbug_gao_2018 import DeepWordBugGao2018 from .deepwordbug_gao_2018 import DeepWordBugGao2018
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017 from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
from .kuleshov_2017 import Kuleshov2017
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
from .textfooler_jin_2019 import TextFoolerJin2019 from .textfooler_jin_2019 import TextFoolerJin2019
from .textfooler_jin_2019_adjusted import TextFoolerJin2019Adjusted from .textfooler_jin_2019_adjusted import TextFoolerJin2019Adjusted

View File

@@ -30,7 +30,7 @@ def Alzantot2018(model):
# Maximum words perturbed percentage of 20% # Maximum words perturbed percentage of 20%
# #
constraints.append( constraints.append(
WordsPerturbed(max_percent=20) WordsPerturbed(max_percent=0.2)
) )
# #
# Maximum word embedding euclidean distance of 0.5. # Maximum word embedding euclidean distance of 0.5.

View File

@@ -8,9 +8,9 @@
""" """
from textattack.attack_methods import GreedyWordSwapWIR
from textattack.constraints.overlap import LevenshteinEditDistance from textattack.constraints.overlap import LevenshteinEditDistance
from textattack.goal_functions import UntargetedClassification from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import CompositeTransformation from textattack.transformations import CompositeTransformation
from textattack.transformations.black_box import \ from textattack.transformations.black_box import \
WordSwapNeighboringCharacterSwap, \ WordSwapNeighboringCharacterSwap, \

View File

@@ -11,12 +11,12 @@
paper). paper).
""" """
from textattack.attack_methods import BeamSearch from textattack.goal_functions import UntargetedClassification
from textattack.constraints.grammaticality import PartOfSpeech from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.overlap import WordsPerturbed from textattack.constraints.overlap import WordsPerturbed
from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.search_methods import BeamSearch
from textattack.transformations.white_box import GradientBasedWordSwap from textattack.transformations.white_box import GradientBasedWordSwap
from textattack.goal_functions import UntargetedClassification
def HotFlipEbrahimi2017(model): def HotFlipEbrahimi2017(model):
# #

View File

@@ -8,8 +8,8 @@
""" """
from textattack.constraints.overlap import WordsPerturbed from textattack.constraints.overlap import WordsPerturbed
from textattack.constraints.grammaticality.language_models import Google1BillionWordsLanguageModel from textattack.constraints.grammaticality.language_models import GPT2
from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics import ThoughtVector
from textattack.goal_functions import UntargetedClassification from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwap from textattack.search_methods import GreedyWordSwap
from textattack.transformations.black_box import WordSwapEmbedding from textattack.transformations.black_box import WordSwapEmbedding

View File

@@ -12,9 +12,9 @@
""" """
from textattack.attack_methods import GreedyWordSwapWIR
from textattack.constraints.overlap import LevenshteinEditDistance from textattack.constraints.overlap import LevenshteinEditDistance
from textattack.goal_functions import NonOverlappingOutput from textattack.goal_functions import NonOverlappingOutput
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations.black_box import WordSwapEmbedding from textattack.transformations.black_box import WordSwapEmbedding
def Seq2SickCheng2018BlackBox(model, goal_function='non_overlapping'): def Seq2SickCheng2018BlackBox(model, goal_function='non_overlapping'):

View File

@@ -8,12 +8,12 @@
""" """
from textattack.attack_methods import GreedyWordSwapWIR from textattack.goal_functions import UntargetedClassification
from textattack.constraints.grammaticality import PartOfSpeech from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations.black_box import WordSwapEmbedding from textattack.transformations.black_box import WordSwapEmbedding
from textattack.goal_functions import UntargetedClassification
def TextFoolerJin2019(model): def TextFoolerJin2019(model):
# #

View File

@@ -8,12 +8,12 @@
""" """
from textattack.attack_methods import GreedyWordSwapWIR
from textattack.constraints.semantics import WordEmbeddingDistance from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder, BERT from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder, BERT
from textattack.constraints.grammaticality import PartOfSpeech, LanguageTool from textattack.constraints.grammaticality import PartOfSpeech, LanguageTool
from textattack.transformations.black_box import WordSwapEmbedding
from textattack.goal_functions import UntargetedClassification from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations.black_box import WordSwapEmbedding
def TextFoolerJin2019Adjusted(model, SE_thresh=0.98, sentence_encoder='bert'): def TextFoolerJin2019Adjusted(model, SE_thresh=0.98, sentence_encoder='bert'):
# #

View File

@@ -34,7 +34,7 @@ class LanguageModelConstraint(Constraint):
x_adv_prob = self.get_log_prob_at_index(x_adv, i) x_adv_prob = self.get_log_prob_at_index(x_adv, i)
if self.max_log_prob_diff is None: if self.max_log_prob_diff is None:
x_prob, x_adv_prob = math.log(p1), math.log(p2) x_prob, x_adv_prob = math.log(p1), math.log(p2)
return (x_prob - x_adv_prob).abs() <= self.max_log_prob_diff return abs(x_prob - x_adv_prob) <= self.max_log_prob_diff
def extra_repr_keys(self): def extra_repr_keys(self):
return ['max_log_prob_diff'] return ['max_log_prob_diff']

View File

@@ -1,3 +1,4 @@
import math
from textattack.constraints import Constraint from textattack.constraints import Constraint
class WordsPerturbed(Constraint): class WordsPerturbed(Constraint):
@@ -6,6 +7,8 @@ class WordsPerturbed(Constraint):
def __init__(self, max_num_words=None, max_percent=None): def __init__(self, max_num_words=None, max_percent=None):
if (max_num_words is None) and (max_percent is None): if (max_num_words is None) and (max_percent is None):
raise ValueError('must set either max perc or max num words') raise ValueError('must set either max perc or max num words')
if max_percent and not (0 <= max_percent <= 1):
raise ValueError('max perc must be between 0 and 1')
self.max_num_words = max_num_words self.max_num_words = max_num_words
self.max_percent = max_percent self.max_percent = max_percent
@@ -16,7 +19,7 @@ class WordsPerturbed(Constraint):
num_words_diff = len(x_adv.all_words_diff(original_text)) num_words_diff = len(x_adv.all_words_diff(original_text))
if self.max_percent: if self.max_percent:
min_num_words = min(len(x_adv.words), len(original_text.words)) min_num_words = min(len(x_adv.words), len(original_text.words))
max_words_perturbed = round(min_num_words * (self.max_percent / 100)) max_words_perturbed = math.ceil(min_num_words * (self.max_percent))
max_percent_met = num_words_diff <= max_words_perturbed max_percent_met = num_words_diff <= max_words_perturbed
else: else:
max_percent_met = True max_percent_met = True

View File

@@ -17,6 +17,7 @@ class ThoughtVector(Constraint):
""" """
def __init__(self, embedding_type='paragramcf', max_mse_dist=None, min_cos_sim=None): def __init__(self, embedding_type='paragramcf', max_mse_dist=None, min_cos_sim=None):
self.word_embedding = WordEmbedding(embedding_type) self.word_embedding = WordEmbedding(embedding_type)
self.embedding_type = embedding_type
if (max_mse_dist or min_cos_sim) is None: if (max_mse_dist or min_cos_sim) is None:
raise ValueError('Must set max_mse_dist or min_cos_sim') raise ValueError('Must set max_mse_dist or min_cos_sim')
@@ -26,9 +27,18 @@ class ThoughtVector(Constraint):
@functools.lru_cache(maxsize=2**10) @functools.lru_cache(maxsize=2**10)
def _get_thought_vector(self, tokenized_text): def _get_thought_vector(self, tokenized_text):
return torch.sum([self.word_embedding[word] for word in tokenized_text.words]) """ Sums the embeddings of all the words in `tokenized_text` into a
"thought vector".
"""
embeddings = []
for word in tokenized_text.words:
embedding = self.word_embedding[word]
if embedding is not None: # out-of-vocab words do not have embeddings
embeddings.append(embedding)
embeddings = torch.tensor(embeddings)
return torch.sum(embeddings, dim=0)
def __call__(self, x, x_adv): def __call__(self, x, x_adv, original_text=None):
""" Returns true if (x, x_adv) are closer than `self.min_cos_sim` """ Returns true if (x, x_adv) are closer than `self.min_cos_sim`
and `self.max_mse_dist`. """ and `self.max_mse_dist`. """
@@ -47,7 +57,7 @@ class ThoughtVector(Constraint):
return False return False
# Check MSE distance. # Check MSE distance.
if self.max_mse_dist: if self.max_mse_dist:
mse_dist = torch.sum((e1 - e2) ** 2) mse_dist = torch.sum((thought_vector_1 - thought_vector_2) ** 2)
if mse_dist > self.max_mse_dist: if mse_dist > self.max_mse_dist:
return False return False
return True return True

View File

@@ -10,8 +10,9 @@ class UntargetedClassification(ClassificationGoalFunction):
below this score. Otherwise, goal is to change the overall predicted below this score. Otherwise, goal is to change the overall predicted
class. class.
""" """
def __init__(self, target_max_score=None): def __init__(self, *args, target_max_score=None, **kwargs):
self.target_max_score = target_max_score self.target_max_score = target_max_score
super().__init__(*args, **kwargs)
def _is_goal_complete(self, model_output, ground_truth_output): def _is_goal_complete(self, model_output, ground_truth_output):
if self.target_max_score: if self.target_max_score:

View File

@@ -76,7 +76,7 @@ class Attack:
transformations = original_transformations[:] transformations = original_transformations[:]
for C in self.constraints: for C in self.constraints:
if len(transformations) == 0: break if len(transformations) == 0: break
transformations = C.call_many(text, transformations, original_text) transformations = C.call_many(text, transformations, original_text=original_text)
# Default to false for all original transformations. # Default to false for all original transformations.
for original_transformation in original_transformations: for original_transformation in original_transformations:
self.constraints_cache[original_transformation] = False self.constraints_cache[original_transformation] = False

View File

@@ -216,7 +216,7 @@ def html_table_from_rows(rows, title=None, header=None, style_dict=None):
return table_html return table_html
def has_letter(word): def has_letter(word):
""" Returns true if `word` contains at least one character in [A-Za-z]. """ Returns true if `word` contains at least one character in [A-Za-z]. """
for c in word: for c in word:
if c.isalpha(): return True if c.isalpha(): return True
return False return False

View File

@@ -48,4 +48,13 @@ class WordEmbedding:
self.cos_sim_mat = {} self.cos_sim_mat = {}
def __getitem__(self, index): def __getitem__(self, index):
""" Gets a word embedding by word or ID.
If word or ID not found, returns None.
"""
if isinstance(index, str):
try:
index = self.word2index[index]
except KeyError:
return None
return self.embeddings[index] return self.embeddings[index]