import argparse import importlib import numpy as np import os import random import sys import textattack import time import torch import pickle import copy from .attack_args import * def set_seed(random_seed): random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) def get_args(): # Parser for regular arguments parser = argparse.ArgumentParser( description='A commandline parser for TextAttack', formatter_class=argparse.ArgumentDefaultsHelpFormatter) transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()) parser.add_argument('--transformation', type=str, required=False, default='word-swap-embedding', choices=transformation_names, help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}. Choices: ' + str(transformation_names)) model_group = parser.add_mutually_exclusive_group() model_names = list(TEXTATTACK_MODEL_CLASS_NAMES.keys()) + list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(TEXTATTACK_DATASET_BY_MODEL.keys()) model_group.add_argument('--model', type=str, required=False, default='bert-base-uncased-yelp-sentiment', choices=model_names, help='The pre-trained model to attack.') model_group.add_argument('--model-from-file', type=str, required=False, help='File of model and tokenizer to import.') model_group.add_argument('--model-from-huggingface', type=str, required=False, help='huggingface.co ID of pre-trained model to load') dataset_group = parser.add_mutually_exclusive_group() dataset_group.add_argument('--dataset-from-nlp', type=str, required=False, default=None, help='Dataset to load from `nlp` repository.') dataset_group.add_argument('--dataset-from-file', type=str, required=False, default=None, help='Dataset to load from a file.') parser.add_argument('--constraints', type=str, required=False, nargs='*', default=['repeat', 'stopword'], help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ' + str(CONSTRAINT_CLASS_NAMES.keys())) parser.add_argument('--out-dir', type=str, required=False, default=None, help='A directory to output results to.') parser.add_argument('--enable-visdom', action='store_true', help='Enable logging to visdom.') parser.add_argument('--enable-wandb', action='store_true', help='Enable logging to Weights & Biases.') parser.add_argument('--disable-stdout', action='store_true', help='Disable logging to stdout') parser.add_argument('--enable-csv', nargs='?', default=None, const='fancy', type=str, help='Enable logging to csv. Use --enable-csv plain to remove [[]] around words.') parser.add_argument('--num-examples', '-n', type=int, required=False, default='5', help='The number of examples to process.') parser.add_argument('--num-examples-offset', '-o', type=int, required=False, default=0, help='The offset to start at in the dataset.') parser.add_argument('--shuffle', action='store_true', required=False, default=False, help='Randomly shuffle the data before attacking') parser.add_argument('--interactive', action='store_true', default=False, help='Whether to run attacks interactively.') parser.add_argument('--attack-n', action='store_true', default=False, help='Whether to run attack until `n` examples have been attacked (not skipped).') parser.add_argument('--parallel', action='store_true', default=False, help='Run attack using multiple GPUs.') goal_function_choices = ', '.join(GOAL_FUNCTION_CLASS_NAMES.keys()) parser.add_argument('--goal-function', '-g', default='untargeted-classification', help=f'The goal function to use. choices: {goal_function_choices}') def str_to_int(s): return sum((ord(c) for c in s)) parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK')) parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(), help='The directory to save checkpoint files.') parser.add_argument('--checkpoint-interval', required=False, type=int, help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.') parser.add_argument('--query-budget', '-q', type=int, default=float('inf'), help='The maximum number of model queries allowed per example attacked.') attack_group = parser.add_mutually_exclusive_group(required=False) search_choices = ', '.join(SEARCH_CLASS_NAMES.keys()) attack_group.add_argument('--search', '--search-method', '-s', type=str, required=False, default='greedy-word-wir', help=f'The search method to use. choices: {search_choices}') attack_group.add_argument('--recipe', '--attack-recipe', '-r', type=str, required=False, default=None, help='full attack recipe (overrides provided goal function, transformation & constraints)', choices=RECIPE_NAMES.keys()) attack_group.add_argument('--attack-from-file', type=str, required=False, default=None, help='attack to load from file (overrides provided goal function, transformation & constraints)', ) # Parser for parsing args for resume resume_parser = argparse.ArgumentParser( description='A commandline parser for TextAttack', formatter_class=argparse.ArgumentDefaultsHelpFormatter) resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=True, help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'\ 'recover latest checkpoint from either current path or specified directory.') resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=None, help='The directory to save checkpoint files. If not set, use directory from recovered arguments.') resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int, help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.') resume_parser.add_argument('--parallel', action='store_true', default=False, help='Run attack using multiple GPUs.') # Resume attack from checkpoint. if sys.argv[1:] and sys.argv[1].lower() == 'resume': args = resume_parser.parse_args(sys.argv[2:]) setattr(args, 'checkpoint_resume', True) else: command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments. args = parser.parse_args(command_line_args) setattr(args, 'checkpoint_resume', False) if args.checkpoint_interval and args.shuffle: # Not allowed b/c we cannot recover order of shuffled data raise ValueError('Cannot use `--checkpoint-interval` with `--shuffle=True`') set_seed(args.random_seed) # Shortcuts for huggingface models using --model. if args.model in HUGGINGFACE_DATASET_BY_MODEL: args.model_from_huggingface, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model] args.model = None return args def parse_transformation_from_args(args, model): # Transformations transformation_name = args.transformation if ':' in transformation_name: transformation_name, params = transformation_name.split(':') if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES: transformation = eval(f'{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model, {params})') elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES: transformation = eval(f'{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})') else: raise ValueError(f'Error: unsupported transformation {transformation_name}') else: if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES: transformation = eval(f'{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model)') elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES: transformation = eval(f'{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()') else: raise ValueError(f'Error: unsupported transformation {transformation_name}') return transformation def parse_goal_function_from_args(args, model): # Goal Functions goal_function = args.goal_function if ':' in goal_function: goal_function_name, params = goal_function.split(':') if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES: raise ValueError(f'Error: unsupported goal_function {goal_function_name}') goal_function = eval(f'{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model, {params})') elif goal_function in GOAL_FUNCTION_CLASS_NAMES: goal_function = eval(f'{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model)') else: raise ValueError(f'Error: unsupported goal_function {goal_function}') goal_function.query_budget = args.query_budget return goal_function def parse_constraints_from_args(args): # Constraints if not args.constraints: return [] _constraints = [] for constraint in args.constraints: if ':' in constraint: constraint_name, params = constraint.split(':') if constraint_name not in CONSTRAINT_CLASS_NAMES: raise ValueError(f'Error: unsupported constraint {constraint_name}') _constraints.append(eval(f'{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})')) elif constraint in CONSTRAINT_CLASS_NAMES: _constraints.append(eval(f'{CONSTRAINT_CLASS_NAMES[constraint]}()')) else: raise ValueError(f'Error: unsupported constraint {constraint}') return _constraints def parse_attack_from_args(args): model = parse_model_from_args(args) if args.recipe: if ':' in args.recipe: recipe_name, params = args.recipe.split(':') if recipe_name not in RECIPE_NAMES: raise ValueError(f'Error: unsupported recipe {recipe_name}') recipe = eval(f'{RECIPE_NAMES[recipe_name]}(model, {params})') elif args.recipe in RECIPE_NAMES: recipe = eval(f'{RECIPE_NAMES[args.recipe]}(model)') else: raise ValueError(f'Invalid recipe {args.recipe}') recipe.goal_function.query_budget = args.query_budget return recipe elif args.attack_from_file: if ':' in args.attack_from_file: attack_file, attack_name = args.attack_from_file.split(':') else: attack_file, attack_name = args.attack_from_file, 'attack' attack_file = attack_file.replace('.py', '').replace('/', '.') attack_module = importlib.import_module(attack_file) attack_func = getattr(attack_module, attack_name) return attack_func(model) else: goal_function = parse_goal_function_from_args(args, model) transformation = parse_transformation_from_args(args, model) constraints = parse_constraints_from_args(args) if ':' in args.search: search_name, params = args.search.split(':') if search_name not in SEARCH_CLASS_NAMES: raise ValueError(f'Error: unsupported search {search_name}') search_method = eval(f'{SEARCH_CLASS_NAMES[search_name]}({params})') elif args.search in SEARCH_CLASS_NAMES: search_method = eval(f'{SEARCH_CLASS_NAMES[args.search]}()') else: raise ValueError(f'Error: unsupported attack {args.search}') return textattack.shared.Attack(goal_function, constraints, transformation, search_method) def parse_model_from_args(args): if args.model_from_file: colored_model_name = textattack.shared.utils.color_text(args.model_from_file, color='blue', method='ansi') textattack.shared.logger.info(f'Loading model and tokenizer from file: {colored_model_name}') if ':' in args.model_from_file: model_file, model_name, tokenizer_name = args.model_from_file.split(':') else: model_file, model_name, tokenizer_name = args.model_from_file, 'model', 'tokenizer' try: model_file = args.model_from_file.replace('.py', '').replace('/', '.') model_module = importlib.import_module(model_file) except: raise ValueError(f'Failed to import model or tokenizer from file {args.model_from_file}') try: model = getattr(model_module, model_name) except AttributeError: raise AttributeError(f'``{model_name}`` not found in module {args.model_from_file}') try: tokenizer = getattr(model_module, tokenizer_name) except AttributeError: raise AttributeError(f'``{tokenizer_name}`` not found in module {args.model_from_file}') model = model.to(textattack.shared.utils.device) setattr(model, 'tokenizer', tokenizer) elif args.model_from_huggingface: import transformers if ':' in args.model_from_huggingface: model_class, model_name = args.model_from_huggingface.split(':') model_class = eval(f'transformers.{model_class}') else: model_class, model_name = transformers.AutoModelForSequenceClassification, args.model_from_huggingface colored_model_name = textattack.shared.utils.color_text(model_name, color='blue', method='ansi') textattack.shared.logger.info(f'Loading pre-trained model from HuggingFace model repository: {colored_model_name}') model = model_class.from_pretrained(model_name) model = model.to(textattack.shared.utils.device) try: tokenizer = textattack.tokenizers.AutoTokenizer(args.model_from_huggingface) except OSError: textattack.shared.logger.warn(f'AutoTokenizer {args.model_from_huggingface} not found. Defaulting to `bert-base-uncased`') tokenizer = textattack.tokenizers.AutoTokenizer('bert-base-uncased') setattr(model, 'tokenizer', tokenizer) else: if ':' in args.model: model_name, params = args.model.split(':') colored_model_name = textattack.shared.utils.color_text(model_name, color='blue', method='ansi') textattack.shared.logger.info(f'Loading pre-trained TextAttack model: {colored_model_name}') if model_name not in TEXTATTACK_MODEL_CLASS_NAMES: raise ValueError(f'Error: unsupported model {model_name}') model = eval(f'{TEXTATTACK_MODEL_CLASS_NAMES[model_name]}({params})') elif args.model in TEXTATTACK_MODEL_CLASS_NAMES: colored_model_name = textattack.shared.utils.color_text(args.model, color='blue', method='ansi') textattack.shared.logger.info(f'Loading pre-trained TextAttack model: {colored_model_name}') model = eval(f'{TEXTATTACK_MODEL_CLASS_NAMES[args.model]}()') elif args.model in TEXTATTACK_DATASET_BY_MODEL: colored_model_name = textattack.shared.utils.color_text(args.model, color='blue', method='ansi') model_path, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model] if args.model.startswith('lstm'): textattack.shared.logger.info(f'Loading pre-trained TextAttack LSTM: {colored_model_name}') model = textattack.models.helpers.LSTMForClassification(model_path=model_path) elif args.model.startswith('cnn'): textattack.shared.logger.info(f'Loading pre-trained TextAttack CNN: {colored_model_name}') model = textattack.models.helpers.WordCNNForClassification(model_path=model_path) else: raise ValueError(f'Unknown TextAttack pretrained model {args.model}') else: raise ValueError(f'Error: unsupported model {args.model}') return model def parse_dataset_from_args(args): args.dataset_from_nlp = ('glue', 'sst2', 'validation') if args.dataset_from_file: textattack.shared.logger.info(f'Loading model and tokenizer from file: {args.model_from_file}') if ':' in args.dataset_from_file: dataset_file, dataset_name = args.dataset_from_file.split(':') else: dataset_file, dataset_name = args.dataset_from_file, 'dataset' try: dataset_file = dataset_file.replace('.py', '').replace('/', '.') dataset_module = importlib.import_module(dataset_file) except: raise ValueError(f'Failed to import dataset from file {args.dataset_from_file}') try: dataset = getattr(dataset_module, dataset_name) except AttributeError: raise AttributeError(f'``dataset`` not found in module {args.dataset_from_file}') elif args.dataset_from_nlp: dataset_args = args.dataset_from_nlp if ':' in dataset_args: dataset_args = dataset_args.split(':') dataset = textattack.datasets.HuggingFaceNLPDataset(*dataset_args, shuffle=args.shuffle) else: if not args.model: raise ValueError('Must supply pretrained model or dataset') elif args.model in DATASET_BY_MODEL: dataset = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset) else: raise ValueError(f'Error: unsupported model {args.model}') return dataset def parse_logger_from_args(args): # Create logger attack_log_manager = textattack.loggers.AttackLogManager() # Set default output directory to `textattack/outputs`. if not args.out_dir: current_dir = os.path.dirname(os.path.realpath(__file__)) outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs') args.out_dir = os.path.normpath(outputs_dir) # Output file. out_time = int(time.time()*1000) # Output file outfile_name = 'attack-{}.txt'.format(out_time) attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name)) # CSV if args.enable_csv: outfile_name = 'attack-{}.csv'.format(out_time) color_method = None if args.enable_csv == 'plain' else 'file' csv_path = os.path.join(args.out_dir, outfile_name) attack_log_manager.add_output_csv(csv_path, color_method) print('Logging to CSV at path {}.'.format(csv_path)) # Visdom if args.enable_visdom: attack_log_manager.enable_visdom() # Weights & Biases if args.enable_wandb: attack_log_manager.enable_wandb() # Stdout if not args.disable_stdout: attack_log_manager.enable_stdout() return attack_log_manager def parse_checkpoint_from_args(args): file_name = os.path.basename(args.checkpoint_file) if file_name.lower() == 'latest': dir_path = os.path.dirname(args.checkpoint_file) chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith('.ta.chkpt')] assert chkpt_file_names, "Checkpoint directory is empty" timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names] latest_file = str(max(timestamps)) + '.ta.chkpt' checkpoint_path = os.path.join(dir_path, latest_file) else: checkpoint_path = args.checkpoint_file checkpoint = textattack.shared.Checkpoint.load(checkpoint_path) set_seed(checkpoint.args.random_seed) return checkpoint def default_checkpoint_dir(): current_dir = os.path.dirname(os.path.realpath(__file__)) checkpoints_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'checkpoints') return os.path.normpath(checkpoints_dir) def merge_checkpoint_args(saved_args, cmdline_args): """ Merge previously saved arguments for checkpoint and newly entered arguments """ args = copy.deepcopy(saved_args) # Newly entered arguments take precedence args.checkpoint_resume = cmdline_args.checkpoint_resume args.parallel = cmdline_args.parallel # If set, replace if cmdline_args.checkpoint_dir: args.checkpoint_dir = cmdline_args.checkpoint_dir if cmdline_args.checkpoint_interval: args.checkpoint_interval = cmdline_args.checkpoint_interval return args