diff --git a/textattack/shared/checkpoint.py b/textattack/shared/checkpoint.py index 0a0be0c5..ee1d731e 100644 --- a/textattack/shared/checkpoint.py +++ b/textattack/shared/checkpoint.py @@ -32,10 +32,20 @@ class Checkpoint: ) args_lines = [] + recipe_set = True if 'recipe' in self.args.__dict__ and self.args.__dict__['recipe'] else False + mutually_exclusive_args = ['search', 'transformation', 'constraints', 'recipe'] + if recipe_set: + args_lines.append(utils.add_indent(f'(recipe): {self.args.__dict__["recipe"]}', 2)) + else: + args_lines.append(utils.add_indent(f'(search): {self.args.__dict__["search"]}', 2)) + args_lines.append(utils.add_indent(f'(transformation): {self.args.__dict__["transformation"]}', 2)) + args_lines.append(utils.add_indent(f'(constraints): {self.args.__dict__["constraints"]}', 2)) + for key in self.args.__dict__: - args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2)) + if key not in mutually_exclusive_args: + args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2)) + args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2) - lines.append(utils.add_indent(f'(Args): {args_str}', 2)) attack_logger_lines = [] diff --git a/textattack/shared/scripts/run_attack_args_helper.py b/textattack/shared/scripts/run_attack_args_helper.py index 46bf3ca7..50855094 100644 --- a/textattack/shared/scripts/run_attack_args_helper.py +++ b/textattack/shared/scripts/run_attack_args_helper.py @@ -206,10 +206,10 @@ def get_args(): parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK')) parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(), - help='A directory to save/load checkpoint files.') + help='The directory to save checkpoint files.') parser.add_argument('--checkpoint-interval', required=False, type=int, - help='Interval for saving checkpoints. If not set, no checkpoints will be saved.') + help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.') attack_group = parser.add_mutually_exclusive_group(required=False) @@ -226,14 +226,15 @@ def get_args(): resume_parser = argparse.ArgumentParser( description='A commandline parser for TextAttack', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=False, default='latest', - help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint.') + resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=True, + help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'\ + 'recover latest checkpoint from either current path or specified directory.') - resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=default_checkpoint_dir(), - help='A directory to save/load checkpoint files.') + resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=None, + help='The directory to save checkpoint files. If not set, use directory from recovered arguments.') resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int, - help='Interval for saving checkpoints. If not set, no checkpoints will be saved.') + help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.') resume_parser.add_argument('--parallel', action='store_true', default=False, help='Run attack using multiple GPUs.') @@ -310,7 +311,7 @@ def parse_recipe_from_args(model, args): elif args.recipe in RECIPE_NAMES: recipe = eval(f'{RECIPE_NAMES[args.recipe]}(model)') else: - raise ValueError('Invalid recipe {args.recipe}') + raise ValueError(f'Invalid recipe {args.recipe}') return recipe def parse_goal_function_and_attack_from_args(args): @@ -377,14 +378,16 @@ def parse_logger_from_args(args):# Create logger return attack_log_manager def parse_checkpoint_from_args(args): - if args.checkpoint_file.lower() == 'latest': - chkpt_file_names = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')] + file_name = os.path.basename(args.checkpoint_file) + if file_name.lower() == 'latest': + dir_path = os.path.dirname(args.checkpoint_file) + chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith('.ta.chkpt')] assert chkpt_file_names, "Checkpoint directory is empty" timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names] latest_file = str(max(timestamps)) + '.ta.chkpt' - checkpoint_path = os.path.join(args.checkpoint_dir, latest_file) + checkpoint_path = os.path.join(dir_path, latest_file) else: - checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_file) + checkpoint_path = args.checkpoint_file checkpoint = textattack.shared.Checkpoint.load(checkpoint_path) set_seed(checkpoint.args.random_seed) @@ -402,8 +405,9 @@ def merge_checkpoint_args(saved_args, cmdline_args): # Newly entered arguments take precedence args.checkpoint_resume = cmdline_args.checkpoint_resume args.parallel = cmdline_args.parallel - args.checkpoint_dir = cmdline_args.checkpoint_dir - # If set, we replace + # If set, replace + if cmdline_args.checkpoint_dir: + args.checkpoint_dir = cmdline_args.checkpoint_dir if cmdline_args.checkpoint_interval: args.checkpoint_interval = cmdline_args.checkpoint_interval diff --git a/textattack/shared/scripts/run_attack_parallel.py b/textattack/shared/scripts/run_attack_parallel.py index 9a3e68c6..4f78a8f0 100644 --- a/textattack/shared/scripts/run_attack_parallel.py +++ b/textattack/shared/scripts/run_attack_parallel.py @@ -51,7 +51,7 @@ def run(args): num_examples_offset = resume_checkpoint.dataset_offset num_remaining_examples = resume_checkpoint.num_remaining_attacks num_total_examples = args.num_examples - logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime)) + logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime)) print(resume_checkpoint, '\n') else: num_examples_offset = args.num_examples_offset diff --git a/textattack/shared/scripts/run_attack_single_threaded.py b/textattack/shared/scripts/run_attack_single_threaded.py index c052e8da..4f864d35 100644 --- a/textattack/shared/scripts/run_attack_single_threaded.py +++ b/textattack/shared/scripts/run_attack_single_threaded.py @@ -29,7 +29,7 @@ def run(args): args = merge_checkpoint_args(resume_checkpoint.args, args) num_examples_offset = resume_checkpoint.dataset_offset num_examples = resume_checkpoint.num_remaining_attacks - logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime)) + logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime)) print(resume_checkpoint, '\n') else: num_examples_offset = args.num_examples_offset