1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

add LM constraint to args + formatting

This commit is contained in:
Jack Morris
2020-06-28 11:32:00 -04:00
parent 5f3e2b2961
commit f588200119
10 changed files with 166 additions and 109 deletions

View File

@@ -104,11 +104,14 @@ def FasterGeneticAlgorithmJia2019(model):
#
# Language Model
#
#
#
constraints.append(LearningToWriteLanguageModel(window_size=6, max_log_prob_diff=5.,
compare_against_original=True))
# constraints.append(LearningToWriteLanguageModel(window_size=5))
#
#
constraints.append(
LearningToWriteLanguageModel(
window_size=6, max_log_prob_diff=5.0, compare_against_original=True
)
)
# constraints.append(LearningToWriteLanguageModel(window_size=5))
#
# Goal is untargeted classification
#

View File

@@ -243,6 +243,7 @@ CONSTRAINT_CLASS_NAMES = {
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
"gpt2": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
#
# Overlap constraints
#

View File

@@ -2,4 +2,4 @@ from .language_model_constraint import LanguageModelConstraint
from .google_language_model import Google1BillionWordsLanguageModel
from .gpt2 import GPT2
from .learning_to_write import LearningToWriteLanguageModel
from .learning_to_write import LearningToWriteLanguageModel

View File

@@ -31,7 +31,7 @@ class LanguageModelConstraint(ABC, Constraint):
def _check_constraint(self, transformed_text, current_text, original_text=None):
if self.compare_against_original:
current_text = original_text
try:
indices = transformed_text.attack_attrs["newly_modified_indices"]
except KeyError:

View File

@@ -1 +1 @@
from .learning_to_write import LearningToWriteLanguageModel
from .learning_to_write import LearningToWriteLanguageModel

View File

@@ -1,7 +1,8 @@
import torch
from torch import nn
from torch.nn.functional import log_softmax
from torch.autograd import Variable
from torch.nn.functional import log_softmax
class AdaptiveSoftmax(nn.Module):
def __init__(self, input_size, cutoffs, scale_down=4):
@@ -11,11 +12,11 @@ class AdaptiveSoftmax(nn.Module):
self.output_size = cutoffs[0] + len(cutoffs) - 1
self.head = nn.Linear(input_size, self.output_size)
self.tail = nn.ModuleList()
for i in range(len(cutoffs)-1):
seq = nn.Sequential(
nn.Linear(input_size, input_size // scale_down, False),
nn.Linear(input_size // scale_down, cutoffs[i+1] - cutoffs[i], False)
)
for i in range(len(cutoffs) - 1):
seq = nn.Sequential(
nn.Linear(input_size, input_size // scale_down, False),
nn.Linear(input_size // scale_down, cutoffs[i + 1] - cutoffs[i], False),
)
self.tail.append(seq)
def reset(self, init=0.1):
@@ -27,14 +28,14 @@ class AdaptiveSoftmax(nn.Module):
def set_target(self, target):
self.id = []
for i in range(len(self.cutoffs) - 1):
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i+1]))
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
if mask.sum() > 0:
self.id.append(Variable(mask.float().nonzero().squeeze(1)))
else:
self.id.append(None)
def forward(self, inp):
assert(len(inp.size()) == 2)
assert len(inp.size()) == 2
output = [self.head(inp)]
for i in range(len(self.id)):
if self.id[i] is not None:
@@ -44,22 +45,25 @@ class AdaptiveSoftmax(nn.Module):
return output
def log_prob(self, inp):
assert(len(inp.size()) == 2)
assert len(inp.size()) == 2
head_out = self.head(inp)
n = inp.size(0)
prob = torch.zeros(n, self.cutoffs[-1]).cuda()
lsm_head = log_softmax(head_out, dim=head_out.dim()-1)
prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data)
lsm_head = log_softmax(head_out, dim=head_out.dim() - 1)
prob.narrow(1, 0, self.output_size).add_(
lsm_head.narrow(1, 0, self.output_size).data
)
for i in range(len(self.tail)):
pos = self.cutoffs[i]
i_size = self.cutoffs[i+1] - pos
i_size = self.cutoffs[i + 1] - pos
buff = lsm_head.narrow(1, self.cutoffs[0] + i, 1)
buff = buff.expand(n, i_size)
temp = self.tail[i](inp)
lsm_tail = log_softmax(temp, dim=temp.dim()-1)
lsm_tail = log_softmax(temp, dim=temp.dim() - 1)
prob.narrow(1, pos, i_size).copy_(buff.data).add_(lsm_tail.data)
return prob
class AdaptiveLoss(nn.Module):
def __init__(self, cutoffs):
super().__init__()
@@ -74,8 +78,8 @@ class AdaptiveLoss(nn.Module):
def remap_target(self, target):
new_target = [target.clone()]
for i in range(len(self.cutoffs)-1):
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i+1]))
for i in range(len(self.cutoffs) - 1):
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
if mask.sum() > 0:
new_target[0][mask] = self.cutoffs[0] + i
@@ -86,12 +90,12 @@ class AdaptiveLoss(nn.Module):
def forward(self, inp, target):
n = inp[0].size(0)
target = self.remap_target(target.data)
target = self.remap_target(target.data)
loss = 0
for i in range(len(inp)):
if inp[i] is not None:
assert(target[i].min() >= 0 and target[i].max() <= inp[i].size(1))
assert target[i].min() >= 0 and target[i].max() <= inp[i].size(1)
criterion = self.criterions[i]
loss += criterion(inp[i], Variable(target[i]))
loss /= n
return loss
return loss

View File

@@ -1,68 +1,81 @@
import numpy as np
import os
import numpy as np
import torch
import torchfile
from .rnn_model import RNNModel
class QueryHandler():
def __init__(self, model, word_to_idx, mapto, device):
self.model = model
self.word_to_idx = word_to_idx
self.mapto = mapto
self.device = device
def query(self, sentences, swapped_words, batch_size=32):
""" Since we don't filter prefixes for OOV ahead of time, it's possible that
class QueryHandler:
def __init__(self, model, word_to_idx, mapto, device):
self.model = model
self.word_to_idx = word_to_idx
self.mapto = mapto
self.device = device
def query(self, sentences, swapped_words, batch_size=32):
""" Since we don't filter prefixes for OOV ahead of time, it's possible that
some of them will have different lengths. When this is the case,
we can't do RNN prediction in batch.
This method _tries_ to do prediction in batch, and, when it fails,
just does prediction sequentially and concatenates all of the results.
"""
try:
return self.try_query(sentences, swapped_words, batch_size=batch_size)
except:
probs = []
for s, w in zip(sentences, swapped_words):
probs.append(self.try_query([s], [w], batch_size=1)[0])
return probs
def try_query(self, sentences, swapped_words, batch_size=32):
# TODO use caching
sentence_length = len(sentences[0])
if any(len(s) != sentence_length for s in sentences):
raise ValueError('Only same length batches are allowed')
try:
return self.try_query(sentences, swapped_words, batch_size=batch_size)
except:
probs = []
for s, w in zip(sentences, swapped_words):
probs.append(self.try_query([s], [w], batch_size=1)[0])
return probs
def try_query(self, sentences, swapped_words, batch_size=32):
# TODO use caching
sentence_length = len(sentences[0])
if any(len(s) != sentence_length for s in sentences):
raise ValueError("Only same length batches are allowed")
log_probs = []
for start in range(0, len(sentences), batch_size):
swapped_words_batch = swapped_words[
start : min(len(sentences), start + batch_size)
]
batch = sentences[start : min(len(sentences), start + batch_size)]
raw_idx_list = [[] for i in range(sentence_length + 1)]
for i, s in enumerate(batch):
s = [word for word in s if word in self.word_to_idx]
words = ["<S>"] + s
word_idxs = [self.word_to_idx[w] for w in words]
for t in range(sentence_length + 1):
if t < len(word_idxs):
raw_idx_list[t].append(word_idxs[t])
orig_num_idxs = len(raw_idx_list)
raw_idx_list = [x for x in raw_idx_list if len(x)]
num_idxs_dropped = orig_num_idxs - len(raw_idx_list)
all_raw_idxs = torch.tensor(
raw_idx_list, device=self.device, dtype=torch.long
)
word_idxs = self.mapto[all_raw_idxs]
hidden = self.model.init_hidden(len(batch))
source = word_idxs[:-1, :]
target = word_idxs[1:, :]
decode, hidden = self.model(source, hidden)
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
for i in range(len(batch)):
if swapped_words_batch[i] not in self.word_to_idx:
log_probs.append(float("-inf"))
else:
log_probs.append(
sum(
[
decode[t, i, target[t, i]].item()
for t in range(sentence_length - num_idxs_dropped)
]
)
)
return log_probs
log_probs = []
for start in range(0, len(sentences), batch_size):
swapped_words_batch = swapped_words[start:min(len(sentences), start + batch_size)]
batch = sentences[start:min(len(sentences), start + batch_size)]
raw_idx_list = [[] for i in range(sentence_length+1)]
for i, s in enumerate(batch):
s = [word for word in s if word in self.word_to_idx]
words = ['<S>'] + s
word_idxs = [self.word_to_idx[w] for w in words]
for t in range(sentence_length+1):
if t < len(word_idxs):
raw_idx_list[t].append(word_idxs[t])
orig_num_idxs = len(raw_idx_list)
raw_idx_list = [x for x in raw_idx_list if len(x)]
num_idxs_dropped = orig_num_idxs - len(raw_idx_list)
all_raw_idxs = torch.tensor(raw_idx_list, device=self.device,
dtype=torch.long)
word_idxs = self.mapto[all_raw_idxs]
hidden = self.model.init_hidden(len(batch))
source = word_idxs[:-1,:]
target = word_idxs[1:,:]
decode, hidden = self.model(source, hidden)
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
for i in range(len(batch)):
if swapped_words_batch[i] not in self.word_to_idx:
log_probs.append(float('-inf'))
else:
log_probs.append(sum([decode[t, i, target[t, i]].item() for t in range(sentence_length-num_idxs_dropped)]))
return log_probs
def util_reverse(item):
new_item = np.zeros(len(item))
@@ -70,20 +83,33 @@ def util_reverse(item):
new_item[val] = idx
return new_item
def load_model(lm_folder_path, device):
word_map = torchfile.load(os.path.join(lm_folder_path, 'word_map.th7'))
word_map = [w.decode('utf-8') for w in word_map]
word_to_idx = {w: i for i, w in enumerate(word_map)}
word_freq = torchfile.load(os.path.join(os.path.join(lm_folder_path, 'word_freq.th7')))
mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device)
model_file = open(os.path.join(lm_folder_path, 'lm-state-dict.pt'), 'rb')
model = RNNModel('GRU', 793471, 256, 2048, 1, [4200, 35000, 180000, 793471], dropout=0.01, proj=True, lm1b=True)
model.load_state_dict(torch.load(model_file))
model.full = True # Use real softmax--important!
model.to(device)
model.eval()
model_file.close()
return QueryHandler(model, word_to_idx, mapto, device)
def load_model(lm_folder_path, device):
word_map = torchfile.load(os.path.join(lm_folder_path, "word_map.th7"))
word_map = [w.decode("utf-8") for w in word_map]
word_to_idx = {w: i for i, w in enumerate(word_map)}
word_freq = torchfile.load(
os.path.join(os.path.join(lm_folder_path, "word_freq.th7"))
)
mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device)
model_file = open(os.path.join(lm_folder_path, "lm-state-dict.pt"), "rb")
model = RNNModel(
"GRU",
793471,
256,
2048,
1,
[4200, 35000, 180000, 793471],
dropout=0.01,
proj=True,
lm1b=True,
)
model.load_state_dict(torch.load(model_file))
model.full = True # Use real softmax--important!
model.to(device)
model.eval()
model_file.close()
return QueryHandler(model, word_to_idx, mapto, device)

View File

@@ -1,7 +1,9 @@
import textattack
import torch
from textattack.constraints.grammaticality.language_models import LanguageModelConstraint
import textattack
from textattack.constraints.grammaticality.language_models import (
LanguageModelConstraint,
)
from .language_model_helpers import load_model
@@ -24,10 +26,13 @@ class LearningToWriteLanguageModel(LanguageModelConstraint):
https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689
"""
CACHE_PATH = 'constraints/grammaticality/language-models/learning-to-write'
CACHE_PATH = "constraints/grammaticality/language-models/learning-to-write"
def __init__(self, window_size=5, **kwargs):
self.window_size = window_size
lm_folder_path = textattack.shared.utils.download_if_needed(LearningToWriteLanguageModel.CACHE_PATH)
lm_folder_path = textattack.shared.utils.download_if_needed(
LearningToWriteLanguageModel.CACHE_PATH
)
self.query_handler = load_model(lm_folder_path, textattack.shared.utils.device)
super().__init__(**kwargs)
@@ -39,7 +44,9 @@ class LearningToWriteLanguageModel(LanguageModelConstraint):
query_words = []
for attacked_text in text_list:
word = attacked_text.words[word_index]
window_text = attacked_text.text_window_around_index(word_index, self.window_size)
window_text = attacked_text.text_window_around_index(
word_index, self.window_size
)
query = textattack.shared.utils.words_from_text(window_text)
queries.append(query)
query_words.append(word)

View File

@@ -1,28 +1,44 @@
import torch.nn as nn
from torch.autograd import Variable
import torch.nn as nn
from .adaptive_softmax import AdaptiveSoftmax
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder. Based on official pytorch examples"""
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, cutoffs, proj=False, dropout=0.5, tie_weights=False,
lm1b=False):
def __init__(
self,
rnn_type,
ntoken,
ninp,
nhid,
nlayers,
cutoffs,
proj=False,
dropout=0.5,
tie_weights=False,
lm1b=False,
):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.lm1b = lm1b
if rnn_type == 'GRU':
if rnn_type == "GRU":
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
except KeyError:
raise ValueError( """An invalid option for `--model` was supplied,
options are ['GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
raise ValueError(
"""An invalid option for `--model` was supplied,
options are ['GRU', 'RNN_TANH' or 'RNN_RELU']"""
)
self.rnn = nn.RNN(
ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
)
self.proj = proj
@@ -71,7 +87,7 @@ class RNNModel(nn.Module):
if self.proj:
output = self.proj_layer(output)
output = output.view(output.size(0)*output.size(1), output.size(2))
output = output.view(output.size(0) * output.size(1), output.size(2))
if self.full:
decode = self.softmax.log_prob(output)
@@ -82,4 +98,4 @@ class RNNModel(nn.Module):
def init_hidden(self, bsz):
weight = next(self.parameters()).data
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())

View File

@@ -75,7 +75,7 @@ class AttackLogManager:
failed_attacks = 0
skipped_attacks = 0
successful_attacks = 0
max_words_changed = 0
max_words_changed = 0
for i, result in enumerate(self.results):
all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult):