mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
changes requested by eli
This commit is contained in:
@@ -32,10 +32,20 @@ class Checkpoint:
|
||||
)
|
||||
|
||||
args_lines = []
|
||||
for key in self.args.__dict__:
|
||||
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
|
||||
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
|
||||
recipe_set = True if 'recipe' in self.args.__dict__ and self.args.__dict__['recipe'] else False
|
||||
mutually_exclusive_args = ['search', 'transformation', 'constraints', 'recipe']
|
||||
if recipe_set:
|
||||
args_lines.append(utils.add_indent(f'(recipe): {self.args.__dict__["recipe"]}', 2))
|
||||
else:
|
||||
args_lines.append(utils.add_indent(f'(search): {self.args.__dict__["search"]}', 2))
|
||||
args_lines.append(utils.add_indent(f'(transformation): {self.args.__dict__["transformation"]}', 2))
|
||||
args_lines.append(utils.add_indent(f'(constraints): {self.args.__dict__["constraints"]}', 2))
|
||||
|
||||
for key in self.args.__dict__:
|
||||
if key not in mutually_exclusive_args:
|
||||
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
|
||||
|
||||
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
|
||||
lines.append(utils.add_indent(f'(Args): {args_str}', 2))
|
||||
|
||||
attack_logger_lines = []
|
||||
|
||||
@@ -206,10 +206,10 @@ def get_args():
|
||||
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
|
||||
|
||||
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
|
||||
help='A directory to save/load checkpoint files.')
|
||||
help='The directory to save checkpoint files.')
|
||||
|
||||
parser.add_argument('--checkpoint-interval', required=False, type=int,
|
||||
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
|
||||
help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.')
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
|
||||
@@ -226,14 +226,15 @@ def get_args():
|
||||
resume_parser = argparse.ArgumentParser(
|
||||
description='A commandline parser for TextAttack',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=False, default='latest',
|
||||
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint.')
|
||||
resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=True,
|
||||
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'\
|
||||
'recover latest checkpoint from either current path or specified directory.')
|
||||
|
||||
resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=default_checkpoint_dir(),
|
||||
help='A directory to save/load checkpoint files.')
|
||||
resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=None,
|
||||
help='The directory to save checkpoint files. If not set, use directory from recovered arguments.')
|
||||
|
||||
resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int,
|
||||
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
|
||||
help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.')
|
||||
|
||||
resume_parser.add_argument('--parallel', action='store_true', default=False,
|
||||
help='Run attack using multiple GPUs.')
|
||||
@@ -310,7 +311,7 @@ def parse_recipe_from_args(model, args):
|
||||
elif args.recipe in RECIPE_NAMES:
|
||||
recipe = eval(f'{RECIPE_NAMES[args.recipe]}(model)')
|
||||
else:
|
||||
raise ValueError('Invalid recipe {args.recipe}')
|
||||
raise ValueError(f'Invalid recipe {args.recipe}')
|
||||
return recipe
|
||||
|
||||
def parse_goal_function_and_attack_from_args(args):
|
||||
@@ -377,14 +378,16 @@ def parse_logger_from_args(args):# Create logger
|
||||
return attack_log_manager
|
||||
|
||||
def parse_checkpoint_from_args(args):
|
||||
if args.checkpoint_file.lower() == 'latest':
|
||||
chkpt_file_names = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')]
|
||||
file_name = os.path.basename(args.checkpoint_file)
|
||||
if file_name.lower() == 'latest':
|
||||
dir_path = os.path.dirname(args.checkpoint_file)
|
||||
chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith('.ta.chkpt')]
|
||||
assert chkpt_file_names, "Checkpoint directory is empty"
|
||||
timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names]
|
||||
latest_file = str(max(timestamps)) + '.ta.chkpt'
|
||||
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file)
|
||||
checkpoint_path = os.path.join(dir_path, latest_file)
|
||||
else:
|
||||
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_file)
|
||||
checkpoint_path = args.checkpoint_file
|
||||
|
||||
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
|
||||
set_seed(checkpoint.args.random_seed)
|
||||
@@ -402,8 +405,9 @@ def merge_checkpoint_args(saved_args, cmdline_args):
|
||||
# Newly entered arguments take precedence
|
||||
args.checkpoint_resume = cmdline_args.checkpoint_resume
|
||||
args.parallel = cmdline_args.parallel
|
||||
args.checkpoint_dir = cmdline_args.checkpoint_dir
|
||||
# If set, we replace
|
||||
# If set, replace
|
||||
if cmdline_args.checkpoint_dir:
|
||||
args.checkpoint_dir = cmdline_args.checkpoint_dir
|
||||
if cmdline_args.checkpoint_interval:
|
||||
args.checkpoint_interval = cmdline_args.checkpoint_interval
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def run(args):
|
||||
num_examples_offset = resume_checkpoint.dataset_offset
|
||||
num_remaining_examples = resume_checkpoint.num_remaining_attacks
|
||||
num_total_examples = args.num_examples
|
||||
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
|
||||
logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime))
|
||||
print(resume_checkpoint, '\n')
|
||||
else:
|
||||
num_examples_offset = args.num_examples_offset
|
||||
|
||||
@@ -29,7 +29,7 @@ def run(args):
|
||||
args = merge_checkpoint_args(resume_checkpoint.args, args)
|
||||
num_examples_offset = resume_checkpoint.dataset_offset
|
||||
num_examples = resume_checkpoint.num_remaining_attacks
|
||||
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
|
||||
logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime))
|
||||
print(resume_checkpoint, '\n')
|
||||
else:
|
||||
num_examples_offset = args.num_examples_offset
|
||||
|
||||
Reference in New Issue
Block a user