mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
add LM constraint to args + formatting
This commit is contained in:
@@ -104,11 +104,14 @@ def FasterGeneticAlgorithmJia2019(model):
|
||||
#
|
||||
# Language Model
|
||||
#
|
||||
#
|
||||
#
|
||||
constraints.append(LearningToWriteLanguageModel(window_size=6, max_log_prob_diff=5.,
|
||||
compare_against_original=True))
|
||||
# constraints.append(LearningToWriteLanguageModel(window_size=5))
|
||||
#
|
||||
#
|
||||
constraints.append(
|
||||
LearningToWriteLanguageModel(
|
||||
window_size=6, max_log_prob_diff=5.0, compare_against_original=True
|
||||
)
|
||||
)
|
||||
# constraints.append(LearningToWriteLanguageModel(window_size=5))
|
||||
#
|
||||
# Goal is untargeted classification
|
||||
#
|
||||
|
||||
@@ -243,6 +243,7 @@ CONSTRAINT_CLASS_NAMES = {
|
||||
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
|
||||
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
|
||||
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
|
||||
"gpt2": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
|
||||
#
|
||||
# Overlap constraints
|
||||
#
|
||||
|
||||
@@ -2,4 +2,4 @@ from .language_model_constraint import LanguageModelConstraint
|
||||
|
||||
from .google_language_model import Google1BillionWordsLanguageModel
|
||||
from .gpt2 import GPT2
|
||||
from .learning_to_write import LearningToWriteLanguageModel
|
||||
from .learning_to_write import LearningToWriteLanguageModel
|
||||
|
||||
@@ -31,7 +31,7 @@ class LanguageModelConstraint(ABC, Constraint):
|
||||
def _check_constraint(self, transformed_text, current_text, original_text=None):
|
||||
if self.compare_against_original:
|
||||
current_text = original_text
|
||||
|
||||
|
||||
try:
|
||||
indices = transformed_text.attack_attrs["newly_modified_indices"]
|
||||
except KeyError:
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .learning_to_write import LearningToWriteLanguageModel
|
||||
from .learning_to_write import LearningToWriteLanguageModel
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.functional import log_softmax
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.functional import log_softmax
|
||||
|
||||
|
||||
class AdaptiveSoftmax(nn.Module):
|
||||
def __init__(self, input_size, cutoffs, scale_down=4):
|
||||
@@ -11,11 +12,11 @@ class AdaptiveSoftmax(nn.Module):
|
||||
self.output_size = cutoffs[0] + len(cutoffs) - 1
|
||||
self.head = nn.Linear(input_size, self.output_size)
|
||||
self.tail = nn.ModuleList()
|
||||
for i in range(len(cutoffs)-1):
|
||||
seq = nn.Sequential(
|
||||
nn.Linear(input_size, input_size // scale_down, False),
|
||||
nn.Linear(input_size // scale_down, cutoffs[i+1] - cutoffs[i], False)
|
||||
)
|
||||
for i in range(len(cutoffs) - 1):
|
||||
seq = nn.Sequential(
|
||||
nn.Linear(input_size, input_size // scale_down, False),
|
||||
nn.Linear(input_size // scale_down, cutoffs[i + 1] - cutoffs[i], False),
|
||||
)
|
||||
self.tail.append(seq)
|
||||
|
||||
def reset(self, init=0.1):
|
||||
@@ -27,14 +28,14 @@ class AdaptiveSoftmax(nn.Module):
|
||||
def set_target(self, target):
|
||||
self.id = []
|
||||
for i in range(len(self.cutoffs) - 1):
|
||||
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i+1]))
|
||||
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
|
||||
if mask.sum() > 0:
|
||||
self.id.append(Variable(mask.float().nonzero().squeeze(1)))
|
||||
else:
|
||||
self.id.append(None)
|
||||
|
||||
def forward(self, inp):
|
||||
assert(len(inp.size()) == 2)
|
||||
assert len(inp.size()) == 2
|
||||
output = [self.head(inp)]
|
||||
for i in range(len(self.id)):
|
||||
if self.id[i] is not None:
|
||||
@@ -44,22 +45,25 @@ class AdaptiveSoftmax(nn.Module):
|
||||
return output
|
||||
|
||||
def log_prob(self, inp):
|
||||
assert(len(inp.size()) == 2)
|
||||
assert len(inp.size()) == 2
|
||||
head_out = self.head(inp)
|
||||
n = inp.size(0)
|
||||
prob = torch.zeros(n, self.cutoffs[-1]).cuda()
|
||||
lsm_head = log_softmax(head_out, dim=head_out.dim()-1)
|
||||
prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data)
|
||||
lsm_head = log_softmax(head_out, dim=head_out.dim() - 1)
|
||||
prob.narrow(1, 0, self.output_size).add_(
|
||||
lsm_head.narrow(1, 0, self.output_size).data
|
||||
)
|
||||
for i in range(len(self.tail)):
|
||||
pos = self.cutoffs[i]
|
||||
i_size = self.cutoffs[i+1] - pos
|
||||
i_size = self.cutoffs[i + 1] - pos
|
||||
buff = lsm_head.narrow(1, self.cutoffs[0] + i, 1)
|
||||
buff = buff.expand(n, i_size)
|
||||
temp = self.tail[i](inp)
|
||||
lsm_tail = log_softmax(temp, dim=temp.dim()-1)
|
||||
lsm_tail = log_softmax(temp, dim=temp.dim() - 1)
|
||||
prob.narrow(1, pos, i_size).copy_(buff.data).add_(lsm_tail.data)
|
||||
return prob
|
||||
|
||||
|
||||
class AdaptiveLoss(nn.Module):
|
||||
def __init__(self, cutoffs):
|
||||
super().__init__()
|
||||
@@ -74,8 +78,8 @@ class AdaptiveLoss(nn.Module):
|
||||
|
||||
def remap_target(self, target):
|
||||
new_target = [target.clone()]
|
||||
for i in range(len(self.cutoffs)-1):
|
||||
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i+1]))
|
||||
for i in range(len(self.cutoffs) - 1):
|
||||
mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1]))
|
||||
|
||||
if mask.sum() > 0:
|
||||
new_target[0][mask] = self.cutoffs[0] + i
|
||||
@@ -86,12 +90,12 @@ class AdaptiveLoss(nn.Module):
|
||||
|
||||
def forward(self, inp, target):
|
||||
n = inp[0].size(0)
|
||||
target = self.remap_target(target.data)
|
||||
target = self.remap_target(target.data)
|
||||
loss = 0
|
||||
for i in range(len(inp)):
|
||||
if inp[i] is not None:
|
||||
assert(target[i].min() >= 0 and target[i].max() <= inp[i].size(1))
|
||||
assert target[i].min() >= 0 and target[i].max() <= inp[i].size(1)
|
||||
criterion = self.criterions[i]
|
||||
loss += criterion(inp[i], Variable(target[i]))
|
||||
loss /= n
|
||||
return loss
|
||||
return loss
|
||||
|
||||
@@ -1,68 +1,81 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchfile
|
||||
|
||||
from .rnn_model import RNNModel
|
||||
|
||||
class QueryHandler():
|
||||
def __init__(self, model, word_to_idx, mapto, device):
|
||||
self.model = model
|
||||
self.word_to_idx = word_to_idx
|
||||
self.mapto = mapto
|
||||
self.device = device
|
||||
|
||||
def query(self, sentences, swapped_words, batch_size=32):
|
||||
""" Since we don't filter prefixes for OOV ahead of time, it's possible that
|
||||
class QueryHandler:
|
||||
def __init__(self, model, word_to_idx, mapto, device):
|
||||
self.model = model
|
||||
self.word_to_idx = word_to_idx
|
||||
self.mapto = mapto
|
||||
self.device = device
|
||||
|
||||
def query(self, sentences, swapped_words, batch_size=32):
|
||||
""" Since we don't filter prefixes for OOV ahead of time, it's possible that
|
||||
some of them will have different lengths. When this is the case,
|
||||
we can't do RNN prediction in batch.
|
||||
|
||||
This method _tries_ to do prediction in batch, and, when it fails,
|
||||
just does prediction sequentially and concatenates all of the results.
|
||||
"""
|
||||
try:
|
||||
return self.try_query(sentences, swapped_words, batch_size=batch_size)
|
||||
except:
|
||||
probs = []
|
||||
for s, w in zip(sentences, swapped_words):
|
||||
probs.append(self.try_query([s], [w], batch_size=1)[0])
|
||||
return probs
|
||||
|
||||
def try_query(self, sentences, swapped_words, batch_size=32):
|
||||
# TODO use caching
|
||||
sentence_length = len(sentences[0])
|
||||
if any(len(s) != sentence_length for s in sentences):
|
||||
raise ValueError('Only same length batches are allowed')
|
||||
try:
|
||||
return self.try_query(sentences, swapped_words, batch_size=batch_size)
|
||||
except:
|
||||
probs = []
|
||||
for s, w in zip(sentences, swapped_words):
|
||||
probs.append(self.try_query([s], [w], batch_size=1)[0])
|
||||
return probs
|
||||
|
||||
def try_query(self, sentences, swapped_words, batch_size=32):
|
||||
# TODO use caching
|
||||
sentence_length = len(sentences[0])
|
||||
if any(len(s) != sentence_length for s in sentences):
|
||||
raise ValueError("Only same length batches are allowed")
|
||||
|
||||
log_probs = []
|
||||
for start in range(0, len(sentences), batch_size):
|
||||
swapped_words_batch = swapped_words[
|
||||
start : min(len(sentences), start + batch_size)
|
||||
]
|
||||
batch = sentences[start : min(len(sentences), start + batch_size)]
|
||||
raw_idx_list = [[] for i in range(sentence_length + 1)]
|
||||
for i, s in enumerate(batch):
|
||||
s = [word for word in s if word in self.word_to_idx]
|
||||
words = ["<S>"] + s
|
||||
word_idxs = [self.word_to_idx[w] for w in words]
|
||||
for t in range(sentence_length + 1):
|
||||
if t < len(word_idxs):
|
||||
raw_idx_list[t].append(word_idxs[t])
|
||||
orig_num_idxs = len(raw_idx_list)
|
||||
raw_idx_list = [x for x in raw_idx_list if len(x)]
|
||||
num_idxs_dropped = orig_num_idxs - len(raw_idx_list)
|
||||
all_raw_idxs = torch.tensor(
|
||||
raw_idx_list, device=self.device, dtype=torch.long
|
||||
)
|
||||
word_idxs = self.mapto[all_raw_idxs]
|
||||
hidden = self.model.init_hidden(len(batch))
|
||||
source = word_idxs[:-1, :]
|
||||
target = word_idxs[1:, :]
|
||||
decode, hidden = self.model(source, hidden)
|
||||
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
|
||||
for i in range(len(batch)):
|
||||
if swapped_words_batch[i] not in self.word_to_idx:
|
||||
log_probs.append(float("-inf"))
|
||||
else:
|
||||
log_probs.append(
|
||||
sum(
|
||||
[
|
||||
decode[t, i, target[t, i]].item()
|
||||
for t in range(sentence_length - num_idxs_dropped)
|
||||
]
|
||||
)
|
||||
)
|
||||
return log_probs
|
||||
|
||||
log_probs = []
|
||||
for start in range(0, len(sentences), batch_size):
|
||||
swapped_words_batch = swapped_words[start:min(len(sentences), start + batch_size)]
|
||||
batch = sentences[start:min(len(sentences), start + batch_size)]
|
||||
raw_idx_list = [[] for i in range(sentence_length+1)]
|
||||
for i, s in enumerate(batch):
|
||||
s = [word for word in s if word in self.word_to_idx]
|
||||
words = ['<S>'] + s
|
||||
word_idxs = [self.word_to_idx[w] for w in words]
|
||||
for t in range(sentence_length+1):
|
||||
if t < len(word_idxs):
|
||||
raw_idx_list[t].append(word_idxs[t])
|
||||
orig_num_idxs = len(raw_idx_list)
|
||||
raw_idx_list = [x for x in raw_idx_list if len(x)]
|
||||
num_idxs_dropped = orig_num_idxs - len(raw_idx_list)
|
||||
all_raw_idxs = torch.tensor(raw_idx_list, device=self.device,
|
||||
dtype=torch.long)
|
||||
word_idxs = self.mapto[all_raw_idxs]
|
||||
hidden = self.model.init_hidden(len(batch))
|
||||
source = word_idxs[:-1,:]
|
||||
target = word_idxs[1:,:]
|
||||
decode, hidden = self.model(source, hidden)
|
||||
decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1)
|
||||
for i in range(len(batch)):
|
||||
if swapped_words_batch[i] not in self.word_to_idx:
|
||||
log_probs.append(float('-inf'))
|
||||
else:
|
||||
log_probs.append(sum([decode[t, i, target[t, i]].item() for t in range(sentence_length-num_idxs_dropped)]))
|
||||
return log_probs
|
||||
|
||||
def util_reverse(item):
|
||||
new_item = np.zeros(len(item))
|
||||
@@ -70,20 +83,33 @@ def util_reverse(item):
|
||||
new_item[val] = idx
|
||||
return new_item
|
||||
|
||||
def load_model(lm_folder_path, device):
|
||||
word_map = torchfile.load(os.path.join(lm_folder_path, 'word_map.th7'))
|
||||
word_map = [w.decode('utf-8') for w in word_map]
|
||||
word_to_idx = {w: i for i, w in enumerate(word_map)}
|
||||
word_freq = torchfile.load(os.path.join(os.path.join(lm_folder_path, 'word_freq.th7')))
|
||||
mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device)
|
||||
|
||||
model_file = open(os.path.join(lm_folder_path, 'lm-state-dict.pt'), 'rb')
|
||||
|
||||
model = RNNModel('GRU', 793471, 256, 2048, 1, [4200, 35000, 180000, 793471], dropout=0.01, proj=True, lm1b=True)
|
||||
|
||||
model.load_state_dict(torch.load(model_file))
|
||||
model.full = True # Use real softmax--important!
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model_file.close()
|
||||
return QueryHandler(model, word_to_idx, mapto, device)
|
||||
def load_model(lm_folder_path, device):
|
||||
word_map = torchfile.load(os.path.join(lm_folder_path, "word_map.th7"))
|
||||
word_map = [w.decode("utf-8") for w in word_map]
|
||||
word_to_idx = {w: i for i, w in enumerate(word_map)}
|
||||
word_freq = torchfile.load(
|
||||
os.path.join(os.path.join(lm_folder_path, "word_freq.th7"))
|
||||
)
|
||||
mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device)
|
||||
|
||||
model_file = open(os.path.join(lm_folder_path, "lm-state-dict.pt"), "rb")
|
||||
|
||||
model = RNNModel(
|
||||
"GRU",
|
||||
793471,
|
||||
256,
|
||||
2048,
|
||||
1,
|
||||
[4200, 35000, 180000, 793471],
|
||||
dropout=0.01,
|
||||
proj=True,
|
||||
lm1b=True,
|
||||
)
|
||||
|
||||
model.load_state_dict(torch.load(model_file))
|
||||
model.full = True # Use real softmax--important!
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model_file.close()
|
||||
return QueryHandler(model, word_to_idx, mapto, device)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import textattack
|
||||
import torch
|
||||
|
||||
from textattack.constraints.grammaticality.language_models import LanguageModelConstraint
|
||||
import textattack
|
||||
from textattack.constraints.grammaticality.language_models import (
|
||||
LanguageModelConstraint,
|
||||
)
|
||||
|
||||
from .language_model_helpers import load_model
|
||||
|
||||
@@ -24,10 +26,13 @@ class LearningToWriteLanguageModel(LanguageModelConstraint):
|
||||
https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689
|
||||
"""
|
||||
|
||||
CACHE_PATH = 'constraints/grammaticality/language-models/learning-to-write'
|
||||
CACHE_PATH = "constraints/grammaticality/language-models/learning-to-write"
|
||||
|
||||
def __init__(self, window_size=5, **kwargs):
|
||||
self.window_size = window_size
|
||||
lm_folder_path = textattack.shared.utils.download_if_needed(LearningToWriteLanguageModel.CACHE_PATH)
|
||||
lm_folder_path = textattack.shared.utils.download_if_needed(
|
||||
LearningToWriteLanguageModel.CACHE_PATH
|
||||
)
|
||||
self.query_handler = load_model(lm_folder_path, textattack.shared.utils.device)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -39,7 +44,9 @@ class LearningToWriteLanguageModel(LanguageModelConstraint):
|
||||
query_words = []
|
||||
for attacked_text in text_list:
|
||||
word = attacked_text.words[word_index]
|
||||
window_text = attacked_text.text_window_around_index(word_index, self.window_size)
|
||||
window_text = attacked_text.text_window_around_index(
|
||||
word_index, self.window_size
|
||||
)
|
||||
query = textattack.shared.utils.words_from_text(window_text)
|
||||
queries.append(query)
|
||||
query_words.append(word)
|
||||
|
||||
@@ -1,28 +1,44 @@
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
|
||||
from .adaptive_softmax import AdaptiveSoftmax
|
||||
|
||||
|
||||
class RNNModel(nn.Module):
|
||||
"""Container module with an encoder, a recurrent module, and a decoder. Based on official pytorch examples"""
|
||||
|
||||
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, cutoffs, proj=False, dropout=0.5, tie_weights=False,
|
||||
lm1b=False):
|
||||
def __init__(
|
||||
self,
|
||||
rnn_type,
|
||||
ntoken,
|
||||
ninp,
|
||||
nhid,
|
||||
nlayers,
|
||||
cutoffs,
|
||||
proj=False,
|
||||
dropout=0.5,
|
||||
tie_weights=False,
|
||||
lm1b=False,
|
||||
):
|
||||
super(RNNModel, self).__init__()
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.encoder = nn.Embedding(ntoken, ninp)
|
||||
|
||||
self.lm1b = lm1b
|
||||
|
||||
if rnn_type == 'GRU':
|
||||
if rnn_type == "GRU":
|
||||
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
|
||||
else:
|
||||
try:
|
||||
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
|
||||
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
|
||||
except KeyError:
|
||||
raise ValueError( """An invalid option for `--model` was supplied,
|
||||
options are ['GRU', 'RNN_TANH' or 'RNN_RELU']""")
|
||||
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
|
||||
raise ValueError(
|
||||
"""An invalid option for `--model` was supplied,
|
||||
options are ['GRU', 'RNN_TANH' or 'RNN_RELU']"""
|
||||
)
|
||||
self.rnn = nn.RNN(
|
||||
ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
|
||||
)
|
||||
|
||||
self.proj = proj
|
||||
|
||||
@@ -71,7 +87,7 @@ class RNNModel(nn.Module):
|
||||
if self.proj:
|
||||
output = self.proj_layer(output)
|
||||
|
||||
output = output.view(output.size(0)*output.size(1), output.size(2))
|
||||
output = output.view(output.size(0) * output.size(1), output.size(2))
|
||||
|
||||
if self.full:
|
||||
decode = self.softmax.log_prob(output)
|
||||
@@ -82,4 +98,4 @@ class RNNModel(nn.Module):
|
||||
|
||||
def init_hidden(self, bsz):
|
||||
weight = next(self.parameters()).data
|
||||
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())
|
||||
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())
|
||||
|
||||
@@ -75,7 +75,7 @@ class AttackLogManager:
|
||||
failed_attacks = 0
|
||||
skipped_attacks = 0
|
||||
successful_attacks = 0
|
||||
max_words_changed = 0
|
||||
max_words_changed = 0
|
||||
for i, result in enumerate(self.results):
|
||||
all_num_words[i] = len(result.original_result.attacked_text.words)
|
||||
if isinstance(result, FailedAttackResult):
|
||||
|
||||
Reference in New Issue
Block a user