1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

support sentence & word constraints

This commit is contained in:
Jack Morris
2019-10-13 15:36:47 -04:00
parent 58782d3448
commit 9d3c2c4e40
10 changed files with 184 additions and 53 deletions

View File

@@ -9,6 +9,7 @@
- upload sample models and datasets
- add logger... we should never call print()
- make it much quieter when we load pretrained BERT. It's so noisy right now :(
- try to refer to 'text' not 'sentences' (better terminology)
"""
import difflib

View File

@@ -0,0 +1 @@
from .constraint import TextConstraint, WordConstraint

View File

@@ -1,5 +1,22 @@
""" Abstract classes that represent constraints on text adversarial examples. """
class Constraint:
""" A constraint that evaluates if x_adv meets a certain constraint. """
""" A constraint that evaluates if (x,x_adv) meets a certain constraint. """
def __call__(self, x, x_adv):
""" Returns True if C(x,x_adv) is true. """
raise NotImplementedError()
raise NotImplementedError()
class WordConstraint(Constraint):
""" A constraint that evaluates if an (original, perturbed) word pair
meets a certain constraint. """
def __call__(self, x, x_adv):
""" Returns True if C(x,x_adv) is true. """
raise NotImplementedError()
class TextConstraint(Constraint):
""" A constraint that evaluates if an (original, perturbed) text input pair
meets a certain constraint. """
def __call__(self, x, x_adv):
""" Returns True if C(x,x_adv) is true. """
raise NotImplementedError()

View File

@@ -1,27 +1,34 @@
from .models import InferSent
from .utils import cache_path
from models import InferSent
from utils import download_if_needed
import numpy as np
import torch
class UniversalSentenceEncoder(constraint):
from .constraint import SentenceConstraint
class UniversalSentenceEncoder(SentenceConstraint):
""" Constraint using cosine similarity between Universal Sentence Encodings
of x and x_adv.
Uses InferSent sentence embeddings. """
MODEL_PATH = '/p/qdata/jm8wx/research/RobustNLP/AttackGeneration/infersent-encoder'
WORD_EMBEDDING_PATH = '/p/qdata/jm8wx/research/RobustNLP/AttackGeneration/word_embeddings'
def __init__(self, threshold=0.8):
self.threshold = threshold
self.model = self.get_infersent_model()
def get_infersent_model(self):
infersent_version = 2
MODEL_PATH = cache_path(f'infersent-encoder/infersent{infersent_version}.pkl')
model_path = os.path.join(UniversalSentenceEncoder.MODEL_PATH, f'infersent{infersent_version}.pkl')
utils.download_if_needed(model_path)
params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
'pool_type': 'max', 'dpout_model': 0.0, 'version': infersent_version}
infersent = InferSent(params_model)
infersent.load_state_dict(torch.load(MODEL_PATH))
W2V_PATH = 'word_embeddings/fastText/crawl-300d-2M.vec'
W2V_PATH = os.path.join(UniversalSentenceEncoder.WORD_EMBEDDING_PATH,
'fastText', 'crawl-300d-2M.vec')
infersent.set_w2v_path(W2V_PATH)
infersent.build_vocab_k_words(K=100000)
return infersent
@@ -36,26 +43,6 @@ class UniversalSentenceEncoder(constraint):
cos = torch.nn.CosineSimilarity(dim=0)
return cos(original_embedding, perturbed_embedding)
def OLD_DLEETE_AFTER_RECREATING_IN_SEARCH_score(self, text, index_to_replace, candidates, tokenized=False,
recover_adv=default_recover_adv):
""" Returns cosine similarity between the USE encodings of x and x_adv.
"""
raw_sentences = np.array(text, dtype=object)
raw_sentences = np.tile(raw_sentences, (len(candidates), 1))
raw_sentences[list(range(len(candidates))), np.tile(index_to_replace, len(candidates))] = np.array(candidates)
raw_sentences = list(map(lambda s: recover_adv(s), raw_sentences))
original_embedding = self.model.encode([recover_adv(text)], tokenize = True)[0]
altered_embeddings = self.model.encode(raw_sentences, tokenize = True)
cos = torch.nn.CosineSimilarity(dim=0)
def cos_similarity(embedding):
return cos(torch.from_numpy(original_embedding), torch.from_numpy(embedding))
return list(np.apply_along_axis(cos_similarity, 1, altered_embeddings))
def __call__(self, x, x_adv):
return self.score(x, x_adv) >= self.threshold

41
datasets/dataset.py Normal file
View File

@@ -0,0 +1,41 @@
""" @TODO
- support tensorflow_datasets, pytorch dataloader and other built-in datasets
- batch support
"""
class TextAttackDataset:
""" A dataset for text attacks.
Any iterable of (label, text_input) pairs qualifies as
a TextAttackDataset.
"""
def __init__(self):
""" Loads a full dataset from disk. Typically stores tuples in
`self.examples`.
"""
raise NotImplementedError()
def __iter__(self):
return self.examples.__iter__()
def __next__(self):
return self.examples.__next__()
def _load_text_file(self, text_file_name, n=None):
""" Loads (label, text) pairs from a text file.
Format must look like:
1 this is a great little ...
0 "i love hot n juicy . ...
0 "\""this world needs a ...
"""
examples = []
i = 0
for raw_line in open(text_file_name, 'r').readlines():
tokens = raw_line.strip().split()
label = int(tokens[0])
text = ' '.join(tokens[1:])
examples.append((label, text))
i += 1
if n and i >= n: break
return examples

View File

@@ -1,5 +1,5 @@
import utils
from dataset import TextAttackDataset
from .dataset import TextAttackDataset
class YelpSentiment(TextAttackDataset):
DATA_PATH = '/p/qdata/jm8wx/research/OLD/TextFooler/data/yelp'

View File

@@ -0,0 +1,40 @@
from constraints import TextConstraint
class Perturbation:
""" Generates perturbations for a given text input. """
def __init__(self, constraints=[]):
print('Perturbation init')
self.constraints = []
if constraints:
self.add_constraints(constraints)
""" An abstract class for perturbing a string of text to produce
a potential adversarial example. """
def perturb(self, tokenized_text):
""" Returns a list of all possible perturbations for `text`
that match provided constraints."""
raise NotImplementedError()
def _filter_perturbations(self, original_text, perturbations):
""" Filters a list of perturbations based on internal constraints. """
good_perturbations = []
for perturbed_text in perturbations:
meets_constraints = True
for c in self.constraints:
meets_constraints = meets_constraints and c(original_text,
perturbed_text)
if not meets_constraints: break
if meets_constraints: good_perturbations.append(p)
return good_perturbations
def add_constraint(self, constraint):
""" Add constraint to attack. """
self.constraints.append(constraint)
def add_constraints(self, constraints):
""" Add multiple constraints.
"""
for constraint in constraints:
if not isinstance(constraint, TextConstraint):
raise ValueError('Cannot add constraint of type', type(constraint))
self.add_constraint(constraint)

View File

@@ -0,0 +1,63 @@
from .perturbation import Perturbation
from constraints import TextConstraint, WordConstraint
class WordSwap(Perturbation):
""" An abstract class that takes a sentence and perturbs it by replacing
some of its words.
Other classes can achieve this by inheriting from WordSwap and
overriding self._get_replacement_words.
"""
# def __init__(self):
# print('WordSwap init')
# super().__init__()
def add_constraints(self, constraints):
""" Add multiple constraints.
"""
for constraint in constraints:
if not (isinstance(constraint, TextConstraint) or isinstance(constraint, WordConstraint)):
raise ValueError('Cannot add constraint of type', type(constraint))
self.add_constraint(constraint)
def _get_replacement_words(self, word):
raise NotImplementedError()
def _filter_perturbations(self, original_text, perturbations, word_swaps):
""" Filters a list of perturbations based on internal constraints. """
""" Filters a list of perturbations based on internal constraints. """
good_perturbations = []
for p in perturbations:
meets_constraints = True
for c in self.constraints:
if isinstance(c, TextConstraint):
meets_constraints = meets_constraints and c(original_text, p)
elif isinstance(c, WordConstraint):
meets_constraints = meets_constraints and c(ow, nw)
if not meets_constraints: break
if meets_constraints: good_perturbations.append(p)
return good_perturbations
def perturb(self, tokenized_text, indices_to_replace=None):
""" Returns a list of all possible perturbations for `text`.
If indices_to_replace is set, only replaces words at those
indices.
"""
words = tokenized_text.words()
if not indices_to_replace:
indices_to_replace = list(range(len(words)))
perturbations = []
word_swaps = []
for i in indices_to_replace:
word_to_replace = words[i]
replacement_words = self._get_replacement_words(word_to_replace)
new_tokenized_texts = []
for r in replacement_words:
new_tokenized_texts.append(tokenized_text.replace_word_at_index(i, r))
word_swaps.append((word_to_replace, r))
perturbations.extend(new_tokenized_texts)
return self._filter_perturbations(tokenized_text, perturbations, word_swaps)

View File

@@ -2,12 +2,13 @@ import numpy as np
import os
import utils
from perturbation import Perturbation
from .word_swap import WordSwap
class WordSwapCounterfit(Perturbation):
class WordSwapCounterfit(WordSwap):
PATH = '/p/qdata/jm8wx/research/RobustNLP/AttackGeneration/word_embeddings/'
def __init__(self, word_embedding_folder='paragram_300_sl999'):
super().__init__()
if word_embedding_folder == 'paragram_300_sl999':
word_embeddings_file = 'paragram_300_sl999.npy'
word_list_file = 'wordlist.pickle'
@@ -50,24 +51,4 @@ class WordSwapCounterfit(Perturbation):
return candidate_words
except KeyError:
# This word is not in our word embedding database, so return an empty list.
return []
def perturb(self, tokenized_text, indices_to_replace=None):
""" Returns a list of all possible perturbations for `text`.
If indices_to_replace is set, only replaces words at those
indices.
"""
words = tokenized_text.words()
if not indices_to_replace:
indices_to_replace = list(range(len(words)))
perturbations = []
for i in indices_to_replace:
word_to_replace = words[i]
replacement_words = self._get_replacement_words(word_to_replace)
new_tokenized_texts = [tokenized_text.replace_word_at_index(i, r)
for r in replacement_words]
perturbations.extend(new_tokenized_texts)
return perturbations
return []

View File

@@ -1,6 +1,6 @@
numpy
torch==1.3.0
transformers==2.0.0
# visdom
# visdom