1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00

prelimary working checkpoint for file_logger and csv_logger

This commit is contained in:
Jin Yong Yoo
2020-05-15 07:22:33 -04:00
parent 685d0a0e56
commit 740756640f
9 changed files with 263 additions and 16 deletions

3
.gitignore vendored
View File

@@ -34,3 +34,6 @@ dist/
# Weights & Biases outputs
wandb/
# checkpoints
checkpoints/

View File

@@ -18,6 +18,9 @@ class GoalFunctionResult:
if isinstance(self.score, torch.Tensor):
self.score = self.score.item()
if isinstance(self.succeeded, torch.Tensor):
self.succeeded = self.succeeded.item()
def get_text_color_input(self):
""" A string representing the color this result's changed
portion should be if it represents the original input.

View File

@@ -1,5 +1,6 @@
import os
import sys
import copy
import terminaltables
from .logger import Logger
@@ -7,6 +8,7 @@ from .logger import Logger
class FileLogger(Logger):
def __init__(self, filename='', stdout=False):
self.stdout = stdout
self.filename = filename
if stdout:
self.fout = sys.stdout
elif isinstance(filename, str):
@@ -18,6 +20,21 @@ class FileLogger(Logger):
self.fout = filename
self.num_results = 0
def __getstate__(self):
# Temporarily save file handle b/c we can't copy it
tmp = self.fout
self.fout = None
state = copy.deepcopy(self.__dict__)
self.fout = tmp
return state
def __setstate__(self, state):
self.__dict__ = state
if self.stdout:
self.fout = sys.stdout
else:
self.fout = open(self.filename, 'a')
def log_attack_result(self, result):
self.num_results += 1
color_method = 'stdout' if self.stdout else 'file'

View File

@@ -140,6 +140,10 @@ class Attack:
if shuffle:
random.shuffle(dataset.examples)
if num_examples <= 0:
return
yield
for text, ground_truth_output in dataset:
tokenized_text = TokenizedText(text, self.tokenizer)
goal_function_result = self.goal_function.get_result(tokenized_text, ground_truth_output)

View File

@@ -5,3 +5,4 @@ from . import validators
from .tokenized_text import TokenizedText
from .word_embedding import WordEmbedding
from .checkpoint import CheckPoint

View File

@@ -0,0 +1,119 @@
import os
import pickle
import datetime
from textattack.shared import utils
from textattack.attack_results import SuccessfulAttackResult, FailedAttackResult, SkippedAttackResult
class CheckPoint:
""" An object that stores necessary information for saving and loading checkpoints
Args:
time (float): epoch time representing when checkpoint was made
args: command line arguments of the original attack
log_manager (AttackLogManager)
"""
def __init__(self, time, args, log_manager):
self.time = time
self.args = args
self.log_manager = log_manager
def __repr__(self):
main_str = 'Checkpoint('
lines = []
lines.append(
utils.add_indent(f'(Time): {self.datetime}', 2)
)
args_lines = []
for key in self.args.__dict__:
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
lines.append(utils.add_indent(f'(Args): {args_str}', 2))
attack_logger_lines = []
attack_logger_lines.append(utils.add_indent(
f'(Number of attacks performed: {self.results_count}', 2
))
attack_logger_lines.append(utils.add_indent(
f'(Number of successful attacks: {self.num_successful_attacks}', 2
))
attack_logger_lines.append(utils.add_indent(
f'(Number of failed attacks: {self.num_failed_attacks}', 2
))
attack_logger_lines.append(utils.add_indent(
f'(Number of skipped attacks: {self.num_skipped_attacks}', 2
))
attack_logger_str = utils.add_indent('\n' + '\n'.join(attack_logger_lines), 2)
lines.append(utils.add_indent(f'(Previous attack summary): {attack_logger_str}', 2))
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
__str__ = __repr__
@property
def results_count(self):
""" Return number of attacks made so far """
return len(self.log_manager.results)
@property
def num_skipped_attacks(self):
count = 0
for r in self.log_manager.results:
if isinstance(r, SkippedAttackResult):
count += 1
return count
@property
def num_failed_attacks(self):
count = 0
for r in self.log_manager.results:
if isinstance(r, FailedAttackResult):
count += 1
return count
@property
def num_successful_attacks(self):
count = 0
for r in self.log_manager.results:
if isinstance(r, SuccessfulAttackResult):
count += 1
return count
@property
def num_remaining_attacks(self):
if self.args.attack_n:
non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks
count = self.args.num_examples - non_skipped_attacks
else:
count = self.args.num_examples - self.results_count
return count
@property
def dataset_offset(self):
""" Calculate offset into the dataset to start from """
# Original offset + # of results processed so far
return self.args.num_examples_offset + self.results_count
@property
def datetime(self):
return datetime.datetime.fromtimestamp(self.time).strftime('%Y-%m-%d %H:%M:%S')
def save(self):
file_name = "{}.ta.chkpt".format(int(self.time*1000))
if not os.path.exists(self.args.checkpoint_dir):
os.makedirs(self.args.checkpoint_dir)
path = os.path.join(self.args.checkpoint_dir, file_name)
with open(path, 'wb') as f:
pickle.dump(self, f)
@classmethod
def load(self, path):
with open(path, 'rb') as f:
checkpoint = pickle.load(f)
assert isinstance(checkpoint, CheckPoint)
return checkpoint

View File

@@ -6,6 +6,7 @@ import sys
import textattack
import time
import torch
import pickle
RECIPE_NAMES = {
'alzantot': 'textattack.attack_recipes.Alzantot2018',
@@ -197,6 +198,15 @@ def get_args():
def str_to_int(s): return sum((ord(c) for c in s))
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
parser.add_argument('--checkpoint-resume', required=False, type=str,
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint. Overrides any oth')
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
help='A directory to save/load checkpoint files.')
parser.add_argument('--checkpoint-interval', required=False, type=int,
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
attack_group = parser.add_mutually_exclusive_group(required=False)
search_choices = ', '.join(SEARCH_CLASS_NAMES.keys())
@@ -211,6 +221,10 @@ def get_args():
command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments.
args = parser.parse_args(command_line_args)
if args.checkpoint_interval and args.shuffle:
# Not allowed b/c we cannot recover order of shuffled data
raise ValueError('Cannot use `--checkpoint-interval` with `--shuffle=True`')
set_seed(args.random_seed)
return args
@@ -308,7 +322,7 @@ def parse_logger_from_args(args):# Create logger
if not args.out_dir:
current_dir = os.path.dirname(os.path.realpath(__file__))
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
args.out_dir = outputs_dir
args.out_dir = os.path.normpath(outputs_dir)
# Output file.
out_time = int(time.time()*1000) # Output file
@@ -335,3 +349,24 @@ def parse_logger_from_args(args):# Create logger
if not args.disable_stdout:
attack_log_manager.enable_stdout()
return attack_log_manager
def parse_checkpoint_from_args(args):
if args.checkpoint_resume:
if args.checkpoint_resume.lower() == 'latest':
chkpt_files = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')]
latest_file = max(chkpt_files)
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file)
else:
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_resume)
checkpoint = textattack.shared.CheckPoint.load(checkpoint_path)
else:
checkpoint = None
return checkpoint
def default_checkpoint_dir():
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'checkpoints')
return os.path.normpath(checkpoints_dir)

View File

@@ -41,6 +41,17 @@ def attack_from_queue(args, in_queue, out_queue):
def run(args):
pytorch_multiprocessing_workaround()
if args.checkpoint_resume:
# Override current args with checkpoint args
resume_checkpoint = parse_checkpoint_from_args(args)
args = resume_checkpoint.args
num_examples_offset = resume_checkpoint.dataset_offset
num_examples = resume_checkpoint.num_remaining_attack
checkpoint_resume = True
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
print(resume_checkpoint, '\n')
# This makes `args` a namespace that's sharable between processes.
# We could do the same thing with the model, but it's actually faster
# to let each thread have their own copy of the model.
@@ -49,11 +60,14 @@ def run(args):
)
start_time = time.time()
if checkpoint_resume:
attack_log_manager = resume_checkpoint.log_manager
else:
attack_log_manager = parse_logger_from_args(args)
# We reserve the first GPU for coordinating workers.
num_gpus = torch.cuda.device_count()
dataset = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset)
dataset = DATASET_BY_MODEL[args.model](offset=num_examples_offset)
print(f'Running on {num_gpus} GPUs')
load_time = time.time()
@@ -64,7 +78,7 @@ def run(args):
in_queue = torch.multiprocessing.Queue()
out_queue = torch.multiprocessing.Queue()
# Add stuff to queue.
for _ in range(args.num_examples):
for _ in range(num_examples):
label, text = next(dataset)
in_queue.put((label, text))
# Start workers.
@@ -74,11 +88,16 @@ def run(args):
(args, in_queue, out_queue)
)
# Log results asynchronously and update progress bar.
if checkpoint_resume:
num_results = resume_checkpoint.results_count
num_failures = resume_checkpoint.num_failed_attacks
num_successes = resume_checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_successes = 0
pbar = tqdm.tqdm(total=args.num_examples, smoothing=0)
while num_results < args.num_examples:
pbar = tqdm.tqdm(total=num_examples, smoothing=0)
while num_results < num_examples:
result = out_queue.get(block=True)
if isinstance(result, Exception):
raise result
@@ -94,6 +113,16 @@ def run(args):
else:
label, text = next(dataset)
in_queue.put((label, text))
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
chkpt_time = time.time()
date_time = datetime.datetime.fromtimestamp(chkpt_time).strftime('%Y-%m-%d %H:%M:%S')
print('\n\n' + '=' * 100)
logger.info('Saving checkpoint at {} after {} attacks.'.format(date_time, num_results))
print('=' * 100 + '\n')
checkpoint = textattack.shared.CheckPoint(chkpt_time, args, attack_log_manager)
checkpoint.save()
pbar.close()
print()
# Enable summary stdout.

View File

@@ -6,9 +6,12 @@ import textattack
import time
import tqdm
import os
import datetime
from .run_attack_args_helper import *
logger = textattack.shared.utils.get_logger()
def run(args):
# Only use one GPU, if we have one.
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
@@ -20,6 +23,21 @@ def run(args):
if 'TFHUB_CACHE_DIR' not in os.environ:
os.environ['TFHUB_CACHE_DIR'] = os.path.expanduser('~/.cache/tensorflow-hub')
if args.checkpoint_resume:
# Override current args with checkpoint args
resume_checkpoint = parse_checkpoint_from_args(args)
args = resume_checkpoint.args
num_examples_offset = resume_checkpoint.dataset_offset
num_examples = resume_checkpoint.num_remaining_attacks
checkpoint_resume = True
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
print(resume_checkpoint, '\n')
else:
num_examples_offset = args.num_examples_offset
num_examples = args.num_examples
checkpoint_resume = False
start_time = time.time()
# Models and Attack
@@ -27,6 +45,9 @@ def run(args):
print(attack, '\n')
# Logger
if checkpoint_resume:
attack_log_manager = resume_checkpoint.log_manager
else:
attack_log_manager = parse_logger_from_args(args)
load_time = time.time()
@@ -57,16 +78,21 @@ def run(args):
else:
# Not interactive? Use default dataset.
if args.model in DATASET_BY_MODEL:
data = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset)
data = DATASET_BY_MODEL[args.model](offset=num_examples_offset)
else:
raise ValueError(f'Error: unsupported model {args.model}')
pbar = tqdm.tqdm(total=args.num_examples, smoothing=0)
pbar = tqdm.tqdm(total=num_examples, smoothing=0)
if checkpoint_resume:
num_results = resume_checkpoint.results_count
num_failures = resume_checkpoint.num_failed_attacks
num_successes = resume_checkpoint.num_successful_attacks
else:
num_results = 0
num_failures = 0
num_successes = 0
for result in attack.attack_dataset(data,
num_examples=args.num_examples,
num_examples=num_examples,
shuffle=args.shuffle,
attack_n=args.attack_n):
attack_log_manager.log_result(result)
@@ -80,6 +106,16 @@ def run(args):
if type(result) == textattack.attack_results.FailedAttackResult:
num_failures += 1
pbar.set_description('[Succeeded / Failed / Total] {} / {} / {}'.format(num_successes, num_failures, num_results))
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
chkpt_time = time.time()
date_time = datetime.datetime.fromtimestamp(chkpt_time).strftime('%Y-%m-%d %H:%M:%S')
print('\n\n' + '=' * 100)
logger.info('Saving checkpoint at {} after {} attacks.'.format(date_time, num_results))
print('=' * 100 + '\n')
checkpoint = textattack.shared.CheckPoint(chkpt_time, args, attack_log_manager)
checkpoint.save()
pbar.close()
print()
# Enable summary stdout