1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

prelimary working checkpoint for file_logger and csv_logger

This commit is contained in:
Jin Yong Yoo
2020-05-15 07:22:33 -04:00
parent 685d0a0e56
commit 740756640f
9 changed files with 263 additions and 16 deletions

View File

@@ -6,6 +6,7 @@ import sys
import textattack
import time
import torch
import pickle
RECIPE_NAMES = {
'alzantot': 'textattack.attack_recipes.Alzantot2018',
@@ -196,6 +197,15 @@ def get_args():
def str_to_int(s): return sum((ord(c) for c in s))
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
parser.add_argument('--checkpoint-resume', required=False, type=str,
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint. Overrides any oth')
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
help='A directory to save/load checkpoint files.')
parser.add_argument('--checkpoint-interval', required=False, type=int,
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
attack_group = parser.add_mutually_exclusive_group(required=False)
@@ -210,6 +220,10 @@ def get_args():
command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments.
args = parser.parse_args(command_line_args)
if args.checkpoint_interval and args.shuffle:
# Not allowed b/c we cannot recover order of shuffled data
raise ValueError('Cannot use `--checkpoint-interval` with `--shuffle=True`')
set_seed(args.random_seed)
@@ -308,7 +322,7 @@ def parse_logger_from_args(args):# Create logger
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
args.out_dir = outputs_dir
args.out_dir = os.path.normpath(outputs_dir)
# Output file.
out_time = int(time.time()*1000) # Output file
@@ -335,3 +349,24 @@ def parse_logger_from_args(args):# Create logger
if not args.disable_stdout:
attack_log_manager.enable_stdout()
return attack_log_manager
def parse_checkpoint_from_args(args):
if args.checkpoint_resume:
if args.checkpoint_resume.lower() == 'latest':
chkpt_files = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')]
latest_file = max(chkpt_files)
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file)
else:
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_resume)
checkpoint = textattack.shared.CheckPoint.load(checkpoint_path)
else:
checkpoint = None
return checkpoint
def default_checkpoint_dir():
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'checkpoints')
return os.path.normpath(checkpoints_dir)