mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
prelimary working checkpoint for file_logger and csv_logger
This commit is contained in:
@@ -6,9 +6,12 @@ import textattack
|
||||
import time
|
||||
import tqdm
|
||||
import os
|
||||
import datetime
|
||||
|
||||
from .run_attack_args_helper import *
|
||||
|
||||
logger = textattack.shared.utils.get_logger()
|
||||
|
||||
def run(args):
|
||||
# Only use one GPU, if we have one.
|
||||
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
||||
@@ -19,6 +22,21 @@ def run(args):
|
||||
# Cache TensorFlow Hub models here, if not otherwise specified.
|
||||
if 'TFHUB_CACHE_DIR' not in os.environ:
|
||||
os.environ['TFHUB_CACHE_DIR'] = os.path.expanduser('~/.cache/tensorflow-hub')
|
||||
|
||||
if args.checkpoint_resume:
|
||||
# Override current args with checkpoint args
|
||||
resume_checkpoint = parse_checkpoint_from_args(args)
|
||||
args = resume_checkpoint.args
|
||||
num_examples_offset = resume_checkpoint.dataset_offset
|
||||
num_examples = resume_checkpoint.num_remaining_attacks
|
||||
checkpoint_resume = True
|
||||
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
|
||||
print(resume_checkpoint, '\n')
|
||||
|
||||
else:
|
||||
num_examples_offset = args.num_examples_offset
|
||||
num_examples = args.num_examples
|
||||
checkpoint_resume = False
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -27,7 +45,10 @@ def run(args):
|
||||
print(attack, '\n')
|
||||
|
||||
# Logger
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
if checkpoint_resume:
|
||||
attack_log_manager = resume_checkpoint.log_manager
|
||||
else:
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
load_time = time.time()
|
||||
print(f'Load time: {load_time - start_time}s')
|
||||
@@ -57,16 +78,21 @@ def run(args):
|
||||
else:
|
||||
# Not interactive? Use default dataset.
|
||||
if args.model in DATASET_BY_MODEL:
|
||||
data = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset)
|
||||
data = DATASET_BY_MODEL[args.model](offset=num_examples_offset)
|
||||
else:
|
||||
raise ValueError(f'Error: unsupported model {args.model}')
|
||||
|
||||
pbar = tqdm.tqdm(total=args.num_examples, smoothing=0)
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_successes = 0
|
||||
pbar = tqdm.tqdm(total=num_examples, smoothing=0)
|
||||
if checkpoint_resume:
|
||||
num_results = resume_checkpoint.results_count
|
||||
num_failures = resume_checkpoint.num_failed_attacks
|
||||
num_successes = resume_checkpoint.num_successful_attacks
|
||||
else:
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_successes = 0
|
||||
for result in attack.attack_dataset(data,
|
||||
num_examples=args.num_examples,
|
||||
num_examples=num_examples,
|
||||
shuffle=args.shuffle,
|
||||
attack_n=args.attack_n):
|
||||
attack_log_manager.log_result(result)
|
||||
@@ -80,6 +106,16 @@ def run(args):
|
||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||
num_failures += 1
|
||||
pbar.set_description('[Succeeded / Failed / Total] {} / {} / {}'.format(num_successes, num_failures, num_results))
|
||||
|
||||
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
|
||||
chkpt_time = time.time()
|
||||
date_time = datetime.datetime.fromtimestamp(chkpt_time).strftime('%Y-%m-%d %H:%M:%S')
|
||||
print('\n\n' + '=' * 100)
|
||||
logger.info('Saving checkpoint at {} after {} attacks.'.format(date_time, num_results))
|
||||
print('=' * 100 + '\n')
|
||||
checkpoint = textattack.shared.CheckPoint(chkpt_time, args, attack_log_manager)
|
||||
checkpoint.save()
|
||||
|
||||
pbar.close()
|
||||
print()
|
||||
# Enable summary stdout
|
||||
|
||||
Reference in New Issue
Block a user