mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge attack_str with master
This commit is contained in:
@@ -75,6 +75,7 @@ def run(args):
|
||||
# Log results asynchronously and update progress bar.
|
||||
num_results = 0
|
||||
pbar = tqdm.tqdm(total=args.num_examples, smoothing=0)
|
||||
num_successes = 0
|
||||
while num_results < args.num_examples:
|
||||
result = out_queue.get(block=True)
|
||||
if isinstance(result, Exception):
|
||||
@@ -83,6 +84,9 @@ def run(args):
|
||||
if (not args.attack_n) or (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
|
||||
pbar.update()
|
||||
num_results += 1
|
||||
if (not isinstance(result, textattack.attack_results.FailedAttackResult)) and (not isinstance(result, textattack.attack_results.SkippedAttackResult)):
|
||||
num_successes += 1
|
||||
pbar.set_description('Successes: {} / {}'.format(num_successes, num_results))
|
||||
else:
|
||||
label, text = next(dataset)
|
||||
in_queue.put((label, text))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import lru
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
@@ -37,6 +38,7 @@ class Attack:
|
||||
self.transformation = transformation
|
||||
self.constraints = constraints
|
||||
self.is_black_box = is_black_box
|
||||
self.constraints_cache = lru.LRU(2**18)
|
||||
|
||||
def get_transformations(self, text, original_text=None,
|
||||
apply_constraints=True, **kwargs):
|
||||
@@ -61,7 +63,7 @@ class Attack:
|
||||
return self._filter_transformations(transformations, text, original_text)
|
||||
return transformations
|
||||
|
||||
def _filter_transformations(self, transformations, text, original_text=None):
|
||||
def _filter_transformations_uncached(self, original_transformations, text, original_text=None):
|
||||
""" Filters a list of potential perturbations based on a list of
|
||||
transformations.
|
||||
|
||||
@@ -71,11 +73,34 @@ class Attack:
|
||||
text (list: TokenizedText): a list of TokenizedText objects
|
||||
representation potential perturbations
|
||||
"""
|
||||
transformations = original_transformations[:]
|
||||
for C in self.constraints:
|
||||
if len(transformations) == 0: break
|
||||
transformations = C.call_many(text, transformations, original_text)
|
||||
transformations_mask = C.call_many(text, transformations, original_text)
|
||||
# Default to false for all original transformations.
|
||||
for original_transformation in original_transformations:
|
||||
self.constraints_cache[original_transformation] = False
|
||||
# Set unfiltered transformations to True in the cache.
|
||||
for successful_transformation in transformations:
|
||||
self.constraints_cache[successful_transformation] = True
|
||||
return transformations
|
||||
|
||||
def _filter_transformations(self, transformations, text, original_text=None):
|
||||
""" Filters a list of potential perturbations based on a list of
|
||||
transformations. Checks cache first.
|
||||
|
||||
Args:
|
||||
transformations (list: function): a list of transformations
|
||||
that filter a list of candidate perturbations
|
||||
text (list: TokenizedText): a list of TokenizedText objects
|
||||
representation potential perturbations
|
||||
"""
|
||||
# Populate cache with transformations.
|
||||
uncached_transformations = [t for t in transformations if (t not in self.constraints_cache)]
|
||||
self._filter_transformations_uncached(uncached_transformations, text, original_text=original_text)
|
||||
# Return transformations from cache.
|
||||
return [t for t in transformations if self.constraints_cache[t]]
|
||||
|
||||
def attack_one(self, tokenized_text):
|
||||
"""
|
||||
Perturbs `tokenized_text` to until goal is reached.
|
||||
|
||||
@@ -13,10 +13,14 @@ class GoalFunction:
|
||||
Args:
|
||||
model: The PyTorch or TensorFlow model used for evaluation.
|
||||
"""
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, use_cache=True):
|
||||
self.model = model
|
||||
self.use_cache = use_cache
|
||||
self.num_queries = 0
|
||||
if self.use_cache:
|
||||
self._call_model_cache = lru.LRU(2**18)
|
||||
else:
|
||||
self._call_model_cache = None
|
||||
|
||||
def should_skip(self, tokenized_text, correct_output):
|
||||
model_outputs = self._call_model([tokenized_text])
|
||||
@@ -54,7 +58,11 @@ class GoalFunction:
|
||||
if not len(tokenized_text_list):
|
||||
return torch.tensor([])
|
||||
ids = [t.ids for t in tokenized_text_list]
|
||||
ids = torch.tensor(ids).to(utils.get_device())
|
||||
if hasattr(self.model, 'model'):
|
||||
model_device = next(self.model.model.parameters()).device
|
||||
else:
|
||||
model_device = next(self.model.parameters()).device
|
||||
ids = torch.tensor(ids).to(model_device)
|
||||
#
|
||||
# shape of `ids` is (n, m, d)
|
||||
# - n: number of elements in `tokenized_text_list`
|
||||
@@ -73,6 +81,8 @@ class GoalFunction:
|
||||
batch = [batch_ids[:, x, :] for x in range(num_fields)]
|
||||
with torch.no_grad():
|
||||
preds = self.model(*batch)
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
scores.append(preds)
|
||||
scores = torch.cat(scores, dim=0)
|
||||
# Validation check on model score dimensions
|
||||
@@ -93,6 +103,8 @@ class GoalFunction:
|
||||
# set of numbers corresponding to probabilities, which should add
|
||||
# up to 1. Since they are `torch.float` values, allow a small
|
||||
# error in the summation.
|
||||
scores = torch.nn.functional.softmax(scores, dim=1)
|
||||
if not ((scores.sum(dim=1) - 1).abs() < 1e-6).all():
|
||||
raise ValueError('Model scores do not add up to 1.')
|
||||
return scores
|
||||
|
||||
@@ -109,6 +121,9 @@ class GoalFunction:
|
||||
# function, then `self.num_queries` will not have been initialized.
|
||||
# In this case, just continue.
|
||||
pass
|
||||
if not self.use_cache:
|
||||
return self._call_model_uncached(tokenized_text_list)
|
||||
else:
|
||||
uncached_list = [text for text in tokenized_text_list if text not in self._call_model_cache]
|
||||
scores = self._call_model_uncached(uncached_list)
|
||||
for text, score in zip(uncached_list, scores):
|
||||
|
||||
@@ -16,7 +16,6 @@ class BERTForClassification:
|
||||
|
||||
"""
|
||||
def __init__(self, model_path, num_labels=2, entailment=False):
|
||||
#model_file_path = utils.download_if_needed(model_path)
|
||||
model_file_path = utils.download_if_needed(model_path)
|
||||
self.model = BertForSequenceClassification.from_pretrained(
|
||||
model_file_path, num_labels=num_labels)
|
||||
|
||||
@@ -6,7 +6,7 @@ class BERTTokenizer(Tokenizer):
|
||||
any type of tokenization, be it word, wordpiece, or character-based.
|
||||
"""
|
||||
def __init__(self, model_path='bert-base-uncased', max_seq_length=256, fast=False):
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
if fast:
|
||||
# Faster tokenizer that is implemented in Rust
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained(model_path)
|
||||
|
||||
Reference in New Issue
Block a user