1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/shared/scripts/attack_args_parser.py

644 lines
22 KiB
Python

import argparse
import copy
import importlib
import os
import pickle
import random
import sys
import time
import numpy as np
import torch
import textattack
from .attack_args import *
def set_seed(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
def get_args():
# Parser for regular arguments
parser = argparse.ArgumentParser(
description="A commandline parser for TextAttack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
)
parser.add_argument(
"--transformation",
type=str,
required=False,
default="word-swap-embedding",
choices=transformation_names,
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}. Choices: '
+ str(transformation_names),
)
model_group = parser.add_mutually_exclusive_group()
model_names = (
list(TEXTATTACK_MODEL_CLASS_NAMES.keys())
+ list(HUGGINGFACE_DATASET_BY_MODEL.keys())
+ list(TEXTATTACK_DATASET_BY_MODEL.keys())
)
model_group.add_argument(
"--model",
type=str,
required=False,
default="bert-base-uncased-yelp-sentiment",
choices=model_names,
help="The pre-trained model to attack.",
)
model_group.add_argument(
"--model-from-file",
type=str,
required=False,
help="File of model and tokenizer to import.",
)
model_group.add_argument(
"--model-from-huggingface",
type=str,
required=False,
help="huggingface.co ID of pre-trained model to load",
)
dataset_group = parser.add_mutually_exclusive_group()
dataset_group.add_argument(
"--dataset-from-nlp",
type=str,
required=False,
default=None,
help="Dataset to load from `nlp` repository.",
)
dataset_group.add_argument(
"--dataset-from-file",
type=str,
required=False,
default=None,
help="Dataset to load from a file.",
)
parser.add_argument(
"--constraints",
type=str,
required=False,
nargs="*",
default=["repeat", "stopword"],
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
+ str(CONSTRAINT_CLASS_NAMES.keys()),
)
parser.add_argument(
"--out-dir",
type=str,
required=False,
default=None,
help="A directory to output results to.",
)
parser.add_argument(
"--enable-visdom", action="store_true", help="Enable logging to visdom."
)
parser.add_argument(
"--enable-wandb",
action="store_true",
help="Enable logging to Weights & Biases.",
)
parser.add_argument(
"--disable-stdout", action="store_true", help="Disable logging to stdout"
)
parser.add_argument(
"--enable-csv",
nargs="?",
default=None,
const="fancy",
type=str,
help="Enable logging to csv. Use --enable-csv plain to remove [[]] around words.",
)
parser.add_argument(
"--num-examples",
"-n",
type=int,
required=False,
default="5",
help="The number of examples to process.",
)
parser.add_argument(
"--num-examples-offset",
"-o",
type=int,
required=False,
default=0,
help="The offset to start at in the dataset.",
)
parser.add_argument(
"--shuffle",
action="store_true",
required=False,
default=False,
help="Randomly shuffle the data before attacking",
)
parser.add_argument(
"--interactive",
action="store_true",
default=False,
help="Whether to run attacks interactively.",
)
parser.add_argument(
"--attack-n",
action="store_true",
default=False,
help="Whether to run attack until `n` examples have been attacked (not skipped).",
)
parser.add_argument(
"--parallel",
action="store_true",
default=False,
help="Run attack using multiple GPUs.",
)
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
parser.add_argument(
"--goal-function",
"-g",
default="untargeted-classification",
help=f"The goal function to use. choices: {goal_function_choices}",
)
def str_to_int(s):
return sum((ord(c) for c in s))
parser.add_argument("--random-seed", default=str_to_int("TEXTATTACK"))
parser.add_argument(
"--checkpoint-dir",
required=False,
type=str,
default=default_checkpoint_dir(),
help="The directory to save checkpoint files.",
)
parser.add_argument(
"--checkpoint-interval",
required=False,
type=int,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
parser.add_argument(
"--query-budget",
"-q",
type=int,
default=float("inf"),
help="The maximum number of model queries allowed per example attacked.",
)
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ", ".join(SEARCH_CLASS_NAMES.keys())
attack_group.add_argument(
"--search",
"--search-method",
"-s",
type=str,
required=False,
default="greedy-word-wir",
help=f"The search method to use. choices: {search_choices}",
)
attack_group.add_argument(
"--recipe",
"--attack-recipe",
"-r",
type=str,
required=False,
default=None,
help="full attack recipe (overrides provided goal function, transformation & constraints)",
choices=RECIPE_NAMES.keys(),
)
attack_group.add_argument(
"--attack-from-file",
type=str,
required=False,
default=None,
help="attack to load from file (overrides provided goal function, transformation & constraints)",
)
# Parser for parsing args for resume
resume_parser = argparse.ArgumentParser(
description="A commandline parser for TextAttack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
resume_parser.add_argument(
"--checkpoint-file",
"-f",
type=str,
required=True,
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'
"recover latest checkpoint from either current path or specified directory.",
)
resume_parser.add_argument(
"--checkpoint-dir",
"-d",
required=False,
type=str,
default=None,
help="The directory to save checkpoint files. If not set, use directory from recovered arguments.",
)
resume_parser.add_argument(
"--checkpoint-interval",
"-i",
required=False,
type=int,
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
)
resume_parser.add_argument(
"--parallel",
action="store_true",
default=False,
help="Run attack using multiple GPUs.",
)
# Resume attack from checkpoint.
if sys.argv[1:] and sys.argv[1].lower() == "resume":
args = resume_parser.parse_args(sys.argv[2:])
setattr(args, "checkpoint_resume", True)
else:
command_line_args = (
None if sys.argv[1:] else ["-h"]
) # Default to help with empty arguments.
args = parser.parse_args(command_line_args)
setattr(args, "checkpoint_resume", False)
if args.checkpoint_interval and args.shuffle:
# Not allowed b/c we cannot recover order of shuffled data
raise ValueError("Cannot use `--checkpoint-interval` with `--shuffle=True`")
set_seed(args.random_seed)
# Shortcuts for huggingface models using --model.
if args.model in HUGGINGFACE_DATASET_BY_MODEL:
(
args.model_from_huggingface,
args.dataset_from_nlp,
) = HUGGINGFACE_DATASET_BY_MODEL[args.model]
args.model = None
return args
def parse_transformation_from_args(args, model):
# Transformations
transformation_name = args.transformation
if ":" in transformation_name:
transformation_name, params = transformation_name.split(":")
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model, {params})"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
)
else:
raise ValueError(f"Error: unsupported transformation {transformation_name}")
else:
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model)"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
)
else:
raise ValueError(f"Error: unsupported transformation {transformation_name}")
return transformation
def parse_goal_function_from_args(args, model):
# Goal Functions
goal_function = args.goal_function
if ":" in goal_function:
goal_function_name, params = goal_function.split(":")
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
raise ValueError(f"Error: unsupported goal_function {goal_function_name}")
goal_function = eval(
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model, {params})"
)
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
goal_function = eval(f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model)")
else:
raise ValueError(f"Error: unsupported goal_function {goal_function}")
goal_function.query_budget = args.query_budget
return goal_function
def parse_constraints_from_args(args):
# Constraints
if not args.constraints:
return []
_constraints = []
for constraint in args.constraints:
if ":" in constraint:
constraint_name, params = constraint.split(":")
if constraint_name not in CONSTRAINT_CLASS_NAMES:
raise ValueError(f"Error: unsupported constraint {constraint_name}")
_constraints.append(
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
)
elif constraint in CONSTRAINT_CLASS_NAMES:
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
else:
raise ValueError(f"Error: unsupported constraint {constraint}")
return _constraints
def parse_attack_from_args(args):
model = parse_model_from_args(args)
if args.recipe:
if ":" in args.recipe:
recipe_name, params = args.recipe.split(":")
if recipe_name not in RECIPE_NAMES:
raise ValueError(f"Error: unsupported recipe {recipe_name}")
recipe = eval(f"{RECIPE_NAMES[recipe_name]}(model, {params})")
elif args.recipe in RECIPE_NAMES:
recipe = eval(f"{RECIPE_NAMES[args.recipe]}(model)")
else:
raise ValueError(f"Invalid recipe {args.recipe}")
recipe.goal_function.query_budget = args.query_budget
return recipe
elif args.attack_from_file:
if ":" in args.attack_from_file:
attack_file, attack_name = args.attack_from_file.split(":")
else:
attack_file, attack_name = args.attack_from_file, "attack"
attack_file = attack_file.replace(".py", "").replace("/", ".")
attack_module = importlib.import_module(attack_file)
attack_func = getattr(attack_module, attack_name)
return attack_func(model)
else:
goal_function = parse_goal_function_from_args(args, model)
transformation = parse_transformation_from_args(args, model)
constraints = parse_constraints_from_args(args)
if ":" in args.search:
search_name, params = args.search.split(":")
if search_name not in SEARCH_CLASS_NAMES:
raise ValueError(f"Error: unsupported search {search_name}")
search_method = eval(f"{SEARCH_CLASS_NAMES[search_name]}({params})")
elif args.search in SEARCH_CLASS_NAMES:
search_method = eval(f"{SEARCH_CLASS_NAMES[args.search]}()")
else:
raise ValueError(f"Error: unsupported attack {args.search}")
return textattack.shared.Attack(
goal_function, constraints, transformation, search_method
)
def parse_model_from_args(args):
if args.model_from_file:
colored_model_name = textattack.shared.utils.color_text(
args.model_from_file, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading model and tokenizer from file: {colored_model_name}"
)
if ":" in args.model_from_file:
model_file, model_name, tokenizer_name = args.model_from_file.split(":")
else:
model_file, model_name, tokenizer_name = (
args.model_from_file,
"model",
"tokenizer",
)
try:
model_file = args.model_from_file.replace(".py", "").replace("/", ".")
model_module = importlib.import_module(model_file)
except:
raise ValueError(
f"Failed to import model or tokenizer from file {args.model_from_file}"
)
try:
model = getattr(model_module, model_name)
except AttributeError:
raise AttributeError(
f"``{model_name}`` not found in module {args.model_from_file}"
)
try:
tokenizer = getattr(model_module, tokenizer_name)
except AttributeError:
raise AttributeError(
f"``{tokenizer_name}`` not found in module {args.model_from_file}"
)
model = model.to(textattack.shared.utils.device)
setattr(model, "tokenizer", tokenizer)
elif args.model_from_huggingface:
import transformers
if ":" in args.model_from_huggingface:
model_class, model_name = args.model_from_huggingface.split(":")
model_class = eval(f"transformers.{model_class}")
else:
model_class, model_name = (
transformers.AutoModelForSequenceClassification,
args.model_from_huggingface,
)
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
)
model = model_class.from_pretrained(model_name)
model = model.to(textattack.shared.utils.device)
try:
tokenizer = textattack.models.tokenizers.AutoTokenizer(args.model_from_huggingface)
except OSError:
textattack.shared.logger.warn(
f"AutoTokenizer {args.model_from_huggingface} not found. Defaulting to `bert-base-uncased`"
)
tokenizer = textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
setattr(model, "tokenizer", tokenizer)
else:
if ":" in args.model:
model_name, params = args.model.split(":")
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack model: {colored_model_name}"
)
if model_name not in TEXTATTACK_MODEL_CLASS_NAMES:
raise ValueError(f"Error: unsupported model {model_name}")
model = eval(f"{TEXTATTACK_MODEL_CLASS_NAMES[model_name]}({params})")
elif args.model in TEXTATTACK_MODEL_CLASS_NAMES:
colored_model_name = textattack.shared.utils.color_text(
args.model, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack model: {colored_model_name}"
)
model = eval(f"{TEXTATTACK_MODEL_CLASS_NAMES[args.model]}()")
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
colored_model_name = textattack.shared.utils.color_text(
args.model, color="blue", method="ansi"
)
model_path, args.dataset_from_nlp = TEXTATTACK_DATASET_BY_MODEL[args.model]
if args.model.startswith("lstm"):
textattack.shared.logger.info(
f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
)
model = textattack.models.helpers.LSTMForClassification(
model_path=model_path
)
elif args.model.startswith("cnn"):
textattack.shared.logger.info(
f"Loading pre-trained TextAttack CNN: {colored_model_name}"
)
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path
)
else:
raise ValueError(f"Unknown TextAttack pretrained model {args.model}")
else:
raise ValueError(f"Error: unsupported model {args.model}")
return model
def parse_dataset_from_args(args):
if args.dataset_from_file:
textattack.shared.logger.info(
f"Loading model and tokenizer from file: {args.model_from_file}"
)
if ":" in args.dataset_from_file:
dataset_file, dataset_name = args.dataset_from_file.split(":")
else:
dataset_file, dataset_name = args.dataset_from_file, "dataset"
try:
dataset_file = dataset_file.replace(".py", "").replace("/", ".")
dataset_module = importlib.import_module(dataset_file)
except:
raise ValueError(
f"Failed to import dataset from file {args.dataset_from_file}"
)
try:
dataset = getattr(dataset_module, dataset_name)
except AttributeError:
raise AttributeError(
f"``dataset`` not found in module {args.dataset_from_file}"
)
elif args.dataset_from_nlp:
dataset_args = args.dataset_from_nlp
if ":" in dataset_args:
dataset_args = dataset_args.split(":")
dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, shuffle=args.shuffle
)
else:
if not args.model:
raise ValueError("Must supply pretrained model or dataset")
elif args.model in DATASET_BY_MODEL:
dataset = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset)
else:
raise ValueError(f"Error: unsupported model {args.model}")
return dataset
def parse_logger_from_args(args):
# Create logger
attack_log_manager = textattack.loggers.AttackLogManager()
# Set default output directory to `textattack/outputs`.
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "outputs"
)
args.out_dir = os.path.normpath(outputs_dir)
# Output file.
out_time = int(time.time() * 1000) # Output file
outfile_name = "attack-{}.txt".format(out_time)
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
# CSV
if args.enable_csv:
outfile_name = "attack-{}.csv".format(out_time)
color_method = None if args.enable_csv == "plain" else "file"
csv_path = os.path.join(args.out_dir, outfile_name)
attack_log_manager.add_output_csv(csv_path, color_method)
print("Logging to CSV at path {}.".format(csv_path))
# Visdom
if args.enable_visdom:
attack_log_manager.enable_visdom()
# Weights & Biases
if args.enable_wandb:
attack_log_manager.enable_wandb()
# Stdout
if not args.disable_stdout:
attack_log_manager.enable_stdout()
return attack_log_manager
def parse_checkpoint_from_args(args):
file_name = os.path.basename(args.checkpoint_file)
if file_name.lower() == "latest":
dir_path = os.path.dirname(args.checkpoint_file)
chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith(".ta.chkpt")]
assert chkpt_file_names, "Checkpoint directory is empty"
timestamps = [int(f.replace(".ta.chkpt", "")) for f in chkpt_file_names]
latest_file = str(max(timestamps)) + ".ta.chkpt"
checkpoint_path = os.path.join(dir_path, latest_file)
else:
checkpoint_path = args.checkpoint_file
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
set_seed(checkpoint.args.random_seed)
return checkpoint
def default_checkpoint_dir():
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "checkpoints"
)
return os.path.normpath(checkpoints_dir)
def merge_checkpoint_args(saved_args, cmdline_args):
""" Merge previously saved arguments for checkpoint and newly entered arguments """
args = copy.deepcopy(saved_args)
# Newly entered arguments take precedence
args.checkpoint_resume = cmdline_args.checkpoint_resume
args.parallel = cmdline_args.parallel
# If set, replace
if cmdline_args.checkpoint_dir:
args.checkpoint_dir = cmdline_args.checkpoint_dir
if cmdline_args.checkpoint_interval:
args.checkpoint_interval = cmdline_args.checkpoint_interval
return args