mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add kuleshov recipe; rename attack_methods --> search_methods
This commit is contained in:
@@ -2,13 +2,13 @@ name = "textattack"
|
|||||||
|
|
||||||
from . import attack_recipes
|
from . import attack_recipes
|
||||||
from . import attack_results
|
from . import attack_results
|
||||||
from . import attack_methods
|
|
||||||
from . import constraints
|
from . import constraints
|
||||||
from . import datasets
|
from . import datasets
|
||||||
from . import goal_functions
|
from . import goal_functions
|
||||||
from . import goal_function_results
|
from . import goal_function_results
|
||||||
from . import loggers
|
from . import loggers
|
||||||
from . import models
|
from . import models
|
||||||
|
from . import search_methods
|
||||||
from . import shared
|
from . import shared
|
||||||
from . import tokenizers
|
from . import tokenizers
|
||||||
from . import transformations
|
from . import transformations
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from .alzantot_2018 import Alzantot2018
|
|||||||
from .alzantot_2018_adjusted import Alzantot2018Adjusted
|
from .alzantot_2018_adjusted import Alzantot2018Adjusted
|
||||||
from .deepwordbug_gao_2018 import DeepWordBugGao2018
|
from .deepwordbug_gao_2018 import DeepWordBugGao2018
|
||||||
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
|
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
|
||||||
|
from .kuleshov_2017 import Kuleshov2017
|
||||||
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
|
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
|
||||||
from .textfooler_jin_2019 import TextFoolerJin2019
|
from .textfooler_jin_2019 import TextFoolerJin2019
|
||||||
from .textfooler_jin_2019_adjusted import TextFoolerJin2019Adjusted
|
from .textfooler_jin_2019_adjusted import TextFoolerJin2019Adjusted
|
||||||
@@ -30,7 +30,7 @@ def Alzantot2018(model):
|
|||||||
# Maximum words perturbed percentage of 20%
|
# Maximum words perturbed percentage of 20%
|
||||||
#
|
#
|
||||||
constraints.append(
|
constraints.append(
|
||||||
WordsPerturbed(max_percent=20)
|
WordsPerturbed(max_percent=0.2)
|
||||||
)
|
)
|
||||||
#
|
#
|
||||||
# Maximum word embedding euclidean distance of 0.5.
|
# Maximum word embedding euclidean distance of 0.5.
|
||||||
|
|||||||
@@ -8,9 +8,9 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from textattack.attack_methods import GreedyWordSwapWIR
|
|
||||||
from textattack.constraints.overlap import LevenshteinEditDistance
|
from textattack.constraints.overlap import LevenshteinEditDistance
|
||||||
from textattack.goal_functions import UntargetedClassification
|
from textattack.goal_functions import UntargetedClassification
|
||||||
|
from textattack.search_methods import GreedyWordSwapWIR
|
||||||
from textattack.transformations import CompositeTransformation
|
from textattack.transformations import CompositeTransformation
|
||||||
from textattack.transformations.black_box import \
|
from textattack.transformations.black_box import \
|
||||||
WordSwapNeighboringCharacterSwap, \
|
WordSwapNeighboringCharacterSwap, \
|
||||||
|
|||||||
@@ -11,12 +11,12 @@
|
|||||||
paper).
|
paper).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from textattack.attack_methods import BeamSearch
|
from textattack.goal_functions import UntargetedClassification
|
||||||
from textattack.constraints.grammaticality import PartOfSpeech
|
from textattack.constraints.grammaticality import PartOfSpeech
|
||||||
from textattack.constraints.overlap import WordsPerturbed
|
from textattack.constraints.overlap import WordsPerturbed
|
||||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||||
|
from textattack.search_methods import BeamSearch
|
||||||
from textattack.transformations.white_box import GradientBasedWordSwap
|
from textattack.transformations.white_box import GradientBasedWordSwap
|
||||||
from textattack.goal_functions import UntargetedClassification
|
|
||||||
|
|
||||||
def HotFlipEbrahimi2017(model):
|
def HotFlipEbrahimi2017(model):
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -8,8 +8,8 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from textattack.constraints.overlap import WordsPerturbed
|
from textattack.constraints.overlap import WordsPerturbed
|
||||||
from textattack.constraints.grammaticality.language_models import Google1BillionWordsLanguageModel
|
from textattack.constraints.grammaticality.language_models import GPT2
|
||||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
from textattack.constraints.semantics import ThoughtVector
|
||||||
from textattack.goal_functions import UntargetedClassification
|
from textattack.goal_functions import UntargetedClassification
|
||||||
from textattack.search_methods import GreedyWordSwap
|
from textattack.search_methods import GreedyWordSwap
|
||||||
from textattack.transformations.black_box import WordSwapEmbedding
|
from textattack.transformations.black_box import WordSwapEmbedding
|
||||||
|
|||||||
@@ -12,9 +12,9 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from textattack.attack_methods import GreedyWordSwapWIR
|
|
||||||
from textattack.constraints.overlap import LevenshteinEditDistance
|
from textattack.constraints.overlap import LevenshteinEditDistance
|
||||||
from textattack.goal_functions import NonOverlappingOutput
|
from textattack.goal_functions import NonOverlappingOutput
|
||||||
|
from textattack.search_methods import GreedyWordSwapWIR
|
||||||
from textattack.transformations.black_box import WordSwapEmbedding
|
from textattack.transformations.black_box import WordSwapEmbedding
|
||||||
|
|
||||||
def Seq2SickCheng2018BlackBox(model, goal_function='non_overlapping'):
|
def Seq2SickCheng2018BlackBox(model, goal_function='non_overlapping'):
|
||||||
|
|||||||
@@ -8,12 +8,12 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from textattack.attack_methods import GreedyWordSwapWIR
|
from textattack.goal_functions import UntargetedClassification
|
||||||
from textattack.constraints.grammaticality import PartOfSpeech
|
from textattack.constraints.grammaticality import PartOfSpeech
|
||||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||||
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
||||||
|
from textattack.search_methods import GreedyWordSwapWIR
|
||||||
from textattack.transformations.black_box import WordSwapEmbedding
|
from textattack.transformations.black_box import WordSwapEmbedding
|
||||||
from textattack.goal_functions import UntargetedClassification
|
|
||||||
|
|
||||||
def TextFoolerJin2019(model):
|
def TextFoolerJin2019(model):
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -8,12 +8,12 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from textattack.attack_methods import GreedyWordSwapWIR
|
|
||||||
from textattack.constraints.semantics import WordEmbeddingDistance
|
from textattack.constraints.semantics import WordEmbeddingDistance
|
||||||
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder, BERT
|
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder, BERT
|
||||||
from textattack.constraints.grammaticality import PartOfSpeech, LanguageTool
|
from textattack.constraints.grammaticality import PartOfSpeech, LanguageTool
|
||||||
from textattack.transformations.black_box import WordSwapEmbedding
|
|
||||||
from textattack.goal_functions import UntargetedClassification
|
from textattack.goal_functions import UntargetedClassification
|
||||||
|
from textattack.search_methods import GreedyWordSwapWIR
|
||||||
|
from textattack.transformations.black_box import WordSwapEmbedding
|
||||||
|
|
||||||
def TextFoolerJin2019Adjusted(model, SE_thresh=0.98, sentence_encoder='bert'):
|
def TextFoolerJin2019Adjusted(model, SE_thresh=0.98, sentence_encoder='bert'):
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class LanguageModelConstraint(Constraint):
|
|||||||
x_adv_prob = self.get_log_prob_at_index(x_adv, i)
|
x_adv_prob = self.get_log_prob_at_index(x_adv, i)
|
||||||
if self.max_log_prob_diff is None:
|
if self.max_log_prob_diff is None:
|
||||||
x_prob, x_adv_prob = math.log(p1), math.log(p2)
|
x_prob, x_adv_prob = math.log(p1), math.log(p2)
|
||||||
return (x_prob - x_adv_prob).abs() <= self.max_log_prob_diff
|
return abs(x_prob - x_adv_prob) <= self.max_log_prob_diff
|
||||||
|
|
||||||
def extra_repr_keys(self):
|
def extra_repr_keys(self):
|
||||||
return ['max_log_prob_diff']
|
return ['max_log_prob_diff']
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from textattack.constraints import Constraint
|
from textattack.constraints import Constraint
|
||||||
|
|
||||||
class WordsPerturbed(Constraint):
|
class WordsPerturbed(Constraint):
|
||||||
@@ -6,6 +7,8 @@ class WordsPerturbed(Constraint):
|
|||||||
def __init__(self, max_num_words=None, max_percent=None):
|
def __init__(self, max_num_words=None, max_percent=None):
|
||||||
if (max_num_words is None) and (max_percent is None):
|
if (max_num_words is None) and (max_percent is None):
|
||||||
raise ValueError('must set either max perc or max num words')
|
raise ValueError('must set either max perc or max num words')
|
||||||
|
if max_percent and not (0 <= max_percent <= 1):
|
||||||
|
raise ValueError('max perc must be between 0 and 1')
|
||||||
self.max_num_words = max_num_words
|
self.max_num_words = max_num_words
|
||||||
self.max_percent = max_percent
|
self.max_percent = max_percent
|
||||||
|
|
||||||
@@ -16,7 +19,7 @@ class WordsPerturbed(Constraint):
|
|||||||
num_words_diff = len(x_adv.all_words_diff(original_text))
|
num_words_diff = len(x_adv.all_words_diff(original_text))
|
||||||
if self.max_percent:
|
if self.max_percent:
|
||||||
min_num_words = min(len(x_adv.words), len(original_text.words))
|
min_num_words = min(len(x_adv.words), len(original_text.words))
|
||||||
max_words_perturbed = round(min_num_words * (self.max_percent / 100))
|
max_words_perturbed = math.ceil(min_num_words * (self.max_percent))
|
||||||
max_percent_met = num_words_diff <= max_words_perturbed
|
max_percent_met = num_words_diff <= max_words_perturbed
|
||||||
else:
|
else:
|
||||||
max_percent_met = True
|
max_percent_met = True
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class ThoughtVector(Constraint):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, embedding_type='paragramcf', max_mse_dist=None, min_cos_sim=None):
|
def __init__(self, embedding_type='paragramcf', max_mse_dist=None, min_cos_sim=None):
|
||||||
self.word_embedding = WordEmbedding(embedding_type)
|
self.word_embedding = WordEmbedding(embedding_type)
|
||||||
|
self.embedding_type = embedding_type
|
||||||
|
|
||||||
if (max_mse_dist or min_cos_sim) is None:
|
if (max_mse_dist or min_cos_sim) is None:
|
||||||
raise ValueError('Must set max_mse_dist or min_cos_sim')
|
raise ValueError('Must set max_mse_dist or min_cos_sim')
|
||||||
@@ -26,9 +27,18 @@ class ThoughtVector(Constraint):
|
|||||||
|
|
||||||
@functools.lru_cache(maxsize=2**10)
|
@functools.lru_cache(maxsize=2**10)
|
||||||
def _get_thought_vector(self, tokenized_text):
|
def _get_thought_vector(self, tokenized_text):
|
||||||
return torch.sum([self.word_embedding[word] for word in tokenized_text.words])
|
""" Sums the embeddings of all the words in `tokenized_text` into a
|
||||||
|
"thought vector".
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for word in tokenized_text.words:
|
||||||
|
embedding = self.word_embedding[word]
|
||||||
|
if embedding is not None: # out-of-vocab words do not have embeddings
|
||||||
|
embeddings.append(embedding)
|
||||||
|
embeddings = torch.tensor(embeddings)
|
||||||
|
return torch.sum(embeddings, dim=0)
|
||||||
|
|
||||||
def __call__(self, x, x_adv):
|
def __call__(self, x, x_adv, original_text=None):
|
||||||
""" Returns true if (x, x_adv) are closer than `self.min_cos_sim`
|
""" Returns true if (x, x_adv) are closer than `self.min_cos_sim`
|
||||||
and `self.max_mse_dist`. """
|
and `self.max_mse_dist`. """
|
||||||
|
|
||||||
@@ -47,7 +57,7 @@ class ThoughtVector(Constraint):
|
|||||||
return False
|
return False
|
||||||
# Check MSE distance.
|
# Check MSE distance.
|
||||||
if self.max_mse_dist:
|
if self.max_mse_dist:
|
||||||
mse_dist = torch.sum((e1 - e2) ** 2)
|
mse_dist = torch.sum((thought_vector_1 - thought_vector_2) ** 2)
|
||||||
if mse_dist > self.max_mse_dist:
|
if mse_dist > self.max_mse_dist:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ class UntargetedClassification(ClassificationGoalFunction):
|
|||||||
below this score. Otherwise, goal is to change the overall predicted
|
below this score. Otherwise, goal is to change the overall predicted
|
||||||
class.
|
class.
|
||||||
"""
|
"""
|
||||||
def __init__(self, target_max_score=None):
|
def __init__(self, *args, target_max_score=None, **kwargs):
|
||||||
self.target_max_score = target_max_score
|
self.target_max_score = target_max_score
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _is_goal_complete(self, model_output, ground_truth_output):
|
def _is_goal_complete(self, model_output, ground_truth_output):
|
||||||
if self.target_max_score:
|
if self.target_max_score:
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ class Attack:
|
|||||||
transformations = original_transformations[:]
|
transformations = original_transformations[:]
|
||||||
for C in self.constraints:
|
for C in self.constraints:
|
||||||
if len(transformations) == 0: break
|
if len(transformations) == 0: break
|
||||||
transformations = C.call_many(text, transformations, original_text)
|
transformations = C.call_many(text, transformations, original_text=original_text)
|
||||||
# Default to false for all original transformations.
|
# Default to false for all original transformations.
|
||||||
for original_transformation in original_transformations:
|
for original_transformation in original_transformations:
|
||||||
self.constraints_cache[original_transformation] = False
|
self.constraints_cache[original_transformation] = False
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ def html_table_from_rows(rows, title=None, header=None, style_dict=None):
|
|||||||
return table_html
|
return table_html
|
||||||
|
|
||||||
def has_letter(word):
|
def has_letter(word):
|
||||||
""" Returns true if `word` contains at least one character in [A-Za-z].
|
""" Returns true if `word` contains at least one character in [A-Za-z]. """
|
||||||
for c in word:
|
for c in word:
|
||||||
if c.isalpha(): return True
|
if c.isalpha(): return True
|
||||||
return False
|
return False
|
||||||
@@ -48,4 +48,13 @@ class WordEmbedding:
|
|||||||
self.cos_sim_mat = {}
|
self.cos_sim_mat = {}
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
""" Gets a word embedding by word or ID.
|
||||||
|
|
||||||
|
If word or ID not found, returns None.
|
||||||
|
"""
|
||||||
|
if isinstance(index, str):
|
||||||
|
try:
|
||||||
|
index = self.word2index[index]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
return self.embeddings[index]
|
return self.embeddings[index]
|
||||||
Reference in New Issue
Block a user