1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

Merge branch 'master' of github.com:UVA-MachineLearningBioinformatics/TextAttack into log-avg-num-words

This commit is contained in:
Jack Morris
2020-02-07 17:11:44 -05:00
11 changed files with 126 additions and 56 deletions

View File

@@ -1,3 +1,4 @@
filelock
language_check
nltk
numpy<1.17

View File

@@ -7,7 +7,7 @@ import time
import tqdm
import os
from textattack_args_helper import *
from run_attack_args_helper import *
def main():
args = get_args()

View File

@@ -85,6 +85,7 @@ CONSTRAINT_CLASS_NAMES = {
}
ATTACK_CLASS_NAMES = {
'beam-search': 'textattack.attack_methods.BeamSearch',
'greedy-word': 'textattack.attack_methods.GreedyWordSwap',
'ga-word': 'textattack.attack_methods.GeneticAlgorithm',
'greedy-word-wir': 'textattack.attack_methods.GreedyWordSwapWIR',
@@ -136,9 +137,10 @@ def get_args():
attack_group = parser.add_mutually_exclusive_group(required=False)
attack_choices = ','.join(ATTACK_CLASS_NAMES.keys())
attack_group.add_argument('--attack', '--attack_method', type=str,
required=False, default='greedy-word-wir',
help='The type of attack to run.', choices=ATTACK_CLASS_NAMES.keys())
help=f'The type of attack to run. choices: {attack_choices}')
attack_group.add_argument('--recipe', type=str, required=False, default=None,
help='full attack recipe (overrides provided transformation & constraints)',

View File

@@ -8,7 +8,7 @@ import time
import torch
import tqdm
from textattack_args_helper import *
from run_attack_args_helper import *
def set_env_variables(gpu_id):
# Only use one GPU, if we have one.
@@ -56,7 +56,7 @@ def main():
load_time = time.time()
if args.interactive:
raise RuntimeException('Cannot run in parallel if --interactive set')
raise RuntimeError('Cannot run in parallel if --interactive set')
in_queue = torch.multiprocessing.Queue()
out_queue = torch.multiprocessing.Queue()

View File

@@ -1,3 +1,4 @@
from .beam_search import BeamSearch
from .greedy_word_swap import GreedyWordSwap
from .greedy_word_swap_wir import GreedyWordSwapWIR
from .genetic_algorithm import GeneticAlgorithm

View File

@@ -0,0 +1,74 @@
from .attack import Attack
from textattack.attack_results import AttackResult, FailedAttackResult
class BeamSearch(Attack):
"""
An attack that greedily chooses from a list of possible
perturbations.
Args:
model (nn.Module): The model to attack.
transformation (Transformation): The type of transformation.
beam_width (int): the number of candidates to retain at each step
max_words_changed (:obj:`int`, optional): The maximum number of words
to change.
"""
def __init__(self, model, transformation, constraints=[], beam_width=8,
max_words_changed=32):
super().__init__(model, transformation, constraints=constraints)
self.beam_width = beam_width
self.max_words_changed = max_words_changed
def attack_one(self, original_label, original_tokenized_text):
max_words_changed = min(
self.max_words_changed,
len(original_tokenized_text.words)
)
original_prob = self._call_model([original_tokenized_text]).max()
default_unswapped_word_indices = list(range(len(original_tokenized_text.words)))
beam = [(original_tokenized_text, default_unswapped_word_indices)]
num_words_changed = 0
new_text_label = original_label
while num_words_changed < max_words_changed:
num_words_changed += 1
potential_next_beam = []
for text, unswapped_word_indices in beam:
transformations = self.get_transformations(
text, indices_to_replace=unswapped_word_indices
)
for next_text in transformations:
new_unswapped_word_indices = unswapped_word_indices.copy()
modified_word_index = next_text.attack_attrs['modified_word_index']
new_unswapped_word_indices.remove(modified_word_index)
potential_next_beam.append((next_text, new_unswapped_word_indices))
if len(potential_next_beam) == 0:
# If we did not find any possible perturbations, give up.
return FailedAttackResult(original_tokenized_text, original_label)
transformed_text_candidates = [text for (text,_) in potential_next_beam]
scores = self._call_model(transformed_text_candidates)
# The best choice is the one that minimizes the original class label.
best_index = scores[:, original_label].argmin()
new_tokenized_text = transformed_text_candidates[best_index]
# If we changed the label, break.
new_text_label = scores[best_index].argmax().item()
if new_text_label != original_label:
new_prob = scores[best_index].max()
break
# Otherwise, refill the beam. This works by sorting the scores from
# the original label in ascending order and filling the beam from
# there.
best_indices = scores[:, original_label].argsort()[:self.beam_width]
beam = [potential_next_beam[i] for i in best_indices]
if original_label == new_text_label:
return FailedAttackResult(original_tokenized_text, original_label)
else:
return AttackResult(
original_tokenized_text,
new_tokenized_text,
original_label,
new_text_label,
float(original_prob),
float(new_prob)
)

View File

@@ -1,59 +1,17 @@
from .attack import Attack
from .beam_search import BeamSearch
from textattack.attack_results import AttackResult, FailedAttackResult
class GreedyWordSwap(Attack):
class GreedyWordSwap(BeamSearch):
"""
An attack that greedily chooses from a list of possible
perturbations.
An attack that greedily chooses from a list of possible perturbations.
Args:
model: The model to attack.
transformation: The type of transformation.
max_depth (:obj:`int`, optional): The maximum number of words to change. Defaults to 32.
max_words_changed (:obj:`int`, optional): The maximum number of words
to change.
"""
def __init__(self, model, transformation, constraints=[], max_depth=32):
super().__init__(model, transformation, constraints=constraints)
self.max_depth = max_depth
def attack_one(self, original_label, tokenized_text):
original_tokenized_text = tokenized_text
original_prob = self._call_model([tokenized_text]).squeeze().max()
num_words_changed = 0
unswapped_word_indices = list(range(len(tokenized_text.words)))
new_tokenized_text = None
new_text_label = None
while num_words_changed <= self.max_depth and len(unswapped_word_indices):
num_words_changed += 1
transformed_text_candidates = self.get_transformations(
tokenized_text,
indices_to_replace=unswapped_word_indices)
if len(transformed_text_candidates) == 0:
# If we did not find any possible perturbations, give up.
break
scores = self._call_model(transformed_text_candidates)
# The best choice is the one that minimizes the original class label.
best_index = scores[:, original_label].argmin()
new_tokenized_text = transformed_text_candidates[best_index]
# If we changed the label, break.
new_text_label = scores[best_index].argmax().item()
if new_text_label != original_label:
new_prob = scores[best_index].max()
break
# Otherwise, remove this word from list of words to change and
# iterate.
word_swap_loc = tokenized_text.first_word_diff_index(new_tokenized_text)
tokenized_text = new_tokenized_text
unswapped_word_indices.remove(word_swap_loc)
if original_label == new_text_label:
return FailedAttackResult(original_tokenized_text, original_label)
else:
return AttackResult(
original_tokenized_text,
new_tokenized_text,
original_label,
new_text_label,
float(original_prob),
float(new_prob)
)
def __init__(self, model, transformation, constraints=[], max_words_changed=32):
super().__init__(model, transformation, constraints=constraints,
beam_width=1, max_words_changed=32)

View File

@@ -1,4 +1,5 @@
import torch
from copy import deepcopy
from .utils import get_device
class TokenizedText:
@@ -30,6 +31,12 @@ class TokenizedText:
self.words = raw_words(text)
self.text = text
self.attack_attrs = attack_attrs
def __eq__(self, other):
return (self.text == other.text) and (self.attack_attrs == other.attack_attrs)
def __hash__(self):
return hash(self.text)
def delete_tensors(self):
""" Delete tensors to clear up GPU space. Only should be called
@@ -150,7 +157,7 @@ class TokenizedText:
text = text[word_end:]
final_sentence += text # Add all of the ending punctuation.
return TokenizedText(final_sentence, self.tokenizer,
attack_attrs=self.attack_attrs)
attack_attrs=deepcopy(self.attack_attrs))
def clean_text(self):
""" Represents self in a clean, printable format. Joins text with multiple

View File

@@ -1,3 +1,4 @@
import filelock
import json
import logging
import os
@@ -34,9 +35,14 @@ def download_if_needed(folder_name):
@TODO: Prevent parallel downloads of the same file with a lock.
"""
# Check if already downloaded.
cache_dest_path = path_in_cache(folder_name)
# Use a lock to prevent concurrent downloads.
cache_dest_lock_path = cache_dest_path + '.lock'
cache_file_lock = filelock.FileLock(cache_dest_lock_path)
cache_file_lock.acquire()
# Check if already downloaded.
if os.path.exists(cache_dest_path):
cache_file_lock.release()
return cache_dest_path
# If the file isn't found yet, download the zip file to the cache.
downloaded_file = tempfile.NamedTemporaryFile(
@@ -51,6 +57,7 @@ def download_if_needed(folder_name):
print('Copying', downloaded_file.name, 'to', cache_dest_path + '.')
os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True)
shutil.copyfile(downloaded_file.name, cache_dest_path)
cache_file_lock.release()
# Remove the temporary file.
os.remove(downloaded_file.name)
print(f'Successfully saved {folder_name} to cache.')

View File

@@ -1,3 +1,4 @@
from .composite_transformation import CompositeTransformation
from .word_swap import WordSwap
from .word_swap_embedding import WordSwapEmbedding
from .word_swap_homoglyph import WordSwapHomoglyph

View File

@@ -0,0 +1,19 @@
import numpy as np
from .transformation import Transformation
class CompositeTransformation(Transformation):
def __init__(self, transformations):
if not isinstance(transformations, list):
raise TypeError('transformations must be a list')
if not len(transformations):
raise ValueError('transformations cannot be empty')
self.transformations = transformations
def __call__(self, *args, **kwargs):
new_tokenized_texts = set()
for transformation in self.transformations:
new_tokenized_texts.update(
transformation(*args, **kwargs)
)
return list(new_tokenized_texts)