1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

changes requested by eli

This commit is contained in:
Jin Yong Yoo
2020-05-31 03:59:10 -04:00
parent 18ffe94b32
commit f73509a942
4 changed files with 32 additions and 18 deletions

View File

@@ -32,10 +32,20 @@ class Checkpoint:
) )
args_lines = [] args_lines = []
for key in self.args.__dict__: recipe_set = True if 'recipe' in self.args.__dict__ and self.args.__dict__['recipe'] else False
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2)) mutually_exclusive_args = ['search', 'transformation', 'constraints', 'recipe']
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2) if recipe_set:
args_lines.append(utils.add_indent(f'(recipe): {self.args.__dict__["recipe"]}', 2))
else:
args_lines.append(utils.add_indent(f'(search): {self.args.__dict__["search"]}', 2))
args_lines.append(utils.add_indent(f'(transformation): {self.args.__dict__["transformation"]}', 2))
args_lines.append(utils.add_indent(f'(constraints): {self.args.__dict__["constraints"]}', 2))
for key in self.args.__dict__:
if key not in mutually_exclusive_args:
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
lines.append(utils.add_indent(f'(Args): {args_str}', 2)) lines.append(utils.add_indent(f'(Args): {args_str}', 2))
attack_logger_lines = [] attack_logger_lines = []

View File

@@ -206,10 +206,10 @@ def get_args():
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK')) parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(), parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
help='A directory to save/load checkpoint files.') help='The directory to save checkpoint files.')
parser.add_argument('--checkpoint-interval', required=False, type=int, parser.add_argument('--checkpoint-interval', required=False, type=int,
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.') help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.')
attack_group = parser.add_mutually_exclusive_group(required=False) attack_group = parser.add_mutually_exclusive_group(required=False)
@@ -226,14 +226,15 @@ def get_args():
resume_parser = argparse.ArgumentParser( resume_parser = argparse.ArgumentParser(
description='A commandline parser for TextAttack', description='A commandline parser for TextAttack',
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)
resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=False, default='latest', resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=True,
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint.') help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'\
'recover latest checkpoint from either current path or specified directory.')
resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=default_checkpoint_dir(), resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=None,
help='A directory to save/load checkpoint files.') help='The directory to save checkpoint files. If not set, use directory from recovered arguments.')
resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int, resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int,
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.') help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.')
resume_parser.add_argument('--parallel', action='store_true', default=False, resume_parser.add_argument('--parallel', action='store_true', default=False,
help='Run attack using multiple GPUs.') help='Run attack using multiple GPUs.')
@@ -310,7 +311,7 @@ def parse_recipe_from_args(model, args):
elif args.recipe in RECIPE_NAMES: elif args.recipe in RECIPE_NAMES:
recipe = eval(f'{RECIPE_NAMES[args.recipe]}(model)') recipe = eval(f'{RECIPE_NAMES[args.recipe]}(model)')
else: else:
raise ValueError('Invalid recipe {args.recipe}') raise ValueError(f'Invalid recipe {args.recipe}')
return recipe return recipe
def parse_goal_function_and_attack_from_args(args): def parse_goal_function_and_attack_from_args(args):
@@ -377,14 +378,16 @@ def parse_logger_from_args(args):# Create logger
return attack_log_manager return attack_log_manager
def parse_checkpoint_from_args(args): def parse_checkpoint_from_args(args):
if args.checkpoint_file.lower() == 'latest': file_name = os.path.basename(args.checkpoint_file)
chkpt_file_names = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')] if file_name.lower() == 'latest':
dir_path = os.path.dirname(args.checkpoint_file)
chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith('.ta.chkpt')]
assert chkpt_file_names, "Checkpoint directory is empty" assert chkpt_file_names, "Checkpoint directory is empty"
timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names] timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names]
latest_file = str(max(timestamps)) + '.ta.chkpt' latest_file = str(max(timestamps)) + '.ta.chkpt'
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file) checkpoint_path = os.path.join(dir_path, latest_file)
else: else:
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_file) checkpoint_path = args.checkpoint_file
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path) checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
set_seed(checkpoint.args.random_seed) set_seed(checkpoint.args.random_seed)
@@ -402,8 +405,9 @@ def merge_checkpoint_args(saved_args, cmdline_args):
# Newly entered arguments take precedence # Newly entered arguments take precedence
args.checkpoint_resume = cmdline_args.checkpoint_resume args.checkpoint_resume = cmdline_args.checkpoint_resume
args.parallel = cmdline_args.parallel args.parallel = cmdline_args.parallel
args.checkpoint_dir = cmdline_args.checkpoint_dir # If set, replace
# If set, we replace if cmdline_args.checkpoint_dir:
args.checkpoint_dir = cmdline_args.checkpoint_dir
if cmdline_args.checkpoint_interval: if cmdline_args.checkpoint_interval:
args.checkpoint_interval = cmdline_args.checkpoint_interval args.checkpoint_interval = cmdline_args.checkpoint_interval

View File

@@ -51,7 +51,7 @@ def run(args):
num_examples_offset = resume_checkpoint.dataset_offset num_examples_offset = resume_checkpoint.dataset_offset
num_remaining_examples = resume_checkpoint.num_remaining_attacks num_remaining_examples = resume_checkpoint.num_remaining_attacks
num_total_examples = args.num_examples num_total_examples = args.num_examples
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime)) logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime))
print(resume_checkpoint, '\n') print(resume_checkpoint, '\n')
else: else:
num_examples_offset = args.num_examples_offset num_examples_offset = args.num_examples_offset

View File

@@ -29,7 +29,7 @@ def run(args):
args = merge_checkpoint_args(resume_checkpoint.args, args) args = merge_checkpoint_args(resume_checkpoint.args, args)
num_examples_offset = resume_checkpoint.dataset_offset num_examples_offset = resume_checkpoint.dataset_offset
num_examples = resume_checkpoint.num_remaining_attacks num_examples = resume_checkpoint.num_remaining_attacks
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime)) logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime))
print(resume_checkpoint, '\n') print(resume_checkpoint, '\n')
else: else:
num_examples_offset = args.num_examples_offset num_examples_offset = args.num_examples_offset