1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/commands/attack/attack_args_helpers.py
2020-07-26 15:56:36 -04:00

491 lines
19 KiB
Python

import argparse
import copy
import importlib
import json
import os
import time
import textattack
from .attack_args import (
ATTACK_RECIPE_NAMES,
BLACK_BOX_TRANSFORMATION_CLASS_NAMES,
CONSTRAINT_CLASS_NAMES,
GOAL_FUNCTION_CLASS_NAMES,
HUGGINGFACE_DATASET_BY_MODEL,
SEARCH_METHOD_CLASS_NAMES,
TEXTATTACK_DATASET_BY_MODEL,
WHITE_BOX_TRANSFORMATION_CLASS_NAMES,
)
# The split token allows users to optionally pass multiple arguments in a single
# parameter by separating them with the split token.
ARGS_SPLIT_TOKEN = "|"
def add_model_args(parser):
"""Adds model-related arguments to an argparser.
This is useful because we want to load pretrained models using
multiple different parsers that share these, but not all, arguments.
"""
model_group = parser.add_mutually_exclusive_group()
model_names = list(HUGGINGFACE_DATASET_BY_MODEL.keys()) + list(
TEXTATTACK_DATASET_BY_MODEL.keys()
)
model_group.add_argument(
"--model",
type=str,
required=False,
default=None,
help='Name of or path to a pre-trained model to attack. Usage: "--model {model}:{arg_1}={value_1},{arg_3}={value_3},...". Choices: '
+ str(model_names),
)
model_group.add_argument(
"--model-from-file",
type=str,
required=False,
help="File of model and tokenizer to import.",
)
model_group.add_argument(
"--model-from-huggingface",
type=str,
required=False,
help="huggingface.co ID of pre-trained model to load",
)
def add_dataset_args(parser):
"""Adds dataset-related arguments to an argparser.
This is useful because we want to load pretrained models using
multiple different parsers that share these, but not all, arguments.
"""
dataset_group = parser.add_mutually_exclusive_group()
dataset_group.add_argument(
"--dataset-from-nlp",
type=str,
required=False,
default=None,
help="Dataset to load from `nlp` repository.",
)
dataset_group.add_argument(
"--dataset-from-file",
type=str,
required=False,
default=None,
help="Dataset to load from a file.",
)
parser.add_argument(
"--shuffle",
type=eval,
required=False,
choices=[True, False],
default="True",
help="Randomly shuffle the data before attacking",
)
parser.add_argument(
"--num-examples",
"-n",
type=int,
required=False,
default="5",
help="The number of examples to process.",
)
parser.add_argument(
"--num-examples-offset",
"-o",
type=int,
required=False,
default=0,
help="The offset to start at in the dataset.",
)
def load_module_from_file(file_path):
"""Uses ``importlib`` to dynamically open a file and load an object from
it."""
temp_module_name = f"temp_{time.time()}"
colored_file_path = textattack.shared.utils.color_text(
file_path, color="blue", method="ansi"
)
textattack.shared.logger.info(f"Loading module from `{colored_file_path}`.")
spec = importlib.util.spec_from_file_location(temp_module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def parse_transformation_from_args(args, model_wrapper):
transformation_name = args.transformation
if ARGS_SPLIT_TOKEN in transformation_name:
transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
)
else:
raise ValueError(f"Error: unsupported transformation {transformation_name}")
else:
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
)
else:
raise ValueError(f"Error: unsupported transformation {transformation_name}")
return transformation
def parse_goal_function_from_args(args, model):
goal_function = args.goal_function
if ARGS_SPLIT_TOKEN in goal_function:
goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
raise ValueError(f"Error: unsupported goal_function {goal_function_name}")
goal_function = eval(
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model, {params})"
)
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
goal_function = eval(f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model)")
else:
raise ValueError(f"Error: unsupported goal_function {goal_function}")
goal_function.query_budget = args.query_budget
goal_function.model_batch_size = args.model_batch_size
goal_function.model_cache_size = args.model_cache_size
return goal_function
def parse_constraints_from_args(args):
if not args.constraints:
return []
_constraints = []
for constraint in args.constraints:
if ARGS_SPLIT_TOKEN in constraint:
constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
if constraint_name not in CONSTRAINT_CLASS_NAMES:
raise ValueError(f"Error: unsupported constraint {constraint_name}")
_constraints.append(
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
)
elif constraint in CONSTRAINT_CLASS_NAMES:
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
else:
raise ValueError(f"Error: unsupported constraint {constraint}")
return _constraints
def parse_attack_from_args(args):
model = parse_model_from_args(args)
if args.recipe:
if ARGS_SPLIT_TOKEN in args.recipe:
recipe_name, params = args.recipe.split(ARGS_SPLIT_TOKEN)
if recipe_name not in ATTACK_RECIPE_NAMES:
raise ValueError(f"Error: unsupported recipe {recipe_name}")
recipe = eval(f"{ATTACK_RECIPE_NAMES[recipe_name]}(model, {params})")
elif args.recipe in ATTACK_RECIPE_NAMES:
recipe = eval(f"{ATTACK_RECIPE_NAMES[args.recipe]}(model)")
else:
raise ValueError(f"Invalid recipe {args.recipe}")
recipe.goal_function.query_budget = args.query_budget
recipe.goal_function.model_batch_size = args.model_batch_size
recipe.goal_function.model_cache_size = args.model_cache_size
recipe.constraint_cache_size = args.constraint_cache_size
return recipe
elif args.attack_from_file:
if ARGS_SPLIT_TOKEN in args.attack_from_file:
attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
else:
attack_file, attack_name = args.attack_from_file, "attack"
attack_module = load_module_from_file(attack_file)
if not hasattr(attack_module, attack_name):
raise ValueError(
f"Loaded `{attack_file}` but could not find `{attack_name}`."
)
attack_func = getattr(attack_module, attack_name)
return attack_func(model)
else:
goal_function = parse_goal_function_from_args(args, model)
transformation = parse_transformation_from_args(args, model)
constraints = parse_constraints_from_args(args)
if ARGS_SPLIT_TOKEN in args.search:
search_name, params = args.search.split(ARGS_SPLIT_TOKEN)
if search_name not in SEARCH_METHOD_CLASS_NAMES:
raise ValueError(f"Error: unsupported search {search_name}")
search_method = eval(f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})")
elif args.search in SEARCH_METHOD_CLASS_NAMES:
search_method = eval(f"{SEARCH_METHOD_CLASS_NAMES[args.search]}()")
else:
raise ValueError(f"Error: unsupported attack {args.search}")
return textattack.shared.Attack(
goal_function,
constraints,
transformation,
search_method,
constraint_cache_size=args.constraint_cache_size,
)
def parse_model_from_args(args):
if args.model_from_file:
# Support loading the model from a .py file where a model wrapper
# is instantiated.
colored_model_name = textattack.shared.utils.color_text(
args.model_from_file, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading model and tokenizer from file: {colored_model_name}"
)
if ARGS_SPLIT_TOKEN in args.model_from_file:
model_file, model_name = args.model_from_file.split(ARGS_SPLIT_TOKEN)
else:
_, model_name = args.model_from_file, "model"
try:
model_module = load_module_from_file(args.model_from_file)
except Exception:
raise ValueError(f"Failed to import file {args.model_from_file}")
try:
model = getattr(model_module, model_name)
except AttributeError:
raise AttributeError(
f"``{model_name}`` not found in module {args.model_from_file}"
)
if not isinstance(model, textattack.models.wrappers.ModelWrapper):
raise TypeError(
"Model must be of type "
f"``textattack.models.ModelWrapper``, got type {type(model)}"
)
elif (args.model in HUGGINGFACE_DATASET_BY_MODEL) or args.model_from_huggingface:
# Support loading models automatically from the HuggingFace model hub.
import transformers
model_name = (
HUGGINGFACE_DATASET_BY_MODEL[args.model][0]
if (args.model in HUGGINGFACE_DATASET_BY_MODEL)
else args.model_from_huggingface
)
if ARGS_SPLIT_TOKEN in model_name:
model_class, model_name = model_name
model_class = eval(f"transformers.{model_class}")
else:
model_class, model_name = (
transformers.AutoModelForSequenceClassification,
model_name,
)
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
)
model = model_class.from_pretrained(model_name)
tokenizer = textattack.models.tokenizers.AutoTokenizer(model_name)
model = textattack.models.wrappers.HuggingFaceModelWrapper(
model, tokenizer, batch_size=args.model_batch_size
)
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
# Support loading TextAttack pre-trained models via just a keyword.
model_path, _ = TEXTATTACK_DATASET_BY_MODEL[args.model]
model = textattack.shared.utils.load_textattack_model_from_path(
args.model, model_path
)
model = textattack.models.wrappers.PyTorchModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)
elif args.model and os.path.exists(args.model):
# Support loading TextAttack-trained models via just their folder path.
# If `args.model` is a path/directory, let's assume it was a model
# trained with textattack, and try and load it.
model_args_json_path = os.path.join(args.model, "train_args.json")
if not os.path.exists(model_args_json_path):
raise FileNotFoundError(
f"Tried to load model from path {args.model} - could not find train_args.json."
)
model_train_args = json.loads(open(model_args_json_path).read())
if model_train_args["model"] not in {"cnn", "lstm"}:
# for huggingface models, set args.model to the path of the model
model_train_args["model"] = args.model
num_labels = model_train_args["num_labels"]
from textattack.commands.train_model.train_args_helpers import model_from_args
model = model_from_args(
argparse.Namespace(**model_train_args), num_labels, model_path=args.model,
)
model = textattack.models.wrappers.PyTorchModelWrapper(
model, model.tokenizer, batch_size=args.model_batch_size
)
else:
raise ValueError(f"Error: unsupported TextAttack model {args.model}")
return model
def parse_dataset_from_args(args):
# Automatically detect dataset for huggingface & textattack models.
# This allows us to use the --model shortcut without specifying a dataset.
if args.model in HUGGINGFACE_DATASET_BY_MODEL:
_, args.dataset_from_nlp = HUGGINGFACE_DATASET_BY_MODEL[args.model]
elif args.model in TEXTATTACK_DATASET_BY_MODEL:
_, dataset = TEXTATTACK_DATASET_BY_MODEL[args.model]
if dataset[0].startswith("textattack"):
# unsavory way to pass custom dataset classes
# ex: dataset = ('textattack.datasets.translation.TedMultiTranslationDataset', 'en', 'de')
dataset = eval(f"{dataset[0]}")(*dataset[1:])
return dataset
else:
args.dataset_from_nlp = dataset
# Automatically detect dataset for models trained with textattack.
elif args.model and os.path.exists(args.model):
model_args_json_path = os.path.join(args.model, "train_args.json")
if not os.path.exists(model_args_json_path):
raise FileNotFoundError(
f"Tried to load model from path {args.model} - could not find train_args.json."
)
model_train_args = json.loads(open(model_args_json_path).read())
try:
if ARGS_SPLIT_TOKEN in model_train_args["dataset"]:
name, subset = model_train_args["dataset"].split(ARGS_SPLIT_TOKEN)
else:
name, subset = model_train_args["dataset"], None
args.dataset_from_nlp = (
name,
subset,
model_train_args["dataset_dev_split"],
)
except KeyError:
raise KeyError(
f"Tried to load model from path {args.model} but can't initialize dataset from train_args.json."
)
# Get dataset from args.
if args.dataset_from_file:
textattack.shared.logger.info(
f"Loading model and tokenizer from file: {args.model_from_file}"
)
if ARGS_SPLIT_TOKEN in args.dataset_from_file:
dataset_file, dataset_name = args.dataset_from_file.split(ARGS_SPLIT_TOKEN)
else:
dataset_file, dataset_name = args.dataset_from_file, "dataset"
try:
dataset_module = load_module_from_file(dataset_file)
except Exception:
raise ValueError(
f"Failed to import dataset from file {args.dataset_from_file}"
)
try:
dataset = getattr(dataset_module, dataset_name)
except AttributeError:
raise AttributeError(
f"``dataset`` not found in module {args.dataset_from_file}"
)
elif args.dataset_from_nlp:
dataset_args = args.dataset_from_nlp
if isinstance(dataset_args, str):
if ARGS_SPLIT_TOKEN in dataset_args:
dataset_args = dataset_args.split(ARGS_SPLIT_TOKEN)
else:
dataset_args = (dataset_args,)
dataset = textattack.datasets.HuggingFaceNlpDataset(
*dataset_args, shuffle=args.shuffle
)
dataset.examples = dataset.examples[args.num_examples_offset :]
else:
raise ValueError("Must supply pretrained model or dataset")
return dataset
def parse_logger_from_args(args):
# Create logger
attack_log_manager = textattack.loggers.AttackLogManager()
out_time = int(time.time() * 1000) # Output file
# Set default output directory to `textattack/outputs`.
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "outputs", "attacks"
)
if not os.path.exists(outputs_dir):
os.makedirs(outputs_dir)
args.out_dir = os.path.normpath(outputs_dir)
# if "--log-to-file" specified in terminal command, then save it to a txt file
if args.log_to_file:
# Output file.
outfile_name = "attack-{}.txt".format(out_time)
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
# CSV
if args.enable_csv:
outfile_name = "attack-{}.csv".format(out_time)
color_method = None if args.enable_csv == "plain" else "file"
csv_path = os.path.join(args.out_dir, outfile_name)
attack_log_manager.add_output_csv(csv_path, color_method)
textattack.shared.logger.info(f"Logging to CSV at path {csv_path}.")
# Visdom
if args.enable_visdom:
attack_log_manager.enable_visdom()
# Weights & Biases
if args.enable_wandb:
attack_log_manager.enable_wandb()
# Stdout
if not args.disable_stdout:
attack_log_manager.enable_stdout()
return attack_log_manager
def parse_checkpoint_from_args(args):
file_name = os.path.basename(args.checkpoint_file)
if file_name.lower() == "latest":
dir_path = os.path.dirname(args.checkpoint_file)
chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith(".ta.chkpt")]
assert chkpt_file_names, "Checkpoint directory is empty"
timestamps = [int(f.replace(".ta.chkpt", "")) for f in chkpt_file_names]
latest_file = str(max(timestamps)) + ".ta.chkpt"
checkpoint_path = os.path.join(dir_path, latest_file)
else:
checkpoint_path = args.checkpoint_file
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
return checkpoint
def default_checkpoint_dir():
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "checkpoints"
)
return os.path.normpath(checkpoints_dir)
def merge_checkpoint_args(saved_args, cmdline_args):
"""Merge previously saved arguments for checkpoint and newly entered
arguments."""
args = copy.deepcopy(saved_args)
# Newly entered arguments take precedence
args.parallel = cmdline_args.parallel
# If set, replace
if cmdline_args.checkpoint_dir:
args.checkpoint_dir = cmdline_args.checkpoint_dir
if cmdline_args.checkpoint_interval:
args.checkpoint_interval = cmdline_args.checkpoint_interval
return args