mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
README; get_logger()->logger
This commit is contained in:
92
README.md
92
README.md
@@ -47,18 +47,48 @@ The [`examples/`](docs/examples/) folder contains notebooks walking through exam
|
||||
|
||||
We also have a command-line interface for running attacks. See help info and list of arguments with `python -m textattack --help`.
|
||||
|
||||
#### Sample Attack Commands
|
||||
|
||||
*TextFooler on an LSTM trained on the MR sentiment classification dataset*:
|
||||
```
|
||||
python -m textattack --recipe textfooler --model bert-mr --num-examples 100
|
||||
```
|
||||
|
||||
*DeepWordBug on BERT trained on the SNLI entailment dataset*:
|
||||
```
|
||||
python -m textattack --model bert-snli --recipe deepwordbug --num-examples 100
|
||||
```
|
||||
|
||||
*Beam search with beam width 4 and word embedding transformation and untargeted goal function on an LSTM*:
|
||||
```
|
||||
python -m textattack --model lstm-mr --num-examples 20 \
|
||||
--search-method beam-search:beam_width=4 --transformation word-swap-embedding \
|
||||
--constraints repeat stopword max-words-perturbed:max_num_words=2 embedding:min_cos_sim=0.8 part-of-speech \
|
||||
--goal-function untargeted-classification
|
||||
```
|
||||
|
||||
*Non-overlapping output attack using a greedy word swap and WordNet word substitutionson T5 English-to-German translation:*
|
||||
```
|
||||
python -m textattack --attack-n --goal-function non-overlapping-output \
|
||||
--model t5-en2de --num-examples 10 --transformation word-swap-wordnet \
|
||||
--constraints edit-distance:12 max-words-perturbed:max_percent=0.75 repeat stopword \
|
||||
--search greedy
|
||||
```
|
||||
|
||||
#### Attack Examples
|
||||
|
||||
### Attacks and Papers Implemented ("Attack Recipes")
|
||||
|
||||
We include attack recipes which build an attack such that only one command line argument has to be passed. To run an attack recipes, run `python -m textattack --recipe [recipe_name]`
|
||||
|
||||
The first are for classification and entailment attacks:
|
||||
The first are for classification tasks, like sentiment classification and entailment:
|
||||
- **textfooler**: Greedy attack with word importance ranking (["Is Bert Really Robust?" (Jin et al., 2019)](https://arxiv.org/abs/1907.11932)).
|
||||
- **alzantot**: Genetic algorithm attack from (["Generating Natural Language Adversarial Examples" (Alzantot et al., 2018)](https://arxiv.org/abs/1804.07998)).
|
||||
- **deepwordbug**: Replace-1 scoring and multi-transformation character-swap attack (["Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers" (Gao et al., 2018)](https://arxiv.org/abs/1801.04354)).
|
||||
- **hotflip**: Beam search and gradient-based word swap (["HotFlip: White-Box Adversarial Examples for Text Classification" (Ebrahimi et al., 2017)](https://arxiv.org/abs/1712.06751)).
|
||||
- **kuleshov**: Greedy search and counterfitted embedding swap (["Adversarial Examples for Natural Language Classification Problems" (Kuleshov et al., 2018)](https://openreview.net/pdf?id=r1QZ3zbAZ)).
|
||||
|
||||
The final is for translation attacks:
|
||||
The final is for sequence-to-sequence models:
|
||||
- **seq2sick**: Greedy attack with goal of changing every word in the output translation. Currently implemented as black-box with plans to change to white-box as done in paper (["Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples" (Cheng et al., 2018)](https://arxiv.org/abs/1803.01128)).
|
||||
|
||||
### Augmenting Text
|
||||
@@ -85,23 +115,71 @@ of a string or a list of strings. Here's an example of how to use the `Embedding
|
||||
|
||||
### TokenizedText
|
||||
|
||||
To allow for word replacement after a sequence has been tokenized, we include a `TokenizedText` object which maintains both a list of tokens and the original text, with punctuation. We use this object in favor of a list of words or just raw text.
|
||||
To allow for word replacement after a sequence has been tokenized, we include a `TokenizedText` object
|
||||
which maintains both a list of tokens and the original text, with punctuation. We use this object in favor of a list of words or just raw text.
|
||||
|
||||
### Models and Datasets
|
||||
|
||||
TextAttack is model-agnostic! Anything that overrides `__call__`, takes in `TokenizedText`, and correctly formats output works. However, TextAttack provides pre-trained models and samples for the following datasets:
|
||||
TextAttack is model-agnostic! You can use `TextAttack` to analyze any model that outputs IDs, tensors, or strings.
|
||||
|
||||
#### Classification:
|
||||
#### HuggingFace `transformers` and `nlp`
|
||||
|
||||
We now provide built-in support for [`transformers` pretrained models](https://huggingface.co/models)
|
||||
and datasets from the [`nlp` package](https://github.com/huggingface/nlp)! Here's an example of loading
|
||||
and attacking a pre-trained model and dataset:
|
||||
|
||||
```
|
||||
python -m textattack --model_from_huggingface distilbert-base-uncased-finetuned-sst-2-english --dataset_from_nlp glue:sst2 --recipe deepwordbug --num-examples 10
|
||||
```
|
||||
|
||||
You can explore other pre-trained models using the `--model_from_huggingface` argument, or other datasets by changing
|
||||
`--dataset_from_nlp`.
|
||||
|
||||
|
||||
#### Loading a model or dataset from a file
|
||||
|
||||
You can easily try out an attack on a local model or dataset sample. To attack a pre-trained model,
|
||||
create a short file that loads them as variables `model` and `tokenizer`. The `tokenizer` must
|
||||
be able to transform string inputs to lists or tensors of IDs using a method called `encode()`. The
|
||||
model must take inputs via the `__call__` method.
|
||||
|
||||
##### Model from a file
|
||||
, you could create the following file
|
||||
and name it `my_model.py`:
|
||||
|
||||
```
|
||||
model = load_model()
|
||||
tokenizer = load_tokenizer()
|
||||
```
|
||||
|
||||
Then, run an attack with the argument `--model_from_file my_model.py`. The model and tokenizer will be loaded automatically.
|
||||
|
||||
#### Dataset from a file
|
||||
|
||||
Loading a dataset from a file is very similar to loading a model from a file. A 'dataset' is any iterable of `(input, output)` pairs.
|
||||
The following example would load a sentiment classification dataset from file `my_dataset.py`:
|
||||
|
||||
```
|
||||
dataset = [('Today was....', 1), ('This movie is...', 0), ...]
|
||||
```
|
||||
|
||||
You can then run attacks on samples from this dataset by adding the argument `--dataset_from_file my_dataset.py`.
|
||||
|
||||
#### TextAttack pre-Trained models
|
||||
|
||||
TextAttack also comes with its own pre-trained models and dataset samples for various tasks:
|
||||
|
||||
##### Classification:
|
||||
* AG News dataset topic classification
|
||||
* IMDB dataset sentiment classification
|
||||
* Movie Review dataset sentiment classification
|
||||
* Yelp dataset sentiment classification
|
||||
|
||||
#### Entailment:
|
||||
##### Entailment:
|
||||
* SNLI datastet
|
||||
* MNLI dataset (matched & unmatched)
|
||||
|
||||
#### Translation:
|
||||
##### Translation:
|
||||
* newstest2013 English to German dataset
|
||||
|
||||
### Attacks
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
click
|
||||
editdistance
|
||||
filelock
|
||||
language_tool_python
|
||||
|
||||
2
setup.py
2
setup.py
@@ -16,7 +16,7 @@ setuptools.setup(
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/QData/textattack",
|
||||
packages=setuptools.find_namespace_packages(exclude=['wandb*', 'build*', 'docs*', 'dist*', 'outputs*', 'tests*', 'local_test*']),
|
||||
packages=setuptools.find_namespace_packages(exclude=['build*', 'docs*', 'dist*', 'outputs*', 'tests*', 'local_test*', 'wandb*']),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
|
||||
@@ -8,7 +8,7 @@ The TextAttack main module:
|
||||
|
||||
|
||||
|
||||
from textattack.shared.scripts.run_attack_args_helper import get_args
|
||||
from textattack.shared.scripts.attack_args_helper import get_args
|
||||
from textattack.shared.scripts.run_attack_parallel import run as run_parallel
|
||||
from textattack.shared.scripts.run_attack_single_threaded import run as run_single_threaded
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class GPT2(LanguageModelConstraint):
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
self.model.to(utils.get_device())
|
||||
self.model.to(utils.device)
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -33,7 +33,7 @@ class GPT2(LanguageModelConstraint):
|
||||
|
||||
token_ids = self.tokenizer.encode(prefix)
|
||||
tokens_tensor = torch.tensor([token_ids])
|
||||
tokens_tensor = tokens_tensor.to(utils.get_device())
|
||||
tokens_tensor = tokens_tensor.to(utils.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(tokens_tensor)
|
||||
|
||||
@@ -12,7 +12,7 @@ class BERT(SentenceEncoder):
|
||||
def __init__(self, threshold=0.7, metric='cosine', **kwargs):
|
||||
super().__init__(threshold=threshold, metric=metric, **kwargs)
|
||||
self.model = SentenceTransformer('bert-base-nli-stsb-mean-tokens')
|
||||
self.model.to(utils.get_device())
|
||||
self.model.to(utils.device)
|
||||
|
||||
def encode(self, sentences):
|
||||
return self.model.encode(sentences)
|
||||
|
||||
@@ -21,7 +21,7 @@ class InferSent(SentenceEncoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model = self.get_infersent_model()
|
||||
self.model.to(utils.get_device())
|
||||
self.model.to(utils.device)
|
||||
|
||||
def get_infersent_model(self):
|
||||
"""
|
||||
|
||||
@@ -72,8 +72,8 @@ class SentenceEncoder(Constraint):
|
||||
starting_embedding, transformed_embedding = self.model.encode(
|
||||
[starting_text_window, transformed_text_window])
|
||||
|
||||
starting_embedding = torch.tensor(starting_embedding).to(utils.get_device())
|
||||
transformed_embedding = torch.tensor(transformed_embedding).to(utils.get_device())
|
||||
starting_embedding = torch.tensor(starting_embedding).to(utils.device)
|
||||
transformed_embedding = torch.tensor(transformed_embedding).to(utils.device)
|
||||
|
||||
starting_embedding = torch.unsqueeze(starting_embedding, dim=0)
|
||||
transformed_embedding = torch.unsqueeze(transformed_embedding, dim=0)
|
||||
@@ -114,25 +114,25 @@ class SentenceEncoder(Constraint):
|
||||
transformed_text.text_window_around_index(modified_index, self.window_size))
|
||||
embeddings = self.encode(starting_text_windows + transformed_text_windows)
|
||||
starting_embeddings = torch.tensor(
|
||||
embeddings[:len(transformed_texts)]).to(utils.get_device())
|
||||
embeddings[:len(transformed_texts)]).to(utils.device)
|
||||
transformed_embeddings = torch.tensor(
|
||||
embeddings[len(transformed_texts):]).to(utils.get_device())
|
||||
embeddings[len(transformed_texts):]).to(utils.device)
|
||||
else:
|
||||
starting_raw_text = starting_text.text
|
||||
transformed_raw_texts = [t.text for t in transformed_texts]
|
||||
embeddings = self.encode([starting_raw_text] + transformed_raw_texts)
|
||||
if isinstance(embeddings[0], torch.Tensor):
|
||||
starting_embedding = embeddings[0].to(utils.get_device())
|
||||
starting_embedding = embeddings[0].to(utils.device)
|
||||
else:
|
||||
# If the embedding is not yet a tensor, make it one.
|
||||
starting_embedding = torch.tensor(embeddings[0]).to(utils.get_device())
|
||||
starting_embedding = torch.tensor(embeddings[0]).to(utils.device)
|
||||
|
||||
if isinstance(embeddings, list):
|
||||
# If `encode` did not return a Tensor of all embeddings, combine
|
||||
# into a tensor.
|
||||
transformed_embeddings = torch.stack(embeddings[1:]).to(utils.get_device())
|
||||
transformed_embeddings = torch.stack(embeddings[1:]).to(utils.device)
|
||||
else:
|
||||
transformed_embeddings = torch.tensor(embeddings[1:]).to(utils.get_device())
|
||||
transformed_embeddings = torch.tensor(embeddings[1:]).to(utils.device)
|
||||
|
||||
# Repeat original embedding to size of perturbed embedding.
|
||||
starting_embeddings = starting_embedding.unsqueeze(dim=0).repeat(len(transformed_embeddings),1)
|
||||
|
||||
@@ -3,21 +3,20 @@ import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
|
||||
from textattack.shared.utils import get_device
|
||||
|
||||
class UniversalSentenceEncoder(SentenceEncoder):
|
||||
"""
|
||||
Constraint using similarity between sentence encodings of x and x_adv where
|
||||
the text embeddings are created using the Universal Sentence Encoder.
|
||||
"""
|
||||
def __init__(self, use_version=4, threshold=0.8, large=True, metric='angular',
|
||||
def __init__(self, threshold=0.8, large=True, metric='angular',
|
||||
**kwargs):
|
||||
if use_version not in [3,4]:
|
||||
raise ValueError(f'Unsupported UniversalSentenceEncoder version {use_version}')
|
||||
super().__init__(threshold=threshold, metric=metric, **kwargs)
|
||||
tfhub_url = 'https://tfhub.dev/google/universal-sentence-encoder{}/{}'.format(
|
||||
'-large' if large else '', use_version)
|
||||
if large:
|
||||
tfhub_url = 'https://tfhub.dev/google/universal-sentence-encoder-large/5'
|
||||
else:
|
||||
tfhub_url = 'https://tfhub.dev/google/universal-sentence-encoder/4'
|
||||
self.model = hub.load(tfhub_url)
|
||||
|
||||
def encode(self, sentences):
|
||||
return self.model(sentences)["outputs"].numpy()
|
||||
return self.model(sentences).numpy()
|
||||
@@ -78,8 +78,8 @@ class WordEmbeddingDistance(Constraint):
|
||||
except KeyError:
|
||||
e1 = self.word_embeddings[a]
|
||||
e2 = self.word_embeddings[b]
|
||||
e1 = torch.tensor(e1).to(utils.get_device())
|
||||
e2 = torch.tensor(e2).to(utils.get_device())
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
cos_sim = torch.nn.CosineSimilarity(dim=0)(e1, e2)
|
||||
self.cos_sim_mat[a][b] = cos_sim
|
||||
return cos_sim
|
||||
@@ -92,8 +92,8 @@ class WordEmbeddingDistance(Constraint):
|
||||
except KeyError:
|
||||
e1 = self.word_embeddings[a]
|
||||
e2 = self.word_embeddings[b]
|
||||
e1 = torch.tensor(e1).to(utils.get_device())
|
||||
e2 = torch.tensor(e2).to(utils.get_device())
|
||||
e1 = torch.tensor(e1).to(utils.device)
|
||||
e2 = torch.tensor(e2).to(utils.device)
|
||||
mse_dist = torch.sum((e1 - e2) ** 2)
|
||||
self.mse_dist_mat[a][b] = mse_dist
|
||||
return mse_dist
|
||||
|
||||
@@ -18,7 +18,7 @@ class BERTForClassification:
|
||||
model_file_path = utils.download_if_needed(model_path)
|
||||
self.model = BertForSequenceClassification.from_pretrained(
|
||||
model_file_path, num_labels=num_labels)
|
||||
self.model.to(utils.get_device())
|
||||
self.model.to(utils.device)
|
||||
self.model.eval()
|
||||
if entailment:
|
||||
self.tokenizer = BERTEntailmentTokenizer()
|
||||
|
||||
@@ -40,7 +40,7 @@ class LSTMForClassification(nn.Module):
|
||||
self.load_state_dict(load_cached_state_dict(model_folder_path))
|
||||
self.word_embeddings = self.emb_layer.embedding
|
||||
self.lookup_table = self.emb_layer.embedding.weight.data
|
||||
self.to(utils.get_device())
|
||||
self.to(utils.device)
|
||||
self.eval()
|
||||
|
||||
def forward(self, _input):
|
||||
|
||||
@@ -27,7 +27,7 @@ class T5ForTextToText:
|
||||
"""
|
||||
def __init__(self, mode='english_to_german', max_length=20, num_beams=1, early_stopping=True):
|
||||
self.model = AutoModelWithLMHead.from_pretrained("t5-base")
|
||||
self.model.to(utils.get_device())
|
||||
self.model.to(utils.device)
|
||||
self.model.eval()
|
||||
self.tokenizer = T5Tokenizer(mode)
|
||||
self.max_length = max_length
|
||||
|
||||
@@ -6,5 +6,5 @@ from textattack.shared import utils
|
||||
def load_cached_state_dict(model_folder_path):
|
||||
model_folder_path = utils.download_if_needed(model_folder_path)
|
||||
model_path = os.path.join(model_folder_path, 'model.bin')
|
||||
state_dict = torch.load(model_path, map_location=utils.get_device())
|
||||
state_dict = torch.load(model_path, map_location=utils.device)
|
||||
return state_dict
|
||||
@@ -30,7 +30,7 @@ class WordCNNForClassification(nn.Module):
|
||||
|
||||
def load_from_disk(self, model_folder_path):
|
||||
self.load_state_dict(load_cached_state_dict(model_folder_path))
|
||||
self.to(utils.get_device())
|
||||
self.to(utils.device)
|
||||
self.eval()
|
||||
|
||||
def forward(self, _input):
|
||||
|
||||
@@ -179,8 +179,6 @@ def get_args():
|
||||
dataset_group = parser.add_mutually_exclusive_group()
|
||||
dataset_group.add_argument('--dataset_from_file', type=str, required=False, default=None,
|
||||
help='Dataset to load from a file.')
|
||||
# TODO edit model benchmarking script to support models/datasets loaded from command-line
|
||||
# TODO add README info about attacking models from file/pretrained huggingface, and dataset
|
||||
|
||||
parser.add_argument('--constraints', type=str, required=False, nargs='*',
|
||||
default=['repeat', 'stopword'],
|
||||
@@ -358,11 +356,13 @@ def parse_model_from_args(args):
|
||||
tokenizer = getattr(model_module, 'tokenizer')
|
||||
except AttributeError:
|
||||
raise AttributeError(f'``tokenizer`` not found in module {args.model_from_file}')
|
||||
model = model.to(textattack.shared.utils.device)
|
||||
setattr(model, 'tokenizer', tokenizer)
|
||||
elif args.model_from_huggingface:
|
||||
import transformers
|
||||
textattack.shared.logger.info(f'Loading pre-trained model from HuggingFace model repository: {args.model_from_huggingface}')
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(args.model_from_huggingface)
|
||||
model = model.to(textattack.shared.utils.device)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_from_huggingface)
|
||||
setattr(model, 'tokenizer', tokenizer)
|
||||
else:
|
||||
@@ -3,9 +3,7 @@ import textattack
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from run_attack_args_helper import *
|
||||
|
||||
import textattack.models as models
|
||||
from attack_args_helper import get_args, parse_model_from_args, parse_dataset_from_args
|
||||
|
||||
def _cb(s): return textattack.shared.utils.color_text(str(s), color='blue', method='stdout')
|
||||
def _cg(s): return textattack.shared.utils.color_text(str(s), color='green', method='stdout')
|
||||
@@ -15,24 +13,27 @@ def _pb(): print(_cg('-' * 60))
|
||||
from collections import Counter
|
||||
|
||||
def get_num_successes(model, ids, true_labels):
|
||||
ids = textattack.shared.utils.preprocess_ids(ids)
|
||||
id_dim = torch.tensor(ids).ndim
|
||||
if id_dim == 2:
|
||||
# For models where the input is a single vector.
|
||||
ids = torch.tensor(ids).to(textattack.shared.utils.get_device())
|
||||
ids = torch.tensor(ids).to(textattack.shared.utils.device)
|
||||
preds = model(ids)
|
||||
elif id_dim == 3:
|
||||
# For models that take multiple vectors per input.
|
||||
ids = map(torch.tensor, zip(*ids))
|
||||
ids = (x.to(textattack.shared.utils.get_device()) for x in ids)
|
||||
ids = (x.to(textattack.shared.utils.device) for x in ids)
|
||||
preds = model(*ids)
|
||||
else:
|
||||
raise TypeError(f'Error: malformed id_dim ({id_dim})')
|
||||
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.get_device())
|
||||
true_labels = torch.tensor(true_labels).to(textattack.shared.utils.device)
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
guess_labels = preds.argmax(dim=1)
|
||||
successes = (guess_labels == true_labels).sum().item()
|
||||
return successes, true_labels, guess_labels
|
||||
|
||||
def test_model_on_dataset(model, dataset, batch_size=16, num_examples=100):
|
||||
def test_model_on_dataset(model, dataset, batch_size=16, num_examples=1000):
|
||||
succ = 0
|
||||
fail = 0
|
||||
batch_ids = []
|
||||
@@ -66,23 +67,11 @@ def test_model_on_dataset(model, dataset, batch_size=16, num_examples=100):
|
||||
print(f'Successes {succ}/{succ+fail} ({_cb(perc)})')
|
||||
return perc
|
||||
|
||||
def test_all_models(num_examples):
|
||||
_pb()
|
||||
for model_name in MODEL_CLASS_NAMES:
|
||||
model = eval(MODEL_CLASS_NAMES[model_name])()
|
||||
dataset = DATASET_BY_MODEL[model_name]()
|
||||
print(f'Testing {_cr(model_name)} on {_cr(type(dataset))}...')
|
||||
test_model_on_dataset(model, dataset, num_examples=num_examples)
|
||||
_pb()
|
||||
# @TODO print the grid of models/dataset names with results in a nice table :)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--n', type=int, default=100,
|
||||
help="number of examples to test on")
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
args = get_args()
|
||||
|
||||
model = parse_model_from_args(args)
|
||||
dataset = parse_dataset_from_args(args)
|
||||
|
||||
with torch.no_grad():
|
||||
test_all_models(args.n)
|
||||
test_model_on_dataset(model, dataset, num_examples=args.num_examples)
|
||||
@@ -8,7 +8,7 @@ import time
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .run_attack_args_helper import *
|
||||
from .attack_args_helper import *
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import tqdm
|
||||
import os
|
||||
import datetime
|
||||
|
||||
from .run_attack_args_helper import *
|
||||
from .attack_args_helper import *
|
||||
|
||||
logger = textattack.shared.logger
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from .utils import get_device, words_from_text
|
||||
from .utils import device, words_from_text
|
||||
|
||||
class TokenizedText:
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
|
||||
def get_device():
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def html_style_from_dict(style_dict):
|
||||
""" Turns
|
||||
|
||||
@@ -57,7 +57,7 @@ def validate_model_goal_function_compatibility(goal_function_class, model_class)
|
||||
|
||||
# Otherwise, this is an unknown model–perhaps user-provided, or we forgot to
|
||||
# update the corresponding dictionary. Warn user and return.
|
||||
logger.warn(f'Unknown if model of class {model_class} compatible with goal function {goal_function}.')
|
||||
logger.warn(f'Unknown if model of class {model_class} compatible with goal function {goal_function_class}.')
|
||||
return True
|
||||
|
||||
def validate_model_gradient_word_swap_compatibility(model):
|
||||
|
||||
@@ -47,7 +47,7 @@ class WordSwapGradientBased(Transformation):
|
||||
"""
|
||||
self.model.train()
|
||||
|
||||
lookup_table = self.model.lookup_table.to(utils.get_device())
|
||||
lookup_table = self.model.lookup_table.to(utils.device)
|
||||
lookup_table_transpose = lookup_table.transpose(0,1)
|
||||
|
||||
# set backward hook on the word embeddings for input x
|
||||
@@ -56,12 +56,12 @@ class WordSwapGradientBased(Transformation):
|
||||
self.model.zero_grad()
|
||||
predictions = self._call_model(text)
|
||||
original_label = predictions.argmax()
|
||||
y_true = torch.Tensor([original_label]).long().to(utils.get_device())
|
||||
y_true = torch.Tensor([original_label]).long().to(utils.device)
|
||||
loss = self.loss(predictions, y_true)
|
||||
loss.backward()
|
||||
|
||||
# grad w.r.t to word embeddings
|
||||
emb_grad = emb_hook.output[0].to(utils.get_device()).squeeze()
|
||||
emb_grad = emb_hook.output[0].to(utils.device).squeeze()
|
||||
|
||||
# grad differences between all flips and original word (eq. 1 from paper)
|
||||
vocab_size = lookup_table.size(0)
|
||||
@@ -127,8 +127,8 @@ class Hook:
|
||||
self.hook = module.register_forward_hook(self.hook_fn)
|
||||
|
||||
def hook_fn(self, module, input, output):
|
||||
self.input = [x.to(utils.get_device()) for x in input]
|
||||
self.output = [x.to(utils.get_device()) for x in output]
|
||||
self.input = [x.to(utils.device) for x in input]
|
||||
self.output = [x.to(utils.device) for x in output]
|
||||
|
||||
def close(self):
|
||||
self.hook.remove()
|
||||
|
||||
Reference in New Issue
Block a user