diff --git a/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py b/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py index 8af27838..fe09c62d 100644 --- a/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py +++ b/textattack/attack_recipes/faster_genetic_algorithm_jia_2019.py @@ -104,11 +104,14 @@ def FasterGeneticAlgorithmJia2019(model): # # Language Model # - # - # - constraints.append(LearningToWriteLanguageModel(window_size=6, max_log_prob_diff=5., - compare_against_original=True)) - # constraints.append(LearningToWriteLanguageModel(window_size=5)) + # + # + constraints.append( + LearningToWriteLanguageModel( + window_size=6, max_log_prob_diff=5.0, compare_against_original=True + ) + ) + # constraints.append(LearningToWriteLanguageModel(window_size=5)) # # Goal is untargeted classification # diff --git a/textattack/commands/attack/attack_args.py b/textattack/commands/attack/attack_args.py index ae53c279..73329911 100644 --- a/textattack/commands/attack/attack_args.py +++ b/textattack/commands/attack/attack_args.py @@ -243,6 +243,7 @@ CONSTRAINT_CLASS_NAMES = { "part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech", "goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel", "gpt2": "textattack.constraints.grammaticality.language_models.GPT2", + "gpt2": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel", # # Overlap constraints # diff --git a/textattack/constraints/grammaticality/language_models/__init__.py b/textattack/constraints/grammaticality/language_models/__init__.py index 0eafd144..5423d23c 100644 --- a/textattack/constraints/grammaticality/language_models/__init__.py +++ b/textattack/constraints/grammaticality/language_models/__init__.py @@ -2,4 +2,4 @@ from .language_model_constraint import LanguageModelConstraint from .google_language_model import Google1BillionWordsLanguageModel from .gpt2 import GPT2 -from .learning_to_write import LearningToWriteLanguageModel \ No newline at end of file +from .learning_to_write import LearningToWriteLanguageModel diff --git a/textattack/constraints/grammaticality/language_models/language_model_constraint.py b/textattack/constraints/grammaticality/language_models/language_model_constraint.py index a0b0d463..f1fcd7c3 100644 --- a/textattack/constraints/grammaticality/language_models/language_model_constraint.py +++ b/textattack/constraints/grammaticality/language_models/language_model_constraint.py @@ -31,7 +31,7 @@ class LanguageModelConstraint(ABC, Constraint): def _check_constraint(self, transformed_text, current_text, original_text=None): if self.compare_against_original: current_text = original_text - + try: indices = transformed_text.attack_attrs["newly_modified_indices"] except KeyError: diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/__init__.py b/textattack/constraints/grammaticality/language_models/learning_to_write/__init__.py index ba2e91d7..3bf01ec6 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/__init__.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/__init__.py @@ -1 +1 @@ -from .learning_to_write import LearningToWriteLanguageModel \ No newline at end of file +from .learning_to_write import LearningToWriteLanguageModel diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/adaptive_softmax.py b/textattack/constraints/grammaticality/language_models/learning_to_write/adaptive_softmax.py index 841951ab..8a5e0e51 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/adaptive_softmax.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/adaptive_softmax.py @@ -1,7 +1,8 @@ import torch from torch import nn -from torch.nn.functional import log_softmax from torch.autograd import Variable +from torch.nn.functional import log_softmax + class AdaptiveSoftmax(nn.Module): def __init__(self, input_size, cutoffs, scale_down=4): @@ -11,11 +12,11 @@ class AdaptiveSoftmax(nn.Module): self.output_size = cutoffs[0] + len(cutoffs) - 1 self.head = nn.Linear(input_size, self.output_size) self.tail = nn.ModuleList() - for i in range(len(cutoffs)-1): - seq = nn.Sequential( - nn.Linear(input_size, input_size // scale_down, False), - nn.Linear(input_size // scale_down, cutoffs[i+1] - cutoffs[i], False) - ) + for i in range(len(cutoffs) - 1): + seq = nn.Sequential( + nn.Linear(input_size, input_size // scale_down, False), + nn.Linear(input_size // scale_down, cutoffs[i + 1] - cutoffs[i], False), + ) self.tail.append(seq) def reset(self, init=0.1): @@ -27,14 +28,14 @@ class AdaptiveSoftmax(nn.Module): def set_target(self, target): self.id = [] for i in range(len(self.cutoffs) - 1): - mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i+1])) + mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1])) if mask.sum() > 0: self.id.append(Variable(mask.float().nonzero().squeeze(1))) else: self.id.append(None) def forward(self, inp): - assert(len(inp.size()) == 2) + assert len(inp.size()) == 2 output = [self.head(inp)] for i in range(len(self.id)): if self.id[i] is not None: @@ -44,22 +45,25 @@ class AdaptiveSoftmax(nn.Module): return output def log_prob(self, inp): - assert(len(inp.size()) == 2) + assert len(inp.size()) == 2 head_out = self.head(inp) n = inp.size(0) prob = torch.zeros(n, self.cutoffs[-1]).cuda() - lsm_head = log_softmax(head_out, dim=head_out.dim()-1) - prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data) + lsm_head = log_softmax(head_out, dim=head_out.dim() - 1) + prob.narrow(1, 0, self.output_size).add_( + lsm_head.narrow(1, 0, self.output_size).data + ) for i in range(len(self.tail)): pos = self.cutoffs[i] - i_size = self.cutoffs[i+1] - pos + i_size = self.cutoffs[i + 1] - pos buff = lsm_head.narrow(1, self.cutoffs[0] + i, 1) buff = buff.expand(n, i_size) temp = self.tail[i](inp) - lsm_tail = log_softmax(temp, dim=temp.dim()-1) + lsm_tail = log_softmax(temp, dim=temp.dim() - 1) prob.narrow(1, pos, i_size).copy_(buff.data).add_(lsm_tail.data) return prob + class AdaptiveLoss(nn.Module): def __init__(self, cutoffs): super().__init__() @@ -74,8 +78,8 @@ class AdaptiveLoss(nn.Module): def remap_target(self, target): new_target = [target.clone()] - for i in range(len(self.cutoffs)-1): - mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i+1])) + for i in range(len(self.cutoffs) - 1): + mask = target.ge(self.cutoffs[i]).mul(target.lt(self.cutoffs[i + 1])) if mask.sum() > 0: new_target[0][mask] = self.cutoffs[0] + i @@ -86,12 +90,12 @@ class AdaptiveLoss(nn.Module): def forward(self, inp, target): n = inp[0].size(0) - target = self.remap_target(target.data) + target = self.remap_target(target.data) loss = 0 for i in range(len(inp)): if inp[i] is not None: - assert(target[i].min() >= 0 and target[i].max() <= inp[i].size(1)) + assert target[i].min() >= 0 and target[i].max() <= inp[i].size(1) criterion = self.criterions[i] loss += criterion(inp[i], Variable(target[i])) loss /= n - return loss \ No newline at end of file + return loss diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py index e381e3d7..1bf1b5f8 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/language_model_helpers.py @@ -1,68 +1,81 @@ -import numpy as np import os + +import numpy as np import torch import torchfile from .rnn_model import RNNModel -class QueryHandler(): - def __init__(self, model, word_to_idx, mapto, device): - self.model = model - self.word_to_idx = word_to_idx - self.mapto = mapto - self.device = device - def query(self, sentences, swapped_words, batch_size=32): - """ Since we don't filter prefixes for OOV ahead of time, it's possible that +class QueryHandler: + def __init__(self, model, word_to_idx, mapto, device): + self.model = model + self.word_to_idx = word_to_idx + self.mapto = mapto + self.device = device + + def query(self, sentences, swapped_words, batch_size=32): + """ Since we don't filter prefixes for OOV ahead of time, it's possible that some of them will have different lengths. When this is the case, we can't do RNN prediction in batch. This method _tries_ to do prediction in batch, and, when it fails, just does prediction sequentially and concatenates all of the results. """ - try: - return self.try_query(sentences, swapped_words, batch_size=batch_size) - except: - probs = [] - for s, w in zip(sentences, swapped_words): - probs.append(self.try_query([s], [w], batch_size=1)[0]) - return probs - - def try_query(self, sentences, swapped_words, batch_size=32): - # TODO use caching - sentence_length = len(sentences[0]) - if any(len(s) != sentence_length for s in sentences): - raise ValueError('Only same length batches are allowed') + try: + return self.try_query(sentences, swapped_words, batch_size=batch_size) + except: + probs = [] + for s, w in zip(sentences, swapped_words): + probs.append(self.try_query([s], [w], batch_size=1)[0]) + return probs + + def try_query(self, sentences, swapped_words, batch_size=32): + # TODO use caching + sentence_length = len(sentences[0]) + if any(len(s) != sentence_length for s in sentences): + raise ValueError("Only same length batches are allowed") + + log_probs = [] + for start in range(0, len(sentences), batch_size): + swapped_words_batch = swapped_words[ + start : min(len(sentences), start + batch_size) + ] + batch = sentences[start : min(len(sentences), start + batch_size)] + raw_idx_list = [[] for i in range(sentence_length + 1)] + for i, s in enumerate(batch): + s = [word for word in s if word in self.word_to_idx] + words = [""] + s + word_idxs = [self.word_to_idx[w] for w in words] + for t in range(sentence_length + 1): + if t < len(word_idxs): + raw_idx_list[t].append(word_idxs[t]) + orig_num_idxs = len(raw_idx_list) + raw_idx_list = [x for x in raw_idx_list if len(x)] + num_idxs_dropped = orig_num_idxs - len(raw_idx_list) + all_raw_idxs = torch.tensor( + raw_idx_list, device=self.device, dtype=torch.long + ) + word_idxs = self.mapto[all_raw_idxs] + hidden = self.model.init_hidden(len(batch)) + source = word_idxs[:-1, :] + target = word_idxs[1:, :] + decode, hidden = self.model(source, hidden) + decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1) + for i in range(len(batch)): + if swapped_words_batch[i] not in self.word_to_idx: + log_probs.append(float("-inf")) + else: + log_probs.append( + sum( + [ + decode[t, i, target[t, i]].item() + for t in range(sentence_length - num_idxs_dropped) + ] + ) + ) + return log_probs - log_probs = [] - for start in range(0, len(sentences), batch_size): - swapped_words_batch = swapped_words[start:min(len(sentences), start + batch_size)] - batch = sentences[start:min(len(sentences), start + batch_size)] - raw_idx_list = [[] for i in range(sentence_length+1)] - for i, s in enumerate(batch): - s = [word for word in s if word in self.word_to_idx] - words = [''] + s - word_idxs = [self.word_to_idx[w] for w in words] - for t in range(sentence_length+1): - if t < len(word_idxs): - raw_idx_list[t].append(word_idxs[t]) - orig_num_idxs = len(raw_idx_list) - raw_idx_list = [x for x in raw_idx_list if len(x)] - num_idxs_dropped = orig_num_idxs - len(raw_idx_list) - all_raw_idxs = torch.tensor(raw_idx_list, device=self.device, - dtype=torch.long) - word_idxs = self.mapto[all_raw_idxs] - hidden = self.model.init_hidden(len(batch)) - source = word_idxs[:-1,:] - target = word_idxs[1:,:] - decode, hidden = self.model(source, hidden) - decode = decode.view(sentence_length - num_idxs_dropped, len(batch), -1) - for i in range(len(batch)): - if swapped_words_batch[i] not in self.word_to_idx: - log_probs.append(float('-inf')) - else: - log_probs.append(sum([decode[t, i, target[t, i]].item() for t in range(sentence_length-num_idxs_dropped)])) - return log_probs def util_reverse(item): new_item = np.zeros(len(item)) @@ -70,20 +83,33 @@ def util_reverse(item): new_item[val] = idx return new_item -def load_model(lm_folder_path, device): - word_map = torchfile.load(os.path.join(lm_folder_path, 'word_map.th7')) - word_map = [w.decode('utf-8') for w in word_map] - word_to_idx = {w: i for i, w in enumerate(word_map)} - word_freq = torchfile.load(os.path.join(os.path.join(lm_folder_path, 'word_freq.th7'))) - mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device) - model_file = open(os.path.join(lm_folder_path, 'lm-state-dict.pt'), 'rb') - - model = RNNModel('GRU', 793471, 256, 2048, 1, [4200, 35000, 180000, 793471], dropout=0.01, proj=True, lm1b=True) - - model.load_state_dict(torch.load(model_file)) - model.full = True # Use real softmax--important! - model.to(device) - model.eval() - model_file.close() - return QueryHandler(model, word_to_idx, mapto, device) \ No newline at end of file +def load_model(lm_folder_path, device): + word_map = torchfile.load(os.path.join(lm_folder_path, "word_map.th7")) + word_map = [w.decode("utf-8") for w in word_map] + word_to_idx = {w: i for i, w in enumerate(word_map)} + word_freq = torchfile.load( + os.path.join(os.path.join(lm_folder_path, "word_freq.th7")) + ) + mapto = torch.from_numpy(util_reverse(np.argsort(-word_freq))).long().to(device) + + model_file = open(os.path.join(lm_folder_path, "lm-state-dict.pt"), "rb") + + model = RNNModel( + "GRU", + 793471, + 256, + 2048, + 1, + [4200, 35000, 180000, 793471], + dropout=0.01, + proj=True, + lm1b=True, + ) + + model.load_state_dict(torch.load(model_file)) + model.full = True # Use real softmax--important! + model.to(device) + model.eval() + model_file.close() + return QueryHandler(model, word_to_idx, mapto, device) diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py b/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py index dd17fa82..05e7ff91 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/learning_to_write.py @@ -1,7 +1,9 @@ -import textattack import torch -from textattack.constraints.grammaticality.language_models import LanguageModelConstraint +import textattack +from textattack.constraints.grammaticality.language_models import ( + LanguageModelConstraint, +) from .language_model_helpers import load_model @@ -24,10 +26,13 @@ class LearningToWriteLanguageModel(LanguageModelConstraint): https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689 """ - CACHE_PATH = 'constraints/grammaticality/language-models/learning-to-write' + CACHE_PATH = "constraints/grammaticality/language-models/learning-to-write" + def __init__(self, window_size=5, **kwargs): self.window_size = window_size - lm_folder_path = textattack.shared.utils.download_if_needed(LearningToWriteLanguageModel.CACHE_PATH) + lm_folder_path = textattack.shared.utils.download_if_needed( + LearningToWriteLanguageModel.CACHE_PATH + ) self.query_handler = load_model(lm_folder_path, textattack.shared.utils.device) super().__init__(**kwargs) @@ -39,7 +44,9 @@ class LearningToWriteLanguageModel(LanguageModelConstraint): query_words = [] for attacked_text in text_list: word = attacked_text.words[word_index] - window_text = attacked_text.text_window_around_index(word_index, self.window_size) + window_text = attacked_text.text_window_around_index( + word_index, self.window_size + ) query = textattack.shared.utils.words_from_text(window_text) queries.append(query) query_words.append(word) diff --git a/textattack/constraints/grammaticality/language_models/learning_to_write/rnn_model.py b/textattack/constraints/grammaticality/language_models/learning_to_write/rnn_model.py index f7fcda2b..31d1af17 100644 --- a/textattack/constraints/grammaticality/language_models/learning_to_write/rnn_model.py +++ b/textattack/constraints/grammaticality/language_models/learning_to_write/rnn_model.py @@ -1,28 +1,44 @@ -import torch.nn as nn from torch.autograd import Variable +import torch.nn as nn from .adaptive_softmax import AdaptiveSoftmax + class RNNModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder. Based on official pytorch examples""" - def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, cutoffs, proj=False, dropout=0.5, tie_weights=False, - lm1b=False): + def __init__( + self, + rnn_type, + ntoken, + ninp, + nhid, + nlayers, + cutoffs, + proj=False, + dropout=0.5, + tie_weights=False, + lm1b=False, + ): super(RNNModel, self).__init__() self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) self.lm1b = lm1b - if rnn_type == 'GRU': + if rnn_type == "GRU": self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) else: try: - nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] + nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] except KeyError: - raise ValueError( """An invalid option for `--model` was supplied, - options are ['GRU', 'RNN_TANH' or 'RNN_RELU']""") - self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) + raise ValueError( + """An invalid option for `--model` was supplied, + options are ['GRU', 'RNN_TANH' or 'RNN_RELU']""" + ) + self.rnn = nn.RNN( + ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout + ) self.proj = proj @@ -71,7 +87,7 @@ class RNNModel(nn.Module): if self.proj: output = self.proj_layer(output) - output = output.view(output.size(0)*output.size(1), output.size(2)) + output = output.view(output.size(0) * output.size(1), output.size(2)) if self.full: decode = self.softmax.log_prob(output) @@ -82,4 +98,4 @@ class RNNModel(nn.Module): def init_hidden(self, bsz): weight = next(self.parameters()).data - return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()) \ No newline at end of file + return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()) diff --git a/textattack/loggers/attack_log_manager.py b/textattack/loggers/attack_log_manager.py index 49c05721..eef13188 100644 --- a/textattack/loggers/attack_log_manager.py +++ b/textattack/loggers/attack_log_manager.py @@ -75,7 +75,7 @@ class AttackLogManager: failed_attacks = 0 skipped_attacks = 0 successful_attacks = 0 - max_words_changed = 0 + max_words_changed = 0 for i, result in enumerate(self.results): all_num_words[i] = len(result.original_result.attacked_text.words) if isinstance(result, FailedAttackResult):