mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
support sentence & word constraints
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
- upload sample models and datasets
|
||||
- add logger... we should never call print()
|
||||
- make it much quieter when we load pretrained BERT. It's so noisy right now :(
|
||||
- try to refer to 'text' not 'sentences' (better terminology)
|
||||
"""
|
||||
|
||||
import difflib
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .constraint import TextConstraint, WordConstraint
|
||||
@@ -1,5 +1,22 @@
|
||||
""" Abstract classes that represent constraints on text adversarial examples. """
|
||||
class Constraint:
|
||||
""" A constraint that evaluates if x_adv meets a certain constraint. """
|
||||
""" A constraint that evaluates if (x,x_adv) meets a certain constraint. """
|
||||
def __call__(self, x, x_adv):
|
||||
""" Returns True if C(x,x_adv) is true. """
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
||||
class WordConstraint(Constraint):
|
||||
""" A constraint that evaluates if an (original, perturbed) word pair
|
||||
meets a certain constraint. """
|
||||
def __call__(self, x, x_adv):
|
||||
""" Returns True if C(x,x_adv) is true. """
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TextConstraint(Constraint):
|
||||
""" A constraint that evaluates if an (original, perturbed) text input pair
|
||||
meets a certain constraint. """
|
||||
def __call__(self, x, x_adv):
|
||||
""" Returns True if C(x,x_adv) is true. """
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
from .models import InferSent
|
||||
from .utils import cache_path
|
||||
from models import InferSent
|
||||
from utils import download_if_needed
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class UniversalSentenceEncoder(constraint):
|
||||
from .constraint import SentenceConstraint
|
||||
|
||||
class UniversalSentenceEncoder(SentenceConstraint):
|
||||
""" Constraint using cosine similarity between Universal Sentence Encodings
|
||||
of x and x_adv.
|
||||
|
||||
Uses InferSent sentence embeddings. """
|
||||
|
||||
MODEL_PATH = '/p/qdata/jm8wx/research/RobustNLP/AttackGeneration/infersent-encoder'
|
||||
WORD_EMBEDDING_PATH = '/p/qdata/jm8wx/research/RobustNLP/AttackGeneration/word_embeddings'
|
||||
|
||||
def __init__(self, threshold=0.8):
|
||||
self.threshold = threshold
|
||||
self.model = self.get_infersent_model()
|
||||
|
||||
def get_infersent_model(self):
|
||||
infersent_version = 2
|
||||
MODEL_PATH = cache_path(f'infersent-encoder/infersent{infersent_version}.pkl')
|
||||
model_path = os.path.join(UniversalSentenceEncoder.MODEL_PATH, f'infersent{infersent_version}.pkl')
|
||||
utils.download_if_needed(model_path)
|
||||
params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
|
||||
'pool_type': 'max', 'dpout_model': 0.0, 'version': infersent_version}
|
||||
infersent = InferSent(params_model)
|
||||
infersent.load_state_dict(torch.load(MODEL_PATH))
|
||||
W2V_PATH = 'word_embeddings/fastText/crawl-300d-2M.vec'
|
||||
W2V_PATH = os.path.join(UniversalSentenceEncoder.WORD_EMBEDDING_PATH,
|
||||
'fastText', 'crawl-300d-2M.vec')
|
||||
infersent.set_w2v_path(W2V_PATH)
|
||||
infersent.build_vocab_k_words(K=100000)
|
||||
return infersent
|
||||
@@ -36,26 +43,6 @@ class UniversalSentenceEncoder(constraint):
|
||||
|
||||
cos = torch.nn.CosineSimilarity(dim=0)
|
||||
return cos(original_embedding, perturbed_embedding)
|
||||
|
||||
|
||||
def OLD_DLEETE_AFTER_RECREATING_IN_SEARCH_score(self, text, index_to_replace, candidates, tokenized=False,
|
||||
recover_adv=default_recover_adv):
|
||||
""" Returns cosine similarity between the USE encodings of x and x_adv.
|
||||
"""
|
||||
|
||||
raw_sentences = np.array(text, dtype=object)
|
||||
raw_sentences = np.tile(raw_sentences, (len(candidates), 1))
|
||||
raw_sentences[list(range(len(candidates))), np.tile(index_to_replace, len(candidates))] = np.array(candidates)
|
||||
raw_sentences = list(map(lambda s: recover_adv(s), raw_sentences))
|
||||
|
||||
original_embedding = self.model.encode([recover_adv(text)], tokenize = True)[0]
|
||||
altered_embeddings = self.model.encode(raw_sentences, tokenize = True)
|
||||
|
||||
cos = torch.nn.CosineSimilarity(dim=0)
|
||||
def cos_similarity(embedding):
|
||||
return cos(torch.from_numpy(original_embedding), torch.from_numpy(embedding))
|
||||
|
||||
return list(np.apply_along_axis(cos_similarity, 1, altered_embeddings))
|
||||
|
||||
def __call__(self, x, x_adv):
|
||||
return self.score(x, x_adv) >= self.threshold
|
||||
41
datasets/dataset.py
Normal file
41
datasets/dataset.py
Normal file
@@ -0,0 +1,41 @@
|
||||
""" @TODO
|
||||
- support tensorflow_datasets, pytorch dataloader and other built-in datasets
|
||||
- batch support
|
||||
"""
|
||||
class TextAttackDataset:
|
||||
""" A dataset for text attacks.
|
||||
|
||||
Any iterable of (label, text_input) pairs qualifies as
|
||||
a TextAttackDataset.
|
||||
"""
|
||||
def __init__(self):
|
||||
""" Loads a full dataset from disk. Typically stores tuples in
|
||||
`self.examples`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __iter__(self):
|
||||
return self.examples.__iter__()
|
||||
|
||||
def __next__(self):
|
||||
return self.examples.__next__()
|
||||
|
||||
def _load_text_file(self, text_file_name, n=None):
|
||||
""" Loads (label, text) pairs from a text file.
|
||||
|
||||
Format must look like:
|
||||
|
||||
1 this is a great little ...
|
||||
0 "i love hot n juicy . ...
|
||||
0 "\""this world needs a ...
|
||||
"""
|
||||
examples = []
|
||||
i = 0
|
||||
for raw_line in open(text_file_name, 'r').readlines():
|
||||
tokens = raw_line.strip().split()
|
||||
label = int(tokens[0])
|
||||
text = ' '.join(tokens[1:])
|
||||
examples.append((label, text))
|
||||
i += 1
|
||||
if n and i >= n: break
|
||||
return examples
|
||||
@@ -1,5 +1,5 @@
|
||||
import utils
|
||||
from dataset import TextAttackDataset
|
||||
from .dataset import TextAttackDataset
|
||||
|
||||
class YelpSentiment(TextAttackDataset):
|
||||
DATA_PATH = '/p/qdata/jm8wx/research/OLD/TextFooler/data/yelp'
|
||||
|
||||
40
perturbations/perturbation.py
Normal file
40
perturbations/perturbation.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from constraints import TextConstraint
|
||||
|
||||
class Perturbation:
|
||||
""" Generates perturbations for a given text input. """
|
||||
def __init__(self, constraints=[]):
|
||||
print('Perturbation init')
|
||||
self.constraints = []
|
||||
if constraints:
|
||||
self.add_constraints(constraints)
|
||||
|
||||
""" An abstract class for perturbing a string of text to produce
|
||||
a potential adversarial example. """
|
||||
def perturb(self, tokenized_text):
|
||||
""" Returns a list of all possible perturbations for `text`
|
||||
that match provided constraints."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _filter_perturbations(self, original_text, perturbations):
|
||||
""" Filters a list of perturbations based on internal constraints. """
|
||||
good_perturbations = []
|
||||
for perturbed_text in perturbations:
|
||||
meets_constraints = True
|
||||
for c in self.constraints:
|
||||
meets_constraints = meets_constraints and c(original_text,
|
||||
perturbed_text)
|
||||
if not meets_constraints: break
|
||||
if meets_constraints: good_perturbations.append(p)
|
||||
return good_perturbations
|
||||
|
||||
def add_constraint(self, constraint):
|
||||
""" Add constraint to attack. """
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_constraints(self, constraints):
|
||||
""" Add multiple constraints.
|
||||
"""
|
||||
for constraint in constraints:
|
||||
if not isinstance(constraint, TextConstraint):
|
||||
raise ValueError('Cannot add constraint of type', type(constraint))
|
||||
self.add_constraint(constraint)
|
||||
63
perturbations/word_swap.py
Normal file
63
perturbations/word_swap.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from .perturbation import Perturbation
|
||||
from constraints import TextConstraint, WordConstraint
|
||||
|
||||
class WordSwap(Perturbation):
|
||||
""" An abstract class that takes a sentence and perturbs it by replacing
|
||||
some of its words.
|
||||
|
||||
Other classes can achieve this by inheriting from WordSwap and
|
||||
overriding self._get_replacement_words.
|
||||
"""
|
||||
# def __init__(self):
|
||||
# print('WordSwap init')
|
||||
# super().__init__()
|
||||
|
||||
def add_constraints(self, constraints):
|
||||
""" Add multiple constraints.
|
||||
"""
|
||||
for constraint in constraints:
|
||||
if not (isinstance(constraint, TextConstraint) or isinstance(constraint, WordConstraint)):
|
||||
raise ValueError('Cannot add constraint of type', type(constraint))
|
||||
self.add_constraint(constraint)
|
||||
|
||||
def _get_replacement_words(self, word):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _filter_perturbations(self, original_text, perturbations, word_swaps):
|
||||
""" Filters a list of perturbations based on internal constraints. """
|
||||
|
||||
""" Filters a list of perturbations based on internal constraints. """
|
||||
good_perturbations = []
|
||||
for p in perturbations:
|
||||
meets_constraints = True
|
||||
for c in self.constraints:
|
||||
if isinstance(c, TextConstraint):
|
||||
meets_constraints = meets_constraints and c(original_text, p)
|
||||
elif isinstance(c, WordConstraint):
|
||||
meets_constraints = meets_constraints and c(ow, nw)
|
||||
if not meets_constraints: break
|
||||
if meets_constraints: good_perturbations.append(p)
|
||||
return good_perturbations
|
||||
|
||||
def perturb(self, tokenized_text, indices_to_replace=None):
|
||||
""" Returns a list of all possible perturbations for `text`.
|
||||
|
||||
If indices_to_replace is set, only replaces words at those
|
||||
indices.
|
||||
"""
|
||||
words = tokenized_text.words()
|
||||
if not indices_to_replace:
|
||||
indices_to_replace = list(range(len(words)))
|
||||
|
||||
perturbations = []
|
||||
word_swaps = []
|
||||
for i in indices_to_replace:
|
||||
word_to_replace = words[i]
|
||||
replacement_words = self._get_replacement_words(word_to_replace)
|
||||
new_tokenized_texts = []
|
||||
for r in replacement_words:
|
||||
new_tokenized_texts.append(tokenized_text.replace_word_at_index(i, r))
|
||||
word_swaps.append((word_to_replace, r))
|
||||
perturbations.extend(new_tokenized_texts)
|
||||
|
||||
return self._filter_perturbations(tokenized_text, perturbations, word_swaps)
|
||||
@@ -2,12 +2,13 @@ import numpy as np
|
||||
import os
|
||||
import utils
|
||||
|
||||
from perturbation import Perturbation
|
||||
from .word_swap import WordSwap
|
||||
|
||||
class WordSwapCounterfit(Perturbation):
|
||||
class WordSwapCounterfit(WordSwap):
|
||||
PATH = '/p/qdata/jm8wx/research/RobustNLP/AttackGeneration/word_embeddings/'
|
||||
|
||||
def __init__(self, word_embedding_folder='paragram_300_sl999'):
|
||||
super().__init__()
|
||||
if word_embedding_folder == 'paragram_300_sl999':
|
||||
word_embeddings_file = 'paragram_300_sl999.npy'
|
||||
word_list_file = 'wordlist.pickle'
|
||||
@@ -50,24 +51,4 @@ class WordSwapCounterfit(Perturbation):
|
||||
return candidate_words
|
||||
except KeyError:
|
||||
# This word is not in our word embedding database, so return an empty list.
|
||||
return []
|
||||
|
||||
def perturb(self, tokenized_text, indices_to_replace=None):
|
||||
""" Returns a list of all possible perturbations for `text`.
|
||||
|
||||
If indices_to_replace is set, only replaces words at those
|
||||
indices.
|
||||
"""
|
||||
words = tokenized_text.words()
|
||||
if not indices_to_replace:
|
||||
indices_to_replace = list(range(len(words)))
|
||||
|
||||
perturbations = []
|
||||
for i in indices_to_replace:
|
||||
word_to_replace = words[i]
|
||||
replacement_words = self._get_replacement_words(word_to_replace)
|
||||
new_tokenized_texts = [tokenized_text.replace_word_at_index(i, r)
|
||||
for r in replacement_words]
|
||||
perturbations.extend(new_tokenized_texts)
|
||||
|
||||
return perturbations
|
||||
return []
|
||||
@@ -1,6 +1,6 @@
|
||||
numpy
|
||||
torch==1.3.0
|
||||
transformers==2.0.0
|
||||
|
||||
|
||||
|
||||
# visdom
|
||||
# visdom
|
||||
|
||||
Reference in New Issue
Block a user