mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
changes requested by eli
This commit is contained in:
@@ -32,10 +32,20 @@ class Checkpoint:
|
|||||||
)
|
)
|
||||||
|
|
||||||
args_lines = []
|
args_lines = []
|
||||||
for key in self.args.__dict__:
|
recipe_set = True if 'recipe' in self.args.__dict__ and self.args.__dict__['recipe'] else False
|
||||||
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
|
mutually_exclusive_args = ['search', 'transformation', 'constraints', 'recipe']
|
||||||
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
|
if recipe_set:
|
||||||
|
args_lines.append(utils.add_indent(f'(recipe): {self.args.__dict__["recipe"]}', 2))
|
||||||
|
else:
|
||||||
|
args_lines.append(utils.add_indent(f'(search): {self.args.__dict__["search"]}', 2))
|
||||||
|
args_lines.append(utils.add_indent(f'(transformation): {self.args.__dict__["transformation"]}', 2))
|
||||||
|
args_lines.append(utils.add_indent(f'(constraints): {self.args.__dict__["constraints"]}', 2))
|
||||||
|
|
||||||
|
for key in self.args.__dict__:
|
||||||
|
if key not in mutually_exclusive_args:
|
||||||
|
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
|
||||||
|
|
||||||
|
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
|
||||||
lines.append(utils.add_indent(f'(Args): {args_str}', 2))
|
lines.append(utils.add_indent(f'(Args): {args_str}', 2))
|
||||||
|
|
||||||
attack_logger_lines = []
|
attack_logger_lines = []
|
||||||
|
|||||||
@@ -206,10 +206,10 @@ def get_args():
|
|||||||
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
|
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
|
||||||
|
|
||||||
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
|
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
|
||||||
help='A directory to save/load checkpoint files.')
|
help='The directory to save checkpoint files.')
|
||||||
|
|
||||||
parser.add_argument('--checkpoint-interval', required=False, type=int,
|
parser.add_argument('--checkpoint-interval', required=False, type=int,
|
||||||
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
|
help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.')
|
||||||
|
|
||||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||||
|
|
||||||
@@ -226,14 +226,15 @@ def get_args():
|
|||||||
resume_parser = argparse.ArgumentParser(
|
resume_parser = argparse.ArgumentParser(
|
||||||
description='A commandline parser for TextAttack',
|
description='A commandline parser for TextAttack',
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=False, default='latest',
|
resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=True,
|
||||||
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint.')
|
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'\
|
||||||
|
'recover latest checkpoint from either current path or specified directory.')
|
||||||
|
|
||||||
resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=default_checkpoint_dir(),
|
resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=None,
|
||||||
help='A directory to save/load checkpoint files.')
|
help='The directory to save checkpoint files. If not set, use directory from recovered arguments.')
|
||||||
|
|
||||||
resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int,
|
resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int,
|
||||||
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
|
help='If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.')
|
||||||
|
|
||||||
resume_parser.add_argument('--parallel', action='store_true', default=False,
|
resume_parser.add_argument('--parallel', action='store_true', default=False,
|
||||||
help='Run attack using multiple GPUs.')
|
help='Run attack using multiple GPUs.')
|
||||||
@@ -310,7 +311,7 @@ def parse_recipe_from_args(model, args):
|
|||||||
elif args.recipe in RECIPE_NAMES:
|
elif args.recipe in RECIPE_NAMES:
|
||||||
recipe = eval(f'{RECIPE_NAMES[args.recipe]}(model)')
|
recipe = eval(f'{RECIPE_NAMES[args.recipe]}(model)')
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid recipe {args.recipe}')
|
raise ValueError(f'Invalid recipe {args.recipe}')
|
||||||
return recipe
|
return recipe
|
||||||
|
|
||||||
def parse_goal_function_and_attack_from_args(args):
|
def parse_goal_function_and_attack_from_args(args):
|
||||||
@@ -377,14 +378,16 @@ def parse_logger_from_args(args):# Create logger
|
|||||||
return attack_log_manager
|
return attack_log_manager
|
||||||
|
|
||||||
def parse_checkpoint_from_args(args):
|
def parse_checkpoint_from_args(args):
|
||||||
if args.checkpoint_file.lower() == 'latest':
|
file_name = os.path.basename(args.checkpoint_file)
|
||||||
chkpt_file_names = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')]
|
if file_name.lower() == 'latest':
|
||||||
|
dir_path = os.path.dirname(args.checkpoint_file)
|
||||||
|
chkpt_file_names = [f for f in os.listdir(dir_path) if f.endswith('.ta.chkpt')]
|
||||||
assert chkpt_file_names, "Checkpoint directory is empty"
|
assert chkpt_file_names, "Checkpoint directory is empty"
|
||||||
timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names]
|
timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names]
|
||||||
latest_file = str(max(timestamps)) + '.ta.chkpt'
|
latest_file = str(max(timestamps)) + '.ta.chkpt'
|
||||||
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file)
|
checkpoint_path = os.path.join(dir_path, latest_file)
|
||||||
else:
|
else:
|
||||||
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_file)
|
checkpoint_path = args.checkpoint_file
|
||||||
|
|
||||||
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
|
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
|
||||||
set_seed(checkpoint.args.random_seed)
|
set_seed(checkpoint.args.random_seed)
|
||||||
@@ -402,8 +405,9 @@ def merge_checkpoint_args(saved_args, cmdline_args):
|
|||||||
# Newly entered arguments take precedence
|
# Newly entered arguments take precedence
|
||||||
args.checkpoint_resume = cmdline_args.checkpoint_resume
|
args.checkpoint_resume = cmdline_args.checkpoint_resume
|
||||||
args.parallel = cmdline_args.parallel
|
args.parallel = cmdline_args.parallel
|
||||||
args.checkpoint_dir = cmdline_args.checkpoint_dir
|
# If set, replace
|
||||||
# If set, we replace
|
if cmdline_args.checkpoint_dir:
|
||||||
|
args.checkpoint_dir = cmdline_args.checkpoint_dir
|
||||||
if cmdline_args.checkpoint_interval:
|
if cmdline_args.checkpoint_interval:
|
||||||
args.checkpoint_interval = cmdline_args.checkpoint_interval
|
args.checkpoint_interval = cmdline_args.checkpoint_interval
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def run(args):
|
|||||||
num_examples_offset = resume_checkpoint.dataset_offset
|
num_examples_offset = resume_checkpoint.dataset_offset
|
||||||
num_remaining_examples = resume_checkpoint.num_remaining_attacks
|
num_remaining_examples = resume_checkpoint.num_remaining_attacks
|
||||||
num_total_examples = args.num_examples
|
num_total_examples = args.num_examples
|
||||||
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
|
logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime))
|
||||||
print(resume_checkpoint, '\n')
|
print(resume_checkpoint, '\n')
|
||||||
else:
|
else:
|
||||||
num_examples_offset = args.num_examples_offset
|
num_examples_offset = args.num_examples_offset
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def run(args):
|
|||||||
args = merge_checkpoint_args(resume_checkpoint.args, args)
|
args = merge_checkpoint_args(resume_checkpoint.args, args)
|
||||||
num_examples_offset = resume_checkpoint.dataset_offset
|
num_examples_offset = resume_checkpoint.dataset_offset
|
||||||
num_examples = resume_checkpoint.num_remaining_attacks
|
num_examples = resume_checkpoint.num_remaining_attacks
|
||||||
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
|
logger.info('Recovered from checkpoint previously saved at {}'.format(resume_checkpoint.datetime))
|
||||||
print(resume_checkpoint, '\n')
|
print(resume_checkpoint, '\n')
|
||||||
else:
|
else:
|
||||||
num_examples_offset = args.num_examples_offset
|
num_examples_offset = args.num_examples_offset
|
||||||
|
|||||||
Reference in New Issue
Block a user