1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

textattack/commands

This commit is contained in:
Jack Morris
2020-06-19 16:48:59 -04:00
parent 8691088e65
commit d2c115d3c8
17 changed files with 430 additions and 332 deletions

View File

@@ -0,0 +1,318 @@
import argparse
import copy
import importlib
import os
import pickle
import random
import sys
import time
import numpy as np
import torch
import textattack
def parse_transformation_from_args(args, model):
# Transformations
transformation_name = args.transformation
if ":" in transformation_name:
transformation_name, params = transformation_name.split(":")
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model, {params})"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
)
else:
raise ValueError(f"Error: unsupported transformation {transformation_name}")
else:
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model)"
)
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
transformation = eval(
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
)
else:
raise ValueError(f"Error: unsupported transformation {transformation_name}")
return transformation
def parse_goal_function_from_args(args, model):
# Goal Functions
goal_function = args.goal_function
if ":" in goal_function:
goal_function_name, params = goal_function.split(":")
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
raise ValueError(f"Error: unsupported goal_function {goal_function_name}")
goal_function = eval(
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model, {params})"
)
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
goal_function = eval(f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model)")
else:
raise ValueError(f"Error: unsupported goal_function {goal_function}")
goal_function.query_budget = args.query_budget
return goal_function
def parse_constraints_from_args(args):
# Constraints
if not args.constraints:
return []
_constraints = []
for constraint in args.constraints:
if ":" in constraint:
constraint_name, params = constraint.split(":")
if constraint_name not in CONSTRAINT_CLASS_NAMES:
raise ValueError(f"Error: unsupported constraint {constraint_name}")
_constraints.append(
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
)
elif constraint in CONSTRAINT_CLASS_NAMES:
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
else:
raise ValueError(f"Error: unsupported constraint {constraint}")
return _constraints
def parse_attack_from_args(args):
model = parse_model_from_args(args)
if args.recipe:
if ":" in args.recipe:
recipe_name, params = args.recipe.split(":")
if recipe_name not in RECIPE_NAMES:
raise ValueError(f"Error: unsupported recipe {recipe_name}")
recipe = eval(f"{RECIPE_NAMES[recipe_name]}(model, {params})")
elif args.recipe in RECIPE_NAMES:
recipe = eval(f"{RECIPE_NAMES[args.recipe]}(model)")
else:
raise ValueError(f"Invalid recipe {args.recipe}")
recipe.goal_function.query_budget = args.query_budget
return recipe
elif args.attack_from_file:
if ":" in args.attack_from_file:
attack_file, attack_name = args.attack_from_file.split(":")
else:
attack_file, attack_name = args.attack_from_file, "attack"
attack_file = attack_file.replace(".py", "").replace("/", ".")
attack_module = importlib.import_module(attack_file)
attack_func = getattr(attack_module, attack_name)
return attack_func(model)
else:
goal_function = parse_goal_function_from_args(args, model)
transformation = parse_transformation_from_args(args, model)
constraints = parse_constraints_from_args(args)
if ":" in args.search:
search_name, params = args.search.split(":")
if search_name not in SEARCH_CLASS_NAMES:
raise ValueError(f"Error: unsupported search {search_name}")
search_method = eval(f"{SEARCH_CLASS_NAMES[search_name]}({params})")
elif args.search in SEARCH_CLASS_NAMES:
search_method = eval(f"{SEARCH_CLASS_NAMES[args.search]}()")
else:
raise ValueError(f"Error: unsupported attack {args.search}")
return textattack.shared.Attack(
goal_function, constraints, transformation, search_method
)
def parse_model_from_args(args):
if args.model_from_file:
colored_model_name = textattack.shared.utils.color_text(
args.model_from_file, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading model and tokenizer from file: {colored_model_name}"
)
if ":" in args.model_from_file:
model_file, model_name, tokenizer_name = args.model_from_file.split(":")
else:
model_file, model_name, tokenizer_name = (
args.model_from_file,
"model",
"tokenizer",
)
try:
model_file = args.model_from_file.replace(".py", "").replace("/", ".")
model_module = importlib.import_module(model_file)
except:
raise ValueError(
f"Failed to import model or tokenizer from file {args.model_from_file}"
)
try:
model = getattr(model_module, model_name)
except AttributeError:
raise AttributeError(
f"``{model_name}`` not found in module {args.model_from_file}"
)
try:
tokenizer = getattr(model_module, tokenizer_name)
except AttributeError:
raise AttributeError(
f"``{tokenizer_name}`` not found in module {args.model_from_file}"
)
model = model.to(textattack.shared.utils.device)
setattr(model, "tokenizer", tokenizer)
elif (args.model in HUGGINGFACE_DATASET_BY_MODEL) or args.model_from_huggingface:
import transformers
model_name = (
HUGGINGFACE_DATASET_BY_MODEL[args.model][0]
if (args.model in HUGGINGFACE_DATASET_BY_MODEL)
else args.model_from_huggingface
)
if ":" in model_name:
model_class, model_name = model_name
model_class = eval(f"transformers.{model_class}")
else:
model_class, model_name = (
transformers.AutoModelForSequenceClassification,
model_name,
)
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
textattack.shared.logger.info(
f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
)
model = model_class.from_pretrained(model_name)
model = model.to(textattack.shared.utils.device)
try:
tokenizer = textattack.models.tokenizers.AutoTokenizer(model_name)
except OSError:
textattack.shared.logger.warn(
f"AutoTokenizer {args.model_from_huggingface} not found. Defaulting to `bert-base-uncased`"
)
tokenizer = textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
setattr(model, "tokenizer", tokenizer)
else:
if args.model in TEXTATTACK_DATASET_BY_MODEL:
model_path, _ = TEXTATTACK_DATASET_BY_MODEL[args.model]
model = textattack.shared.utils.load_textattack_model_from_path(
args.model, model_path
)
else:
raise ValueError(f"Error: unsupported TextAttack model {args.model}")
return model
def parse_dataset_from_args(args):
if args.dataset_from_file:
textattack.shared.logger.info(
f"Loading model and tokenizer from file: {args.model_from_file}"
)
if ":" in args.dataset_from_file:
dataset_file, dataset_name = args.dataset_from_file.split(":")
else:
dataset_file, dataset_name = args.dataset_from_file, "dataset"
try:
dataset_file = dataset_file.replace(".py", "").replace("/", ".")
dataset_module = importlib.import_module(dataset_file)
except:
raise ValueError(
f"Failed to import dataset from file {args.dataset_from_file}"
)
try:
dataset = getattr(dataset_module, dataset_name)
except AttributeError:
raise AttributeError(
f"``dataset`` not found in module {args.dataset_from_file}"
)
elif args.dataset_from_nlp:
dataset_args = args.dataset_from_nlp
if ":" in dataset_args:
dataset_args = dataset_args.split(":")
dataset = textattack.datasets.HuggingFaceNLPDataset(
*dataset_args, shuffle=args.shuffle
)
dataset.examples = dataset.examples[args.num_examples_offset :]
else:
raise ValueError("Must supply pretrained model or dataset")
return dataset
def parse_logger_from_args(args):
# Create logger
attack_log_manager = textattack.loggers.AttackLogManager()
# Set default output directory to `textattack/outputs`.
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "outputs"
)
args.out_dir = os.path.normpath(outputs_dir)
# Output file.
out_time = int(time.time() * 1000) # Output file
outfile_name = "attack-{}.txt".format(out_time)
attack_log_manager.add_output_file(os.path.join(args.out_dir, outfile_name))
# CSV
if args.enable_csv:
outfile_name = "attack-{}.csv".format(out_time)
color_method = None if args.enable_csv == "plain" else "file"
csv_path = os.path.join(args.out_dir, outfile_name)
attack_log_manager.add_output_csv(csv_path, color_method)
print("Logging to CSV at path {}.".format(csv_path))
# Visdom
if args.enable_visdom:
attack_log_manager.enable_visdom()
# Weights & Biases
if args.enable_wandb:
attack_log_manager.enable_wandb()
# Stdout
if not args.disable_stdout:
attack_log_manager.enable_stdout()
return attack_log_manager
def parse_checkpoint_from_args(args):
file_name = os.path.basename(args.checkpoint_file)
if file_name.lower() == "latest":
dir_path = os.path.dirname(args.checkpoint_file)
chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith(".ta.chkpt")]
assert chkpt_file_names, "Checkpoint directory is empty"
timestamps = [int(f.replace(".ta.chkpt", "")) for f in chkpt_file_names]
latest_file = str(max(timestamps)) + ".ta.chkpt"
checkpoint_path = os.path.join(dir_path, latest_file)
else:
checkpoint_path = args.checkpoint_file
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
set_seed(checkpoint.args.random_seed)
return checkpoint
def default_checkpoint_dir():
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(
current_dir, os.pardir, os.pardir, os.pardir, "checkpoints"
)
return os.path.normpath(checkpoints_dir)
def merge_checkpoint_args(saved_args, cmdline_args):
""" Merge previously saved arguments for checkpoint and newly entered arguments """
args = copy.deepcopy(saved_args)
# Newly entered arguments take precedence
args.checkpoint_resume = cmdline_args.checkpoint_resume
args.parallel = cmdline_args.parallel
# If set, replace
if cmdline_args.checkpoint_dir:
args.checkpoint_dir = cmdline_args.checkpoint_dir
if cmdline_args.checkpoint_interval:
args.checkpoint_interval = cmdline_args.checkpoint_interval
return args