mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
Merge branch 'master' of github.com:UVA-MachineLearningBioinformatics/TextAttack into log-avg-num-words
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
filelock
|
||||
language_check
|
||||
nltk
|
||||
numpy<1.17
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import tqdm
|
||||
import os
|
||||
|
||||
from textattack_args_helper import *
|
||||
from run_attack_args_helper import *
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
@@ -85,6 +85,7 @@ CONSTRAINT_CLASS_NAMES = {
|
||||
}
|
||||
|
||||
ATTACK_CLASS_NAMES = {
|
||||
'beam-search': 'textattack.attack_methods.BeamSearch',
|
||||
'greedy-word': 'textattack.attack_methods.GreedyWordSwap',
|
||||
'ga-word': 'textattack.attack_methods.GeneticAlgorithm',
|
||||
'greedy-word-wir': 'textattack.attack_methods.GreedyWordSwapWIR',
|
||||
@@ -136,9 +137,10 @@ def get_args():
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
|
||||
attack_choices = ','.join(ATTACK_CLASS_NAMES.keys())
|
||||
attack_group.add_argument('--attack', '--attack_method', type=str,
|
||||
required=False, default='greedy-word-wir',
|
||||
help='The type of attack to run.', choices=ATTACK_CLASS_NAMES.keys())
|
||||
help=f'The type of attack to run. choices: {attack_choices}')
|
||||
|
||||
attack_group.add_argument('--recipe', type=str, required=False, default=None,
|
||||
help='full attack recipe (overrides provided transformation & constraints)',
|
||||
@@ -8,7 +8,7 @@ import time
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from textattack_args_helper import *
|
||||
from run_attack_args_helper import *
|
||||
|
||||
def set_env_variables(gpu_id):
|
||||
# Only use one GPU, if we have one.
|
||||
@@ -56,7 +56,7 @@ def main():
|
||||
load_time = time.time()
|
||||
|
||||
if args.interactive:
|
||||
raise RuntimeException('Cannot run in parallel if --interactive set')
|
||||
raise RuntimeError('Cannot run in parallel if --interactive set')
|
||||
|
||||
in_queue = torch.multiprocessing.Queue()
|
||||
out_queue = torch.multiprocessing.Queue()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .beam_search import BeamSearch
|
||||
from .greedy_word_swap import GreedyWordSwap
|
||||
from .greedy_word_swap_wir import GreedyWordSwapWIR
|
||||
from .genetic_algorithm import GeneticAlgorithm
|
||||
|
||||
74
textattack/attack_methods/beam_search.py
Normal file
74
textattack/attack_methods/beam_search.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from .attack import Attack
|
||||
from textattack.attack_results import AttackResult, FailedAttackResult
|
||||
|
||||
class BeamSearch(Attack):
|
||||
"""
|
||||
An attack that greedily chooses from a list of possible
|
||||
perturbations.
|
||||
Args:
|
||||
model (nn.Module): The model to attack.
|
||||
transformation (Transformation): The type of transformation.
|
||||
beam_width (int): the number of candidates to retain at each step
|
||||
max_words_changed (:obj:`int`, optional): The maximum number of words
|
||||
to change.
|
||||
|
||||
"""
|
||||
def __init__(self, model, transformation, constraints=[], beam_width=8,
|
||||
max_words_changed=32):
|
||||
super().__init__(model, transformation, constraints=constraints)
|
||||
self.beam_width = beam_width
|
||||
self.max_words_changed = max_words_changed
|
||||
|
||||
def attack_one(self, original_label, original_tokenized_text):
|
||||
max_words_changed = min(
|
||||
self.max_words_changed,
|
||||
len(original_tokenized_text.words)
|
||||
)
|
||||
original_prob = self._call_model([original_tokenized_text]).max()
|
||||
default_unswapped_word_indices = list(range(len(original_tokenized_text.words)))
|
||||
beam = [(original_tokenized_text, default_unswapped_word_indices)]
|
||||
num_words_changed = 0
|
||||
new_text_label = original_label
|
||||
while num_words_changed < max_words_changed:
|
||||
num_words_changed += 1
|
||||
potential_next_beam = []
|
||||
for text, unswapped_word_indices in beam:
|
||||
transformations = self.get_transformations(
|
||||
text, indices_to_replace=unswapped_word_indices
|
||||
)
|
||||
for next_text in transformations:
|
||||
new_unswapped_word_indices = unswapped_word_indices.copy()
|
||||
modified_word_index = next_text.attack_attrs['modified_word_index']
|
||||
new_unswapped_word_indices.remove(modified_word_index)
|
||||
potential_next_beam.append((next_text, new_unswapped_word_indices))
|
||||
if len(potential_next_beam) == 0:
|
||||
# If we did not find any possible perturbations, give up.
|
||||
return FailedAttackResult(original_tokenized_text, original_label)
|
||||
transformed_text_candidates = [text for (text,_) in potential_next_beam]
|
||||
scores = self._call_model(transformed_text_candidates)
|
||||
# The best choice is the one that minimizes the original class label.
|
||||
best_index = scores[:, original_label].argmin()
|
||||
new_tokenized_text = transformed_text_candidates[best_index]
|
||||
# If we changed the label, break.
|
||||
new_text_label = scores[best_index].argmax().item()
|
||||
if new_text_label != original_label:
|
||||
new_prob = scores[best_index].max()
|
||||
break
|
||||
# Otherwise, refill the beam. This works by sorting the scores from
|
||||
# the original label in ascending order and filling the beam from
|
||||
# there.
|
||||
best_indices = scores[:, original_label].argsort()[:self.beam_width]
|
||||
beam = [potential_next_beam[i] for i in best_indices]
|
||||
|
||||
|
||||
if original_label == new_text_label:
|
||||
return FailedAttackResult(original_tokenized_text, original_label)
|
||||
else:
|
||||
return AttackResult(
|
||||
original_tokenized_text,
|
||||
new_tokenized_text,
|
||||
original_label,
|
||||
new_text_label,
|
||||
float(original_prob),
|
||||
float(new_prob)
|
||||
)
|
||||
@@ -1,59 +1,17 @@
|
||||
from .attack import Attack
|
||||
from .beam_search import BeamSearch
|
||||
from textattack.attack_results import AttackResult, FailedAttackResult
|
||||
|
||||
class GreedyWordSwap(Attack):
|
||||
class GreedyWordSwap(BeamSearch):
|
||||
"""
|
||||
An attack that greedily chooses from a list of possible
|
||||
perturbations.
|
||||
An attack that greedily chooses from a list of possible perturbations.
|
||||
Args:
|
||||
model: The model to attack.
|
||||
transformation: The type of transformation.
|
||||
max_depth (:obj:`int`, optional): The maximum number of words to change. Defaults to 32.
|
||||
max_words_changed (:obj:`int`, optional): The maximum number of words
|
||||
to change.
|
||||
|
||||
"""
|
||||
def __init__(self, model, transformation, constraints=[], max_depth=32):
|
||||
super().__init__(model, transformation, constraints=constraints)
|
||||
self.max_depth = max_depth
|
||||
|
||||
def attack_one(self, original_label, tokenized_text):
|
||||
original_tokenized_text = tokenized_text
|
||||
original_prob = self._call_model([tokenized_text]).squeeze().max()
|
||||
num_words_changed = 0
|
||||
unswapped_word_indices = list(range(len(tokenized_text.words)))
|
||||
new_tokenized_text = None
|
||||
new_text_label = None
|
||||
while num_words_changed <= self.max_depth and len(unswapped_word_indices):
|
||||
num_words_changed += 1
|
||||
transformed_text_candidates = self.get_transformations(
|
||||
tokenized_text,
|
||||
indices_to_replace=unswapped_word_indices)
|
||||
if len(transformed_text_candidates) == 0:
|
||||
# If we did not find any possible perturbations, give up.
|
||||
break
|
||||
scores = self._call_model(transformed_text_candidates)
|
||||
# The best choice is the one that minimizes the original class label.
|
||||
best_index = scores[:, original_label].argmin()
|
||||
new_tokenized_text = transformed_text_candidates[best_index]
|
||||
# If we changed the label, break.
|
||||
new_text_label = scores[best_index].argmax().item()
|
||||
if new_text_label != original_label:
|
||||
new_prob = scores[best_index].max()
|
||||
break
|
||||
# Otherwise, remove this word from list of words to change and
|
||||
# iterate.
|
||||
word_swap_loc = tokenized_text.first_word_diff_index(new_tokenized_text)
|
||||
tokenized_text = new_tokenized_text
|
||||
unswapped_word_indices.remove(word_swap_loc)
|
||||
|
||||
|
||||
if original_label == new_text_label:
|
||||
return FailedAttackResult(original_tokenized_text, original_label)
|
||||
else:
|
||||
return AttackResult(
|
||||
original_tokenized_text,
|
||||
new_tokenized_text,
|
||||
original_label,
|
||||
new_text_label,
|
||||
float(original_prob),
|
||||
float(new_prob)
|
||||
)
|
||||
def __init__(self, model, transformation, constraints=[], max_words_changed=32):
|
||||
super().__init__(model, transformation, constraints=constraints,
|
||||
beam_width=1, max_words_changed=32)
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from .utils import get_device
|
||||
|
||||
class TokenizedText:
|
||||
@@ -30,6 +31,12 @@ class TokenizedText:
|
||||
self.words = raw_words(text)
|
||||
self.text = text
|
||||
self.attack_attrs = attack_attrs
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.text == other.text) and (self.attack_attrs == other.attack_attrs)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.text)
|
||||
|
||||
def delete_tensors(self):
|
||||
""" Delete tensors to clear up GPU space. Only should be called
|
||||
@@ -150,7 +157,7 @@ class TokenizedText:
|
||||
text = text[word_end:]
|
||||
final_sentence += text # Add all of the ending punctuation.
|
||||
return TokenizedText(final_sentence, self.tokenizer,
|
||||
attack_attrs=self.attack_attrs)
|
||||
attack_attrs=deepcopy(self.attack_attrs))
|
||||
|
||||
def clean_text(self):
|
||||
""" Represents self in a clean, printable format. Joins text with multiple
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import filelock
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -34,9 +35,14 @@ def download_if_needed(folder_name):
|
||||
|
||||
@TODO: Prevent parallel downloads of the same file with a lock.
|
||||
"""
|
||||
# Check if already downloaded.
|
||||
cache_dest_path = path_in_cache(folder_name)
|
||||
# Use a lock to prevent concurrent downloads.
|
||||
cache_dest_lock_path = cache_dest_path + '.lock'
|
||||
cache_file_lock = filelock.FileLock(cache_dest_lock_path)
|
||||
cache_file_lock.acquire()
|
||||
# Check if already downloaded.
|
||||
if os.path.exists(cache_dest_path):
|
||||
cache_file_lock.release()
|
||||
return cache_dest_path
|
||||
# If the file isn't found yet, download the zip file to the cache.
|
||||
downloaded_file = tempfile.NamedTemporaryFile(
|
||||
@@ -51,6 +57,7 @@ def download_if_needed(folder_name):
|
||||
print('Copying', downloaded_file.name, 'to', cache_dest_path + '.')
|
||||
os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True)
|
||||
shutil.copyfile(downloaded_file.name, cache_dest_path)
|
||||
cache_file_lock.release()
|
||||
# Remove the temporary file.
|
||||
os.remove(downloaded_file.name)
|
||||
print(f'Successfully saved {folder_name} to cache.')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .composite_transformation import CompositeTransformation
|
||||
from .word_swap import WordSwap
|
||||
from .word_swap_embedding import WordSwapEmbedding
|
||||
from .word_swap_homoglyph import WordSwapHomoglyph
|
||||
|
||||
19
textattack/transformations/composite_transformation.py
Normal file
19
textattack/transformations/composite_transformation.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import numpy as np
|
||||
from .transformation import Transformation
|
||||
|
||||
class CompositeTransformation(Transformation):
|
||||
def __init__(self, transformations):
|
||||
if not isinstance(transformations, list):
|
||||
raise TypeError('transformations must be a list')
|
||||
if not len(transformations):
|
||||
raise ValueError('transformations cannot be empty')
|
||||
self.transformations = transformations
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
new_tokenized_texts = set()
|
||||
for transformation in self.transformations:
|
||||
new_tokenized_texts.update(
|
||||
transformation(*args, **kwargs)
|
||||
)
|
||||
return list(new_tokenized_texts)
|
||||
|
||||
Reference in New Issue
Block a user