mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
prelimary working checkpoint for file_logger and csv_logger
This commit is contained in:
119
textattack/shared/checkpoint.py
Normal file
119
textattack/shared/checkpoint.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import pickle
|
||||
import datetime
|
||||
from textattack.shared import utils
|
||||
from textattack.attack_results import SuccessfulAttackResult, FailedAttackResult, SkippedAttackResult
|
||||
|
||||
class CheckPoint:
|
||||
""" An object that stores necessary information for saving and loading checkpoints
|
||||
|
||||
Args:
|
||||
time (float): epoch time representing when checkpoint was made
|
||||
args: command line arguments of the original attack
|
||||
log_manager (AttackLogManager)
|
||||
"""
|
||||
def __init__(self, time, args, log_manager):
|
||||
self.time = time
|
||||
self.args = args
|
||||
self.log_manager = log_manager
|
||||
|
||||
def __repr__(self):
|
||||
main_str = 'Checkpoint('
|
||||
lines = []
|
||||
lines.append(
|
||||
utils.add_indent(f'(Time): {self.datetime}', 2)
|
||||
)
|
||||
|
||||
args_lines = []
|
||||
for key in self.args.__dict__:
|
||||
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
|
||||
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
|
||||
|
||||
lines.append(utils.add_indent(f'(Args): {args_str}', 2))
|
||||
|
||||
attack_logger_lines = []
|
||||
attack_logger_lines.append(utils.add_indent(
|
||||
f'(Number of attacks performed: {self.results_count}', 2
|
||||
))
|
||||
attack_logger_lines.append(utils.add_indent(
|
||||
f'(Number of successful attacks: {self.num_successful_attacks}', 2
|
||||
))
|
||||
attack_logger_lines.append(utils.add_indent(
|
||||
f'(Number of failed attacks: {self.num_failed_attacks}', 2
|
||||
))
|
||||
attack_logger_lines.append(utils.add_indent(
|
||||
f'(Number of skipped attacks: {self.num_skipped_attacks}', 2
|
||||
))
|
||||
attack_logger_str = utils.add_indent('\n' + '\n'.join(attack_logger_lines), 2)
|
||||
lines.append(utils.add_indent(f'(Previous attack summary): {attack_logger_str}', 2))
|
||||
|
||||
main_str += '\n ' + '\n '.join(lines) + '\n'
|
||||
main_str += ')'
|
||||
return main_str
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@property
|
||||
def results_count(self):
|
||||
""" Return number of attacks made so far """
|
||||
return len(self.log_manager.results)
|
||||
|
||||
@property
|
||||
def num_skipped_attacks(self):
|
||||
count = 0
|
||||
for r in self.log_manager.results:
|
||||
if isinstance(r, SkippedAttackResult):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def num_failed_attacks(self):
|
||||
count = 0
|
||||
for r in self.log_manager.results:
|
||||
if isinstance(r, FailedAttackResult):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def num_successful_attacks(self):
|
||||
count = 0
|
||||
for r in self.log_manager.results:
|
||||
if isinstance(r, SuccessfulAttackResult):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def num_remaining_attacks(self):
|
||||
if self.args.attack_n:
|
||||
non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks
|
||||
count = self.args.num_examples - non_skipped_attacks
|
||||
else:
|
||||
count = self.args.num_examples - self.results_count
|
||||
return count
|
||||
|
||||
@property
|
||||
def dataset_offset(self):
|
||||
""" Calculate offset into the dataset to start from """
|
||||
# Original offset + # of results processed so far
|
||||
return self.args.num_examples_offset + self.results_count
|
||||
|
||||
@property
|
||||
def datetime(self):
|
||||
return datetime.datetime.fromtimestamp(self.time).strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
def save(self):
|
||||
file_name = "{}.ta.chkpt".format(int(self.time*1000))
|
||||
if not os.path.exists(self.args.checkpoint_dir):
|
||||
os.makedirs(self.args.checkpoint_dir)
|
||||
path = os.path.join(self.args.checkpoint_dir, file_name)
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(self, path):
|
||||
with open(path, 'rb') as f:
|
||||
checkpoint = pickle.load(f)
|
||||
assert isinstance(checkpoint, CheckPoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
Reference in New Issue
Block a user