mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
prelimary working checkpoint for file_logger and csv_logger
This commit is contained in:
@@ -6,6 +6,7 @@ import sys
|
||||
import textattack
|
||||
import time
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
RECIPE_NAMES = {
|
||||
'alzantot': 'textattack.attack_recipes.Alzantot2018',
|
||||
@@ -196,6 +197,15 @@ def get_args():
|
||||
|
||||
def str_to_int(s): return sum((ord(c) for c in s))
|
||||
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
|
||||
|
||||
parser.add_argument('--checkpoint-resume', required=False, type=str,
|
||||
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint. Overrides any oth')
|
||||
|
||||
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
|
||||
help='A directory to save/load checkpoint files.')
|
||||
|
||||
parser.add_argument('--checkpoint-interval', required=False, type=int,
|
||||
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
|
||||
@@ -210,6 +220,10 @@ def get_args():
|
||||
|
||||
command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments.
|
||||
args = parser.parse_args(command_line_args)
|
||||
|
||||
if args.checkpoint_interval and args.shuffle:
|
||||
# Not allowed b/c we cannot recover order of shuffled data
|
||||
raise ValueError('Cannot use `--checkpoint-interval` with `--shuffle=True`')
|
||||
|
||||
set_seed(args.random_seed)
|
||||
|
||||
@@ -308,7 +322,7 @@ def parse_logger_from_args(args):# Create logger
|
||||
if not args.out_dir:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
|
||||
args.out_dir = outputs_dir
|
||||
args.out_dir = os.path.normpath(outputs_dir)
|
||||
|
||||
# Output file.
|
||||
out_time = int(time.time()*1000) # Output file
|
||||
@@ -335,3 +349,24 @@ def parse_logger_from_args(args):# Create logger
|
||||
if not args.disable_stdout:
|
||||
attack_log_manager.enable_stdout()
|
||||
return attack_log_manager
|
||||
|
||||
def parse_checkpoint_from_args(args):
|
||||
if args.checkpoint_resume:
|
||||
if args.checkpoint_resume.lower() == 'latest':
|
||||
chkpt_files = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')]
|
||||
latest_file = max(chkpt_files)
|
||||
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file)
|
||||
else:
|
||||
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_resume)
|
||||
|
||||
checkpoint = textattack.shared.CheckPoint.load(checkpoint_path)
|
||||
else:
|
||||
checkpoint = None
|
||||
|
||||
return checkpoint
|
||||
|
||||
def default_checkpoint_dir():
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
checkpoints_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'checkpoints')
|
||||
return os.path.normpath(checkpoints_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user