From 740756640fb21c581cd829e489b7a6a78f514d7c Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Fri, 15 May 2020 07:22:33 -0400 Subject: [PATCH] prelimary working checkpoint for file_logger and csv_logger --- .gitignore | 3 + .../goal_function_result.py | 3 + textattack/loggers/file_logger.py | 17 +++ textattack/search_methods/attack.py | 4 + textattack/shared/__init__.py | 1 + textattack/shared/checkpoint.py | 119 ++++++++++++++++++ .../shared/scripts/run_attack_args_helper.py | 37 +++++- .../shared/scripts/run_attack_parallel.py | 45 +++++-- .../scripts/run_attack_single_threaded.py | 50 ++++++-- 9 files changed, 263 insertions(+), 16 deletions(-) create mode 100644 textattack/shared/checkpoint.py diff --git a/.gitignore b/.gitignore index dc8225e6..777f73be 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ dist/ # Weights & Biases outputs wandb/ + +# checkpoints +checkpoints/ diff --git a/textattack/goal_function_results/goal_function_result.py b/textattack/goal_function_results/goal_function_result.py index 4b22ce98..05afb4dd 100644 --- a/textattack/goal_function_results/goal_function_result.py +++ b/textattack/goal_function_results/goal_function_result.py @@ -17,6 +17,9 @@ class GoalFunctionResult: if isinstance(self.score, torch.Tensor): self.score = self.score.item() + + if isinstance(self.succeeded, torch.Tensor): + self.succeeded = self.succeeded.item() def get_text_color_input(self): """ A string representing the color this result's changed diff --git a/textattack/loggers/file_logger.py b/textattack/loggers/file_logger.py index d77020ea..b8cbf57b 100644 --- a/textattack/loggers/file_logger.py +++ b/textattack/loggers/file_logger.py @@ -1,5 +1,6 @@ import os import sys +import copy import terminaltables from .logger import Logger @@ -7,6 +8,7 @@ from .logger import Logger class FileLogger(Logger): def __init__(self, filename='', stdout=False): self.stdout = stdout + self.filename = filename if stdout: self.fout = sys.stdout elif isinstance(filename, str): @@ -18,6 +20,21 @@ class FileLogger(Logger): self.fout = filename self.num_results = 0 + def __getstate__(self): + # Temporarily save file handle b/c we can't copy it + tmp = self.fout + self.fout = None + state = copy.deepcopy(self.__dict__) + self.fout = tmp + return state + + def __setstate__(self, state): + self.__dict__ = state + if self.stdout: + self.fout = sys.stdout + else: + self.fout = open(self.filename, 'a') + def log_attack_result(self, result): self.num_results += 1 color_method = 'stdout' if self.stdout else 'file' diff --git a/textattack/search_methods/attack.py b/textattack/search_methods/attack.py index 939f651b..4a95ae66 100644 --- a/textattack/search_methods/attack.py +++ b/textattack/search_methods/attack.py @@ -139,6 +139,10 @@ class Attack: if shuffle: random.shuffle(dataset.examples) + + if num_examples <= 0: + return + yield for text, ground_truth_output in dataset: tokenized_text = TokenizedText(text, self.tokenizer) diff --git a/textattack/shared/__init__.py b/textattack/shared/__init__.py index cec31a1b..2dbe616a 100644 --- a/textattack/shared/__init__.py +++ b/textattack/shared/__init__.py @@ -5,3 +5,4 @@ from . import validators from .tokenized_text import TokenizedText from .word_embedding import WordEmbedding +from .checkpoint import CheckPoint \ No newline at end of file diff --git a/textattack/shared/checkpoint.py b/textattack/shared/checkpoint.py new file mode 100644 index 00000000..4404346c --- /dev/null +++ b/textattack/shared/checkpoint.py @@ -0,0 +1,119 @@ +import os +import pickle +import datetime +from textattack.shared import utils +from textattack.attack_results import SuccessfulAttackResult, FailedAttackResult, SkippedAttackResult + +class CheckPoint: + """ An object that stores necessary information for saving and loading checkpoints + + Args: + time (float): epoch time representing when checkpoint was made + args: command line arguments of the original attack + log_manager (AttackLogManager) + """ + def __init__(self, time, args, log_manager): + self.time = time + self.args = args + self.log_manager = log_manager + + def __repr__(self): + main_str = 'Checkpoint(' + lines = [] + lines.append( + utils.add_indent(f'(Time): {self.datetime}', 2) + ) + + args_lines = [] + for key in self.args.__dict__: + args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2)) + args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2) + + lines.append(utils.add_indent(f'(Args): {args_str}', 2)) + + attack_logger_lines = [] + attack_logger_lines.append(utils.add_indent( + f'(Number of attacks performed: {self.results_count}', 2 + )) + attack_logger_lines.append(utils.add_indent( + f'(Number of successful attacks: {self.num_successful_attacks}', 2 + )) + attack_logger_lines.append(utils.add_indent( + f'(Number of failed attacks: {self.num_failed_attacks}', 2 + )) + attack_logger_lines.append(utils.add_indent( + f'(Number of skipped attacks: {self.num_skipped_attacks}', 2 + )) + attack_logger_str = utils.add_indent('\n' + '\n'.join(attack_logger_lines), 2) + lines.append(utils.add_indent(f'(Previous attack summary): {attack_logger_str}', 2)) + + main_str += '\n ' + '\n '.join(lines) + '\n' + main_str += ')' + return main_str + + __str__ = __repr__ + + @property + def results_count(self): + """ Return number of attacks made so far """ + return len(self.log_manager.results) + + @property + def num_skipped_attacks(self): + count = 0 + for r in self.log_manager.results: + if isinstance(r, SkippedAttackResult): + count += 1 + return count + + @property + def num_failed_attacks(self): + count = 0 + for r in self.log_manager.results: + if isinstance(r, FailedAttackResult): + count += 1 + return count + + @property + def num_successful_attacks(self): + count = 0 + for r in self.log_manager.results: + if isinstance(r, SuccessfulAttackResult): + count += 1 + return count + + @property + def num_remaining_attacks(self): + if self.args.attack_n: + non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks + count = self.args.num_examples - non_skipped_attacks + else: + count = self.args.num_examples - self.results_count + return count + + @property + def dataset_offset(self): + """ Calculate offset into the dataset to start from """ + # Original offset + # of results processed so far + return self.args.num_examples_offset + self.results_count + + @property + def datetime(self): + return datetime.datetime.fromtimestamp(self.time).strftime('%Y-%m-%d %H:%M:%S') + + def save(self): + file_name = "{}.ta.chkpt".format(int(self.time*1000)) + if not os.path.exists(self.args.checkpoint_dir): + os.makedirs(self.args.checkpoint_dir) + path = os.path.join(self.args.checkpoint_dir, file_name) + with open(path, 'wb') as f: + pickle.dump(self, f) + + @classmethod + def load(self, path): + with open(path, 'rb') as f: + checkpoint = pickle.load(f) + assert isinstance(checkpoint, CheckPoint) + + return checkpoint + \ No newline at end of file diff --git a/textattack/shared/scripts/run_attack_args_helper.py b/textattack/shared/scripts/run_attack_args_helper.py index 36db42e3..619a29a2 100644 --- a/textattack/shared/scripts/run_attack_args_helper.py +++ b/textattack/shared/scripts/run_attack_args_helper.py @@ -6,6 +6,7 @@ import sys import textattack import time import torch +import pickle RECIPE_NAMES = { 'alzantot': 'textattack.attack_recipes.Alzantot2018', @@ -196,6 +197,15 @@ def get_args(): def str_to_int(s): return sum((ord(c) for c in s)) parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK')) + + parser.add_argument('--checkpoint-resume', required=False, type=str, + help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint. Overrides any oth') + + parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(), + help='A directory to save/load checkpoint files.') + + parser.add_argument('--checkpoint-interval', required=False, type=int, + help='Interval for saving checkpoints. If not set, no checkpoints will be saved.') attack_group = parser.add_mutually_exclusive_group(required=False) @@ -210,6 +220,10 @@ def get_args(): command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments. args = parser.parse_args(command_line_args) + + if args.checkpoint_interval and args.shuffle: + # Not allowed b/c we cannot recover order of shuffled data + raise ValueError('Cannot use `--checkpoint-interval` with `--shuffle=True`') set_seed(args.random_seed) @@ -308,7 +322,7 @@ def parse_logger_from_args(args):# Create logger if not args.out_dir: current_dir = os.path.dirname(os.path.realpath(__file__)) outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs') - args.out_dir = outputs_dir + args.out_dir = os.path.normpath(outputs_dir) # Output file. out_time = int(time.time()*1000) # Output file @@ -335,3 +349,24 @@ def parse_logger_from_args(args):# Create logger if not args.disable_stdout: attack_log_manager.enable_stdout() return attack_log_manager + +def parse_checkpoint_from_args(args): + if args.checkpoint_resume: + if args.checkpoint_resume.lower() == 'latest': + chkpt_files = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')] + latest_file = max(chkpt_files) + checkpoint_path = os.path.join(args.checkpoint_dir, latest_file) + else: + checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_resume) + + checkpoint = textattack.shared.CheckPoint.load(checkpoint_path) + else: + checkpoint = None + + return checkpoint + +def default_checkpoint_dir(): + current_dir = os.path.dirname(os.path.realpath(__file__)) + checkpoints_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'checkpoints') + return os.path.normpath(checkpoints_dir) + diff --git a/textattack/shared/scripts/run_attack_parallel.py b/textattack/shared/scripts/run_attack_parallel.py index 65a33633..1deb8c08 100644 --- a/textattack/shared/scripts/run_attack_parallel.py +++ b/textattack/shared/scripts/run_attack_parallel.py @@ -41,6 +41,17 @@ def attack_from_queue(args, in_queue, out_queue): def run(args): pytorch_multiprocessing_workaround() + + if args.checkpoint_resume: + # Override current args with checkpoint args + resume_checkpoint = parse_checkpoint_from_args(args) + args = resume_checkpoint.args + num_examples_offset = resume_checkpoint.dataset_offset + num_examples = resume_checkpoint.num_remaining_attack + checkpoint_resume = True + logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime)) + print(resume_checkpoint, '\n') + # This makes `args` a namespace that's sharable between processes. # We could do the same thing with the model, but it's actually faster # to let each thread have their own copy of the model. @@ -49,11 +60,14 @@ def run(args): ) start_time = time.time() - attack_log_manager = parse_logger_from_args(args) + if checkpoint_resume: + attack_log_manager = resume_checkpoint.log_manager + else: + attack_log_manager = parse_logger_from_args(args) # We reserve the first GPU for coordinating workers. num_gpus = torch.cuda.device_count() - dataset = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset) + dataset = DATASET_BY_MODEL[args.model](offset=num_examples_offset) print(f'Running on {num_gpus} GPUs') load_time = time.time() @@ -64,7 +78,7 @@ def run(args): in_queue = torch.multiprocessing.Queue() out_queue = torch.multiprocessing.Queue() # Add stuff to queue. - for _ in range(args.num_examples): + for _ in range(num_examples): label, text = next(dataset) in_queue.put((label, text)) # Start workers. @@ -74,11 +88,16 @@ def run(args): (args, in_queue, out_queue) ) # Log results asynchronously and update progress bar. - num_results = 0 - num_failures = 0 - num_successes = 0 - pbar = tqdm.tqdm(total=args.num_examples, smoothing=0) - while num_results < args.num_examples: + if checkpoint_resume: + num_results = resume_checkpoint.results_count + num_failures = resume_checkpoint.num_failed_attacks + num_successes = resume_checkpoint.num_successful_attacks + else: + num_results = 0 + num_failures = 0 + num_successes = 0 + pbar = tqdm.tqdm(total=num_examples, smoothing=0) + while num_results < num_examples: result = out_queue.get(block=True) if isinstance(result, Exception): raise result @@ -94,6 +113,16 @@ def run(args): else: label, text = next(dataset) in_queue.put((label, text)) + + if args.checkpoint_interval and num_results % args.checkpoint_interval == 0: + chkpt_time = time.time() + date_time = datetime.datetime.fromtimestamp(chkpt_time).strftime('%Y-%m-%d %H:%M:%S') + print('\n\n' + '=' * 100) + logger.info('Saving checkpoint at {} after {} attacks.'.format(date_time, num_results)) + print('=' * 100 + '\n') + checkpoint = textattack.shared.CheckPoint(chkpt_time, args, attack_log_manager) + checkpoint.save() + pbar.close() print() # Enable summary stdout. diff --git a/textattack/shared/scripts/run_attack_single_threaded.py b/textattack/shared/scripts/run_attack_single_threaded.py index 9b73b471..03a42acc 100644 --- a/textattack/shared/scripts/run_attack_single_threaded.py +++ b/textattack/shared/scripts/run_attack_single_threaded.py @@ -6,9 +6,12 @@ import textattack import time import tqdm import os +import datetime from .run_attack_args_helper import * +logger = textattack.shared.utils.get_logger() + def run(args): # Only use one GPU, if we have one. if 'CUDA_VISIBLE_DEVICES' not in os.environ: @@ -19,6 +22,21 @@ def run(args): # Cache TensorFlow Hub models here, if not otherwise specified. if 'TFHUB_CACHE_DIR' not in os.environ: os.environ['TFHUB_CACHE_DIR'] = os.path.expanduser('~/.cache/tensorflow-hub') + + if args.checkpoint_resume: + # Override current args with checkpoint args + resume_checkpoint = parse_checkpoint_from_args(args) + args = resume_checkpoint.args + num_examples_offset = resume_checkpoint.dataset_offset + num_examples = resume_checkpoint.num_remaining_attacks + checkpoint_resume = True + logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime)) + print(resume_checkpoint, '\n') + + else: + num_examples_offset = args.num_examples_offset + num_examples = args.num_examples + checkpoint_resume = False start_time = time.time() @@ -27,7 +45,10 @@ def run(args): print(attack, '\n') # Logger - attack_log_manager = parse_logger_from_args(args) + if checkpoint_resume: + attack_log_manager = resume_checkpoint.log_manager + else: + attack_log_manager = parse_logger_from_args(args) load_time = time.time() print(f'Load time: {load_time - start_time}s') @@ -57,16 +78,21 @@ def run(args): else: # Not interactive? Use default dataset. if args.model in DATASET_BY_MODEL: - data = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset) + data = DATASET_BY_MODEL[args.model](offset=num_examples_offset) else: raise ValueError(f'Error: unsupported model {args.model}') - pbar = tqdm.tqdm(total=args.num_examples, smoothing=0) - num_results = 0 - num_failures = 0 - num_successes = 0 + pbar = tqdm.tqdm(total=num_examples, smoothing=0) + if checkpoint_resume: + num_results = resume_checkpoint.results_count + num_failures = resume_checkpoint.num_failed_attacks + num_successes = resume_checkpoint.num_successful_attacks + else: + num_results = 0 + num_failures = 0 + num_successes = 0 for result in attack.attack_dataset(data, - num_examples=args.num_examples, + num_examples=num_examples, shuffle=args.shuffle, attack_n=args.attack_n): attack_log_manager.log_result(result) @@ -80,6 +106,16 @@ def run(args): if type(result) == textattack.attack_results.FailedAttackResult: num_failures += 1 pbar.set_description('[Succeeded / Failed / Total] {} / {} / {}'.format(num_successes, num_failures, num_results)) + + if args.checkpoint_interval and num_results % args.checkpoint_interval == 0: + chkpt_time = time.time() + date_time = datetime.datetime.fromtimestamp(chkpt_time).strftime('%Y-%m-%d %H:%M:%S') + print('\n\n' + '=' * 100) + logger.info('Saving checkpoint at {} after {} attacks.'.format(date_time, num_results)) + print('=' * 100 + '\n') + checkpoint = textattack.shared.CheckPoint(chkpt_time, args, attack_log_manager) + checkpoint.save() + pbar.close() print() # Enable summary stdout