1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

address requested changes

This commit is contained in:
Jin Yong Yoo
2020-05-18 04:44:22 -04:00
parent 9d5ff072d9
commit c63add8dc2
7 changed files with 94 additions and 77 deletions

View File

@@ -45,10 +45,9 @@ def run(args):
if args.checkpoint_resume:
# Override current args with checkpoint args
resume_checkpoint = parse_checkpoint_from_args(args)
args = resume_checkpoint.args
args = merge_checkpoint_args(resume_checkpoint.args, args)
num_examples_offset = resume_checkpoint.dataset_offset
num_examples = resume_checkpoint.num_remaining_attack
checkpoint_resume = True
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
print(resume_checkpoint, '\n')
@@ -60,7 +59,7 @@ def run(args):
)
start_time = time.time()
if checkpoint_resume:
if args.checkpoint_resume:
attack_log_manager = resume_checkpoint.log_manager
else:
attack_log_manager = parse_logger_from_args(args)
@@ -88,7 +87,7 @@ def run(args):
(args, in_queue, out_queue)
)
# Log results asynchronously and update progress bar.
if checkpoint_resume:
if args.checkpoint_resume:
num_results = resume_checkpoint.results_count
num_failures = resume_checkpoint.num_failed_attacks
num_successes = resume_checkpoint.num_successful_attacks
@@ -115,13 +114,9 @@ def run(args):
in_queue.put((label, text))
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
chkpt_time = time.time()
date_time = datetime.datetime.fromtimestamp(chkpt_time).strftime('%Y-%m-%d %H:%M:%S')
print('\n\n' + '=' * 100)
logger.info('Saving checkpoint at {} after {} attacks.'.format(date_time, num_results))
print('=' * 100 + '\n')
checkpoint = textattack.shared.CheckPoint(chkpt_time, args, attack_log_manager)
checkpoint = textattack.shared.Checkpoint(chkpt_time, args, attack_log_manager)
checkpoint.save()
attack_log_manager.flush()
pbar.close()
print()