mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
textattack/commands
This commit is contained in:
318
textattack/commands/attack/attack_args_helpers.py
Normal file
318
textattack/commands/attack/attack_args_helpers.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import argparse
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import textattack
|
||||
|
||||
def parse_transformation_from_args(args, model):
|
||||
# Transformations
|
||||
transformation_name = args.transformation
|
||||
if ":" in transformation_name:
|
||||
transformation_name, params = transformation_name.split(":")
|
||||
|
||||
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
||||
transformation = eval(
|
||||
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model, {params})"
|
||||
)
|
||||
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
||||
transformation = eval(
|
||||
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported transformation {transformation_name}")
|
||||
else:
|
||||
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
||||
transformation = eval(
|
||||
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model)"
|
||||
)
|
||||
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
||||
transformation = eval(
|
||||
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported transformation {transformation_name}")
|
||||
return transformation
|
||||
|
||||
|
||||
def parse_goal_function_from_args(args, model):
|
||||
# Goal Functions
|
||||
goal_function = args.goal_function
|
||||
if ":" in goal_function:
|
||||
goal_function_name, params = goal_function.split(":")
|
||||
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
|
||||
raise ValueError(f"Error: unsupported goal_function {goal_function_name}")
|
||||
goal_function = eval(
|
||||
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model, {params})"
|
||||
)
|
||||
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
|
||||
goal_function = eval(f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model)")
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported goal_function {goal_function}")
|
||||
goal_function.query_budget = args.query_budget
|
||||
return goal_function
|
||||
|
||||
|
||||
def parse_constraints_from_args(args):
|
||||
# Constraints
|
||||
if not args.constraints:
|
||||
return []
|
||||
|
||||
_constraints = []
|
||||
for constraint in args.constraints:
|
||||
if ":" in constraint:
|
||||
constraint_name, params = constraint.split(":")
|
||||
if constraint_name not in CONSTRAINT_CLASS_NAMES:
|
||||
raise ValueError(f"Error: unsupported constraint {constraint_name}")
|
||||
_constraints.append(
|
||||
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
|
||||
)
|
||||
elif constraint in CONSTRAINT_CLASS_NAMES:
|
||||
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported constraint {constraint}")
|
||||
|
||||
return _constraints
|
||||
|
||||
|
||||
def parse_attack_from_args(args):
|
||||
model = parse_model_from_args(args)
|
||||
if args.recipe:
|
||||
if ":" in args.recipe:
|
||||
recipe_name, params = args.recipe.split(":")
|
||||
if recipe_name not in RECIPE_NAMES:
|
||||
raise ValueError(f"Error: unsupported recipe {recipe_name}")
|
||||
recipe = eval(f"{RECIPE_NAMES[recipe_name]}(model, {params})")
|
||||
elif args.recipe in RECIPE_NAMES:
|
||||
recipe = eval(f"{RECIPE_NAMES[args.recipe]}(model)")
|
||||
else:
|
||||
raise ValueError(f"Invalid recipe {args.recipe}")
|
||||
recipe.goal_function.query_budget = args.query_budget
|
||||
return recipe
|
||||
elif args.attack_from_file:
|
||||
if ":" in args.attack_from_file:
|
||||
attack_file, attack_name = args.attack_from_file.split(":")
|
||||
else:
|
||||
attack_file, attack_name = args.attack_from_file, "attack"
|
||||
attack_file = attack_file.replace(".py", "").replace("/", ".")
|
||||
attack_module = importlib.import_module(attack_file)
|
||||
attack_func = getattr(attack_module, attack_name)
|
||||
return attack_func(model)
|
||||
else:
|
||||
goal_function = parse_goal_function_from_args(args, model)
|
||||
transformation = parse_transformation_from_args(args, model)
|
||||
constraints = parse_constraints_from_args(args)
|
||||
if ":" in args.search:
|
||||
search_name, params = args.search.split(":")
|
||||
if search_name not in SEARCH_CLASS_NAMES:
|
||||
raise ValueError(f"Error: unsupported search {search_name}")
|
||||
search_method = eval(f"{SEARCH_CLASS_NAMES[search_name]}({params})")
|
||||
elif args.search in SEARCH_CLASS_NAMES:
|
||||
search_method = eval(f"{SEARCH_CLASS_NAMES[args.search]}()")
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported attack {args.search}")
|
||||
return textattack.shared.Attack(
|
||||
goal_function, constraints, transformation, search_method
|
||||
)
|
||||
|
||||
|
||||
def parse_model_from_args(args):
|
||||
if args.model_from_file:
|
||||
colored_model_name = textattack.shared.utils.color_text(
|
||||
args.model_from_file, color="blue", method="ansi"
|
||||
)
|
||||
textattack.shared.logger.info(
|
||||
f"Loading model and tokenizer from file: {colored_model_name}"
|
||||
)
|
||||
if ":" in args.model_from_file:
|
||||
model_file, model_name, tokenizer_name = args.model_from_file.split(":")
|
||||
else:
|
||||
model_file, model_name, tokenizer_name = (
|
||||
args.model_from_file,
|
||||
"model",
|
||||
"tokenizer",
|
||||
)
|
||||
try:
|
||||
model_file = args.model_from_file.replace(".py", "").replace("/", ".")
|
||||
model_module = importlib.import_module(model_file)
|
||||
except:
|
||||
raise ValueError(
|
||||
f"Failed to import model or tokenizer from file {args.model_from_file}"
|
||||
)
|
||||
try:
|
||||
model = getattr(model_module, model_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
f"``{model_name}`` not found in module {args.model_from_file}"
|
||||
)
|
||||
try:
|
||||
tokenizer = getattr(model_module, tokenizer_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
f"``{tokenizer_name}`` not found in module {args.model_from_file}"
|
||||
)
|
||||
model = model.to(textattack.shared.utils.device)
|
||||
setattr(model, "tokenizer", tokenizer)
|
||||
elif (args.model in HUGGINGFACE_DATASET_BY_MODEL) or args.model_from_huggingface:
|
||||
import transformers
|
||||
|
||||
model_name = (
|
||||
HUGGINGFACE_DATASET_BY_MODEL[args.model][0]
|
||||
if (args.model in HUGGINGFACE_DATASET_BY_MODEL)
|
||||
else args.model_from_huggingface
|
||||
)
|
||||
|
||||
if ":" in model_name:
|
||||
model_class, model_name = model_name
|
||||
model_class = eval(f"transformers.{model_class}")
|
||||
else:
|
||||
model_class, model_name = (
|
||||
transformers.AutoModelForSequenceClassification,
|
||||
model_name,
|
||||
)
|
||||
colored_model_name = textattack.shared.utils.color_text(
|
||||
model_name, color="blue", method="ansi"
|
||||
)
|
||||
textattack.shared.logger.info(
|
||||
f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
|
||||
)
|
||||
model = model_class.from_pretrained(model_name)
|
||||
model = model.to(textattack.shared.utils.device)
|
||||
try:
|
||||
tokenizer = textattack.models.tokenizers.AutoTokenizer(model_name)
|
||||
except OSError:
|
||||
textattack.shared.logger.warn(
|
||||
f"AutoTokenizer {args.model_from_huggingface} not found. Defaulting to `bert-base-uncased`"
|
||||
)
|
||||
tokenizer = textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
|
||||
setattr(model, "tokenizer", tokenizer)
|
||||
else:
|
||||
if args.model in TEXTATTACK_DATASET_BY_MODEL:
|
||||
model_path, _ = TEXTATTACK_DATASET_BY_MODEL[args.model]
|
||||
model = textattack.shared.utils.load_textattack_model_from_path(
|
||||
args.model, model_path
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Error: unsupported TextAttack model {args.model}")
|
||||
return model
|
||||
|
||||
|
||||
def parse_dataset_from_args(args):
|
||||
if args.dataset_from_file:
|
||||
textattack.shared.logger.info(
|
||||
f"Loading model and tokenizer from file: {args.model_from_file}"
|
||||
)
|
||||
if ":" in args.dataset_from_file:
|
||||
dataset_file, dataset_name = args.dataset_from_file.split(":")
|
||||
else:
|
||||
dataset_file, dataset_name = args.dataset_from_file, "dataset"
|
||||
try:
|
||||
dataset_file = dataset_file.replace(".py", "").replace("/", ".")
|
||||
dataset_module = importlib.import_module(dataset_file)
|
||||
except:
|
||||
raise ValueError(
|
||||
f"Failed to import dataset from file {args.dataset_from_file}"
|
||||
)
|
||||
try:
|
||||
dataset = getattr(dataset_module, dataset_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
f"``dataset`` not found in module {args.dataset_from_file}"
|
||||
)
|
||||
elif args.dataset_from_nlp:
|
||||
dataset_args = args.dataset_from_nlp
|
||||
if ":" in dataset_args:
|
||||
dataset_args = dataset_args.split(":")
|
||||
dataset = textattack.datasets.HuggingFaceNLPDataset(
|
||||
*dataset_args, shuffle=args.shuffle
|
||||
)
|
||||
dataset.examples = dataset.examples[args.num_examples_offset :]
|
||||
else:
|
||||
raise ValueError("Must supply pretrained model or dataset")
|
||||
return dataset
|
||||
|
||||
|
||||
def parse_logger_from_args(args):
|
||||
# Create logger
|
||||
attack_log_manager = textattack.loggers.AttackLogManager()
|
||||
# Set default output directory to `textattack/outputs`.
|
||||
if not args.out_dir:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(
|
||||
current_dir, os.pardir, os.pardir, os.pardir, "outputs"
|
||||
)
|
||||
args.out_dir = os.path.normpath(outputs_dir)
|
||||
|
||||
# Output file.
|
||||
out_time = int(time.time() * 1000) # Output file
|
||||
outfile_name = "attack-{}.txt".format(out_time)
|
||||
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
|
||||
|
||||
# CSV
|
||||
if args.enable_csv:
|
||||
outfile_name = "attack-{}.csv".format(out_time)
|
||||
color_method = None if args.enable_csv == "plain" else "file"
|
||||
csv_path = os.path.join(args.out_dir, outfile_name)
|
||||
attack_log_manager.add_output_csv(csv_path, color_method)
|
||||
print("Logging to CSV at path {}.".format(csv_path))
|
||||
|
||||
# Visdom
|
||||
if args.enable_visdom:
|
||||
attack_log_manager.enable_visdom()
|
||||
|
||||
# Weights & Biases
|
||||
if args.enable_wandb:
|
||||
attack_log_manager.enable_wandb()
|
||||
|
||||
# Stdout
|
||||
if not args.disable_stdout:
|
||||
attack_log_manager.enable_stdout()
|
||||
return attack_log_manager
|
||||
|
||||
|
||||
def parse_checkpoint_from_args(args):
|
||||
file_name = os.path.basename(args.checkpoint_file)
|
||||
if file_name.lower() == "latest":
|
||||
dir_path = os.path.dirname(args.checkpoint_file)
|
||||
chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith(".ta.chkpt")]
|
||||
assert chkpt_file_names, "Checkpoint directory is empty"
|
||||
timestamps = [int(f.replace(".ta.chkpt", "")) for f in chkpt_file_names]
|
||||
latest_file = str(max(timestamps)) + ".ta.chkpt"
|
||||
checkpoint_path = os.path.join(dir_path, latest_file)
|
||||
else:
|
||||
checkpoint_path = args.checkpoint_file
|
||||
|
||||
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
|
||||
set_seed(checkpoint.args.random_seed)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def default_checkpoint_dir():
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
checkpoints_dir = os.path.join(
|
||||
current_dir, os.pardir, os.pardir, os.pardir, "checkpoints"
|
||||
)
|
||||
return os.path.normpath(checkpoints_dir)
|
||||
|
||||
|
||||
def merge_checkpoint_args(saved_args, cmdline_args):
|
||||
""" Merge previously saved arguments for checkpoint and newly entered arguments """
|
||||
args = copy.deepcopy(saved_args)
|
||||
# Newly entered arguments take precedence
|
||||
args.checkpoint_resume = cmdline_args.checkpoint_resume
|
||||
args.parallel = cmdline_args.parallel
|
||||
# If set, replace
|
||||
if cmdline_args.checkpoint_dir:
|
||||
args.checkpoint_dir = cmdline_args.checkpoint_dir
|
||||
if cmdline_args.checkpoint_interval:
|
||||
args.checkpoint_interval = cmdline_args.checkpoint_interval
|
||||
|
||||
return args
|
||||
Reference in New Issue
Block a user