1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add kuleshov recipe; rename attack_methods --> search_methods

This commit is contained in:
Jack Morris
2020-05-06 23:19:13 -04:00
parent dd848c7ac1
commit f140541560
16 changed files with 44 additions and 20 deletions

View File

@@ -2,13 +2,13 @@ name = "textattack"
from . import attack_recipes
from . import attack_results
from . import attack_methods
from . import constraints
from . import datasets
from . import goal_functions
from . import goal_function_results
from . import loggers
from . import models
from . import search_methods
from . import shared
from . import tokenizers
from . import transformations

View File

@@ -2,6 +2,7 @@ from .alzantot_2018 import Alzantot2018
from .alzantot_2018_adjusted import Alzantot2018Adjusted
from .deepwordbug_gao_2018 import DeepWordBugGao2018
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
from .kuleshov_2017 import Kuleshov2017
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
from .textfooler_jin_2019 import TextFoolerJin2019
from .textfooler_jin_2019_adjusted import TextFoolerJin2019Adjusted

View File

@@ -30,7 +30,7 @@ def Alzantot2018(model):
# Maximum words perturbed percentage of 20%
#
constraints.append(
WordsPerturbed(max_percent=20)
WordsPerturbed(max_percent=0.2)
)
#
# Maximum word embedding euclidean distance of 0.5.

View File

@@ -8,9 +8,9 @@
"""
from textattack.attack_methods import GreedyWordSwapWIR
from textattack.constraints.overlap import LevenshteinEditDistance
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import CompositeTransformation
from textattack.transformations.black_box import \
WordSwapNeighboringCharacterSwap, \

View File

@@ -11,12 +11,12 @@
paper).
"""
from textattack.attack_methods import BeamSearch
from textattack.goal_functions import UntargetedClassification
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.overlap import WordsPerturbed
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.search_methods import BeamSearch
from textattack.transformations.white_box import GradientBasedWordSwap
from textattack.goal_functions import UntargetedClassification
def HotFlipEbrahimi2017(model):
#

View File

@@ -8,8 +8,8 @@
"""
from textattack.constraints.overlap import WordsPerturbed
from textattack.constraints.grammaticality.language_models import Google1BillionWordsLanguageModel
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.grammaticality.language_models import GPT2
from textattack.constraints.semantics import ThoughtVector
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwap
from textattack.transformations.black_box import WordSwapEmbedding

View File

@@ -12,9 +12,9 @@
"""
from textattack.attack_methods import GreedyWordSwapWIR
from textattack.constraints.overlap import LevenshteinEditDistance
from textattack.goal_functions import NonOverlappingOutput
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations.black_box import WordSwapEmbedding
def Seq2SickCheng2018BlackBox(model, goal_function='non_overlapping'):

View File

@@ -8,12 +8,12 @@
"""
from textattack.attack_methods import GreedyWordSwapWIR
from textattack.goal_functions import UntargetedClassification
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations.black_box import WordSwapEmbedding
from textattack.goal_functions import UntargetedClassification
def TextFoolerJin2019(model):
#

View File

@@ -8,12 +8,12 @@
"""
from textattack.attack_methods import GreedyWordSwapWIR
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder, BERT
from textattack.constraints.grammaticality import PartOfSpeech, LanguageTool
from textattack.transformations.black_box import WordSwapEmbedding
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations.black_box import WordSwapEmbedding
def TextFoolerJin2019Adjusted(model, SE_thresh=0.98, sentence_encoder='bert'):
#

View File

@@ -34,7 +34,7 @@ class LanguageModelConstraint(Constraint):
x_adv_prob = self.get_log_prob_at_index(x_adv, i)
if self.max_log_prob_diff is None:
x_prob, x_adv_prob = math.log(p1), math.log(p2)
return (x_prob - x_adv_prob).abs() <= self.max_log_prob_diff
return abs(x_prob - x_adv_prob) <= self.max_log_prob_diff
def extra_repr_keys(self):
return ['max_log_prob_diff']

View File

@@ -1,3 +1,4 @@
import math
from textattack.constraints import Constraint
class WordsPerturbed(Constraint):
@@ -6,6 +7,8 @@ class WordsPerturbed(Constraint):
def __init__(self, max_num_words=None, max_percent=None):
if (max_num_words is None) and (max_percent is None):
raise ValueError('must set either max perc or max num words')
if max_percent and not (0 <= max_percent <= 1):
raise ValueError('max perc must be between 0 and 1')
self.max_num_words = max_num_words
self.max_percent = max_percent
@@ -16,7 +19,7 @@ class WordsPerturbed(Constraint):
num_words_diff = len(x_adv.all_words_diff(original_text))
if self.max_percent:
min_num_words = min(len(x_adv.words), len(original_text.words))
max_words_perturbed = round(min_num_words * (self.max_percent / 100))
max_words_perturbed = math.ceil(min_num_words * (self.max_percent))
max_percent_met = num_words_diff <= max_words_perturbed
else:
max_percent_met = True

View File

@@ -17,6 +17,7 @@ class ThoughtVector(Constraint):
"""
def __init__(self, embedding_type='paragramcf', max_mse_dist=None, min_cos_sim=None):
self.word_embedding = WordEmbedding(embedding_type)
self.embedding_type = embedding_type
if (max_mse_dist or min_cos_sim) is None:
raise ValueError('Must set max_mse_dist or min_cos_sim')
@@ -26,9 +27,18 @@ class ThoughtVector(Constraint):
@functools.lru_cache(maxsize=2**10)
def _get_thought_vector(self, tokenized_text):
return torch.sum([self.word_embedding[word] for word in tokenized_text.words])
""" Sums the embeddings of all the words in `tokenized_text` into a
"thought vector".
"""
embeddings = []
for word in tokenized_text.words:
embedding = self.word_embedding[word]
if embedding is not None: # out-of-vocab words do not have embeddings
embeddings.append(embedding)
embeddings = torch.tensor(embeddings)
return torch.sum(embeddings, dim=0)
def __call__(self, x, x_adv):
def __call__(self, x, x_adv, original_text=None):
""" Returns true if (x, x_adv) are closer than `self.min_cos_sim`
and `self.max_mse_dist`. """
@@ -47,7 +57,7 @@ class ThoughtVector(Constraint):
return False
# Check MSE distance.
if self.max_mse_dist:
mse_dist = torch.sum((e1 - e2) ** 2)
mse_dist = torch.sum((thought_vector_1 - thought_vector_2) ** 2)
if mse_dist > self.max_mse_dist:
return False
return True

View File

@@ -10,8 +10,9 @@ class UntargetedClassification(ClassificationGoalFunction):
below this score. Otherwise, goal is to change the overall predicted
class.
"""
def __init__(self, target_max_score=None):
def __init__(self, *args, target_max_score=None, **kwargs):
self.target_max_score = target_max_score
super().__init__(*args, **kwargs)
def _is_goal_complete(self, model_output, ground_truth_output):
if self.target_max_score:

View File

@@ -76,7 +76,7 @@ class Attack:
transformations = original_transformations[:]
for C in self.constraints:
if len(transformations) == 0: break
transformations = C.call_many(text, transformations, original_text)
transformations = C.call_many(text, transformations, original_text=original_text)
# Default to false for all original transformations.
for original_transformation in original_transformations:
self.constraints_cache[original_transformation] = False

View File

@@ -216,7 +216,7 @@ def html_table_from_rows(rows, title=None, header=None, style_dict=None):
return table_html
def has_letter(word):
""" Returns true if `word` contains at least one character in [A-Za-z].
""" Returns true if `word` contains at least one character in [A-Za-z]. """
for c in word:
if c.isalpha(): return True
return False

View File

@@ -48,4 +48,13 @@ class WordEmbedding:
self.cos_sim_mat = {}
def __getitem__(self, index):
""" Gets a word embedding by word or ID.
If word or ID not found, returns None.
"""
if isinstance(index, str):
try:
index = self.word2index[index]
except KeyError:
return None
return self.embeddings[index]