mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
merge checkpoint changes in
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -34,3 +34,6 @@ dist/
|
||||
|
||||
# Weights & Biases outputs
|
||||
wandb/
|
||||
|
||||
# checkpoints
|
||||
checkpoints/
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
<a target="_blank" href="https://travis-ci.org/QData/TextAttack">
|
||||
<img src="https://travis-ci.org/QData/TextAttack.svg?branch=master" alt="Coverage Status">
|
||||
</a>
|
||||
<a href="https://badge.fury.io/py/textattack">
|
||||
<img src="https://badge.fury.io/py/textattack.svg" alt="PyPI version" height="18">
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
@@ -19,6 +19,9 @@ class GoalFunctionResult:
|
||||
|
||||
if isinstance(self.score, torch.Tensor):
|
||||
self.score = self.score.item()
|
||||
|
||||
if isinstance(self.succeeded, torch.Tensor):
|
||||
self.succeeded = self.succeeded.item()
|
||||
|
||||
def get_text_color_input(self):
|
||||
""" A string representing the color this result's changed
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import terminaltables
|
||||
|
||||
from .logger import Logger
|
||||
@@ -7,6 +8,7 @@ from .logger import Logger
|
||||
class FileLogger(Logger):
|
||||
def __init__(self, filename='', stdout=False):
|
||||
self.stdout = stdout
|
||||
self.filename = filename
|
||||
if stdout:
|
||||
self.fout = sys.stdout
|
||||
elif isinstance(filename, str):
|
||||
@@ -18,6 +20,18 @@ class FileLogger(Logger):
|
||||
self.fout = filename
|
||||
self.num_results = 0
|
||||
|
||||
def __getstate__(self):
|
||||
# Temporarily save file handle b/c we can't copy it
|
||||
state = {i: self.__dict__[i] for i in self.__dict__ if i !='fout'}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
if self.stdout:
|
||||
self.fout = sys.stdout
|
||||
else:
|
||||
self.fout = open(self.filename, 'a')
|
||||
|
||||
def log_attack_result(self, result):
|
||||
self.num_results += 1
|
||||
color_method = 'stdout' if self.stdout else 'file'
|
||||
@@ -36,4 +50,8 @@ class FileLogger(Logger):
|
||||
|
||||
def log_sep(self):
|
||||
self.fout.write('-' * 90 + '\n')
|
||||
|
||||
def flush(self):
|
||||
self.fout.flush()
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import socket
|
||||
import copy
|
||||
from visdom import Visdom
|
||||
|
||||
from textattack.shared.utils import html_table_from_rows
|
||||
@@ -16,9 +17,20 @@ class VisdomLogger(Logger):
|
||||
if not port_is_open(port, hostname=hostname):
|
||||
raise socket.error(f'Visdom not running on {hostname}:{port}')
|
||||
self.vis = Visdom(port=port, server=hostname, env=env)
|
||||
self.env = env
|
||||
self.port = port
|
||||
self.hostname = hostname
|
||||
self.windows = {}
|
||||
self.sample_rows = []
|
||||
|
||||
def __getstate__(self):
|
||||
state = {i: self.__dict__[i] for i in self.__dict__ if i !='vis'}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self.vis = Visdom(port=self.port, server=self.hostname, env=self.env)
|
||||
|
||||
def log_attack_result(self, result):
|
||||
text_a, text_b = result.diff_color(color_method='html')
|
||||
result_str = result.goal_function_result_str(color_method='html')
|
||||
@@ -51,7 +63,7 @@ class VisdomLogger(Logger):
|
||||
if not window_id: window_id = title # Can provide either of these,
|
||||
if not title: title = window_id # or both.
|
||||
table = html_table_from_rows(rows, title=title, header=header, style_dict=style)
|
||||
self.text(table_html, title=title, window_id=window_id)
|
||||
self.text(table, title=title, window_id=window_id)
|
||||
|
||||
def bar(self, X_data, numbins=10, title=None, window_id=None):
|
||||
window = None
|
||||
|
||||
@@ -4,9 +4,14 @@ from .logger import Logger
|
||||
class WeightsAndBiasesLogger(Logger):
|
||||
def __init__(self, filename='', stdout=False):
|
||||
import wandb
|
||||
wandb.init(project='textattack')
|
||||
wandb.init(project='textattack', resume=True)
|
||||
self._result_table_rows = []
|
||||
|
||||
def __setstate__(self, state):
|
||||
import wandb
|
||||
self.__dict__ = state
|
||||
wandb.init(project='textattack', resume=True)
|
||||
|
||||
def log_summary_rows(self, rows, title, window_id):
|
||||
table = wandb.Table(columns=['Attack Results', ''])
|
||||
for row in rows:
|
||||
|
||||
@@ -5,3 +5,4 @@ from . import validators
|
||||
from .tokenized_text import TokenizedText
|
||||
from .word_embedding import WordEmbedding
|
||||
from .attack import Attack
|
||||
from .checkpoint import Checkpoint
|
||||
|
||||
@@ -161,6 +161,10 @@ class Attack:
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(dataset.examples)
|
||||
|
||||
if num_examples <= 0:
|
||||
return
|
||||
yield
|
||||
|
||||
for text, ground_truth_output in dataset:
|
||||
tokenized_text = TokenizedText(text, self.tokenizer)
|
||||
|
||||
126
textattack/shared/checkpoint.py
Normal file
126
textattack/shared/checkpoint.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import datetime
|
||||
from textattack.shared import utils
|
||||
from textattack.attack_results import SuccessfulAttackResult, FailedAttackResult, SkippedAttackResult
|
||||
|
||||
logger = utils.get_logger()
|
||||
|
||||
class Checkpoint:
|
||||
""" An object that stores necessary information for saving and loading checkpoints
|
||||
|
||||
Args:
|
||||
args: command line arguments of the original attack
|
||||
log_manager (AttackLogManager)
|
||||
chkpt_time (float): epoch time representing when checkpoint was made
|
||||
"""
|
||||
def __init__(self, args, log_manager, chkpt_time=None):
|
||||
self.args = args
|
||||
self.log_manager = log_manager
|
||||
if chkpt_time:
|
||||
self.time = chkpt_time
|
||||
else:
|
||||
self.time = time.time()
|
||||
|
||||
def __repr__(self):
|
||||
main_str = 'Checkpoint('
|
||||
lines = []
|
||||
lines.append(
|
||||
utils.add_indent(f'(Time): {self.datetime}', 2)
|
||||
)
|
||||
|
||||
args_lines = []
|
||||
for key in self.args.__dict__:
|
||||
args_lines.append(utils.add_indent(f'({key}): {self.args.__dict__[key]}', 2))
|
||||
args_str = utils.add_indent('\n' + '\n'.join(args_lines), 2)
|
||||
|
||||
lines.append(utils.add_indent(f'(Args): {args_str}', 2))
|
||||
|
||||
attack_logger_lines = []
|
||||
attack_logger_lines.append(utils.add_indent(
|
||||
f'(Total number of examples to attack): {self.args.num_examples}', 2
|
||||
))
|
||||
attack_logger_lines.append(utils.add_indent(
|
||||
f'(Number of attacks performed): {self.results_count}', 2
|
||||
))
|
||||
attack_logger_lines.append(utils.add_indent(
|
||||
f'(Number of remaining attacks): {self.num_remaining_attacks}', 2
|
||||
))
|
||||
breakdown_lines = []
|
||||
breakdown_lines.append(utils.add_indent(
|
||||
f'(Number of successful attacks): {self.num_successful_attacks}', 2
|
||||
))
|
||||
breakdown_lines.append(utils.add_indent(
|
||||
f'(Number of failed attacks): {self.num_failed_attacks}', 2
|
||||
))
|
||||
breakdown_lines.append(utils.add_indent(
|
||||
f'(Number of skipped attacks): {self.num_skipped_attacks}', 2
|
||||
))
|
||||
breakdown_str = utils.add_indent('\n' + '\n'.join(breakdown_lines), 2)
|
||||
attack_logger_lines.append(utils.add_indent(f'(Latest result breakdown): {breakdown_str}', 2))
|
||||
attack_logger_str = utils.add_indent('\n' + '\n'.join(attack_logger_lines), 2)
|
||||
lines.append(utils.add_indent(f'(Previous attack summary): {attack_logger_str}', 2))
|
||||
|
||||
main_str += '\n ' + '\n '.join(lines) + '\n'
|
||||
main_str += ')'
|
||||
return main_str
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@property
|
||||
def results_count(self):
|
||||
""" Return number of attacks made so far """
|
||||
return len(self.log_manager.results)
|
||||
|
||||
@property
|
||||
def num_skipped_attacks(self):
|
||||
return sum(isinstance(r, SkippedAttackResult) for r in self.log_manager.results)
|
||||
|
||||
@property
|
||||
def num_failed_attacks(self):
|
||||
return sum(isinstance(r, FailedAttackResult) for r in self.log_manager.results)
|
||||
|
||||
@property
|
||||
def num_successful_attacks(self):
|
||||
return sum(isinstance(r, SuccessfulAttackResult) for r in self.log_manager.results)
|
||||
|
||||
@property
|
||||
def num_remaining_attacks(self):
|
||||
if self.args.attack_n:
|
||||
non_skipped_attacks = self.num_successful_attacks + self.num_failed_attacks
|
||||
count = self.args.num_examples - non_skipped_attacks
|
||||
else:
|
||||
count = self.args.num_examples - self.results_count
|
||||
return count
|
||||
|
||||
@property
|
||||
def dataset_offset(self):
|
||||
""" Calculate offset into the dataset to start from """
|
||||
# Original offset + # of results processed so far
|
||||
return self.args.num_examples_offset + self.results_count
|
||||
|
||||
@property
|
||||
def datetime(self):
|
||||
return datetime.datetime.fromtimestamp(self.time).strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
def save(self, quiet=False):
|
||||
file_name = "{}.ta.chkpt".format(int(self.time*1000))
|
||||
if not os.path.exists(self.args.checkpoint_dir):
|
||||
os.makedirs(self.args.checkpoint_dir)
|
||||
path = os.path.join(self.args.checkpoint_dir, file_name)
|
||||
if not quiet:
|
||||
print('\n\n' + '=' * 125)
|
||||
logger.info('Saving checkpoint under "{}" at {} after {} attacks.'.format(path, self.datetime, self.results_count))
|
||||
print('=' * 125 + '\n')
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(self, path):
|
||||
with open(path, 'rb') as f:
|
||||
checkpoint = pickle.load(f)
|
||||
assert isinstance(checkpoint, Checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -6,6 +6,8 @@ import sys
|
||||
import textattack
|
||||
import time
|
||||
import torch
|
||||
import pickle
|
||||
import copy
|
||||
|
||||
RECIPE_NAMES = {
|
||||
'alzantot': 'textattack.attack_recipes.Alzantot2018',
|
||||
@@ -144,6 +146,7 @@ def set_seed(random_seed):
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
def get_args():
|
||||
# Parser for regular arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description='A commandline parser for TextAttack',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
@@ -151,7 +154,7 @@ def get_args():
|
||||
parser.add_argument('--transformation', type=str, required=False,
|
||||
default='word-swap-embedding', choices=TRANSFORMATION_CLASS_NAMES.keys(),
|
||||
help='The transformations to apply.')
|
||||
|
||||
|
||||
parser.add_argument('--model', type=str, required=False, default='bert-yelp-sentiment',
|
||||
choices=MODEL_CLASS_NAMES.keys(), help='The classification model to attack.')
|
||||
|
||||
@@ -198,6 +201,12 @@ def get_args():
|
||||
|
||||
def str_to_int(s): return sum((ord(c) for c in s))
|
||||
parser.add_argument('--random-seed', default=str_to_int('TEXTATTACK'))
|
||||
|
||||
parser.add_argument('--checkpoint-dir', required=False, type=str, default=default_checkpoint_dir(),
|
||||
help='A directory to save/load checkpoint files.')
|
||||
|
||||
parser.add_argument('--checkpoint-interval', required=False, type=int,
|
||||
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
|
||||
|
||||
attack_group = parser.add_mutually_exclusive_group(required=False)
|
||||
|
||||
@@ -209,11 +218,36 @@ def get_args():
|
||||
attack_group.add_argument('--recipe', '-r', type=str, required=False, default=None,
|
||||
help='full attack recipe (overrides provided goal function, transformation & constraints)',
|
||||
choices=RECIPE_NAMES.keys())
|
||||
|
||||
command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments.
|
||||
args = parser.parse_args(command_line_args)
|
||||
|
||||
set_seed(args.random_seed)
|
||||
|
||||
# Parser for parsing args for resume
|
||||
resume_parser = argparse.ArgumentParser(
|
||||
description='A commandline parser for TextAttack',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
resume_parser.add_argument('--checkpoint-file', '-f', type=str, required=False, default='latest',
|
||||
help='Name of checkpoint file to resume attack from. If "latest" is entered, recover latest checkpoint.')
|
||||
|
||||
resume_parser.add_argument('--checkpoint-dir', '-d', required=False, type=str, default=default_checkpoint_dir(),
|
||||
help='A directory to save/load checkpoint files.')
|
||||
|
||||
resume_parser.add_argument('--checkpoint-interval', '-i', required=False, type=int,
|
||||
help='Interval for saving checkpoints. If not set, no checkpoints will be saved.')
|
||||
|
||||
resume_parser.add_argument('--parallel', action='store_true', default=False,
|
||||
help='Run attack using multiple GPUs.')
|
||||
|
||||
if sys.argv[1:] and sys.argv[1].lower() == 'resume':
|
||||
args = resume_parser.parse_args(sys.argv[2:])
|
||||
setattr(args, 'checkpoint_resume', True)
|
||||
else:
|
||||
command_line_args = None if sys.argv[1:] else ['-h'] # Default to help with empty arguments.
|
||||
args = parser.parse_args(command_line_args)
|
||||
setattr(args, 'checkpoint_resume', False)
|
||||
|
||||
if args.checkpoint_interval and args.shuffle:
|
||||
# Not allowed b/c we cannot recover order of shuffled data
|
||||
raise ValueError('Cannot use `--checkpoint-interval` with `--shuffle=True`')
|
||||
|
||||
set_seed(args.random_seed)
|
||||
|
||||
return args
|
||||
|
||||
@@ -311,7 +345,7 @@ def parse_logger_from_args(args):# Create logger
|
||||
if not args.out_dir:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
outputs_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'outputs')
|
||||
args.out_dir = outputs_dir
|
||||
args.out_dir = os.path.normpath(outputs_dir)
|
||||
|
||||
# Output file.
|
||||
out_time = int(time.time()*1000) # Output file
|
||||
@@ -338,3 +372,36 @@ def parse_logger_from_args(args):# Create logger
|
||||
if not args.disable_stdout:
|
||||
attack_log_manager.enable_stdout()
|
||||
return attack_log_manager
|
||||
|
||||
def parse_checkpoint_from_args(args):
|
||||
if args.checkpoint_file.lower() == 'latest':
|
||||
chkpt_file_names = [f for f in os.listdir(args.checkpoint_dir) if f.endswith('.ta.chkpt')]
|
||||
assert chkpt_file_names, "Checkpoint directory is empty"
|
||||
timestamps = [int(f.replace('.ta.chkpt', '')) for f in chkpt_file_names]
|
||||
latest_file = str(max(timestamps)) + '.ta.chkpt'
|
||||
checkpoint_path = os.path.join(args.checkpoint_dir, latest_file)
|
||||
else:
|
||||
checkpoint_path = os.path.join(args.checkpoint_dir, args.checkpoint_file)
|
||||
|
||||
checkpoint = textattack.shared.Checkpoint.load(checkpoint_path)
|
||||
set_seed(checkpoint.args.random_seed)
|
||||
|
||||
return checkpoint
|
||||
|
||||
def default_checkpoint_dir():
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
checkpoints_dir = os.path.join(current_dir, os.pardir, os.pardir, os.pardir, 'checkpoints')
|
||||
return os.path.normpath(checkpoints_dir)
|
||||
|
||||
def merge_checkpoint_args(saved_args, cmdline_args):
|
||||
""" Merge previously saved arguments for checkpoint and newly entered arguments """
|
||||
args = copy.deepcopy(saved_args)
|
||||
# Newly entered arguments take precedence
|
||||
args.checkpoint_resume = cmdline_args.checkpoint_resume
|
||||
args.parallel = cmdline_args.parallel
|
||||
args.checkpoint_dir = cmdline_args.checkpoint_dir
|
||||
# If set, we replace
|
||||
if cmdline_args.checkpoint_interval:
|
||||
args.checkpoint_interval = cmdlineargs.checkpoint_interval
|
||||
|
||||
return args
|
||||
|
||||
@@ -41,6 +41,16 @@ def attack_from_queue(args, in_queue, out_queue):
|
||||
|
||||
def run(args):
|
||||
pytorch_multiprocessing_workaround()
|
||||
|
||||
if args.checkpoint_resume:
|
||||
# Override current args with checkpoint args
|
||||
resume_checkpoint = parse_checkpoint_from_args(args)
|
||||
args = merge_checkpoint_args(resume_checkpoint.args, args)
|
||||
num_examples_offset = resume_checkpoint.dataset_offset
|
||||
num_examples = resume_checkpoint.num_remaining_attack
|
||||
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
|
||||
print(resume_checkpoint, '\n')
|
||||
|
||||
# This makes `args` a namespace that's sharable between processes.
|
||||
# We could do the same thing with the model, but it's actually faster
|
||||
# to let each thread have their own copy of the model.
|
||||
@@ -49,11 +59,14 @@ def run(args):
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
if args.checkpoint_resume:
|
||||
attack_log_manager = resume_checkpoint.log_manager
|
||||
else:
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
# We reserve the first GPU for coordinating workers.
|
||||
num_gpus = torch.cuda.device_count()
|
||||
dataset = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset)
|
||||
dataset = DATASET_BY_MODEL[args.model](offset=num_examples_offset)
|
||||
|
||||
print(f'Running on {num_gpus} GPUs')
|
||||
load_time = time.time()
|
||||
@@ -64,7 +77,7 @@ def run(args):
|
||||
in_queue = torch.multiprocessing.Queue()
|
||||
out_queue = torch.multiprocessing.Queue()
|
||||
# Add stuff to queue.
|
||||
for _ in range(args.num_examples):
|
||||
for _ in range(num_examples):
|
||||
label, text = next(dataset)
|
||||
in_queue.put((label, text))
|
||||
# Start workers.
|
||||
@@ -74,11 +87,16 @@ def run(args):
|
||||
(args, in_queue, out_queue)
|
||||
)
|
||||
# Log results asynchronously and update progress bar.
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_successes = 0
|
||||
pbar = tqdm.tqdm(total=args.num_examples, smoothing=0)
|
||||
while num_results < args.num_examples:
|
||||
if args.checkpoint_resume:
|
||||
num_results = resume_checkpoint.results_count
|
||||
num_failures = resume_checkpoint.num_failed_attacks
|
||||
num_successes = resume_checkpoint.num_successful_attacks
|
||||
else:
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_successes = 0
|
||||
pbar = tqdm.tqdm(total=num_examples, smoothing=0)
|
||||
while num_results < num_examples:
|
||||
result = out_queue.get(block=True)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
@@ -94,6 +112,12 @@ def run(args):
|
||||
else:
|
||||
label, text = next(dataset)
|
||||
in_queue.put((label, text))
|
||||
|
||||
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
|
||||
checkpoint = textattack.shared.Checkpoint(chkpt_time, args, attack_log_manager)
|
||||
checkpoint.save()
|
||||
attack_log_manager.flush()
|
||||
|
||||
pbar.close()
|
||||
print()
|
||||
# Enable summary stdout.
|
||||
|
||||
@@ -6,9 +6,12 @@ import textattack
|
||||
import time
|
||||
import tqdm
|
||||
import os
|
||||
import datetime
|
||||
|
||||
from .run_attack_args_helper import *
|
||||
|
||||
logger = textattack.shared.utils.get_logger()
|
||||
|
||||
def run(args):
|
||||
# Only use one GPU, if we have one.
|
||||
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
||||
@@ -19,6 +22,18 @@ def run(args):
|
||||
# Cache TensorFlow Hub models here, if not otherwise specified.
|
||||
if 'TFHUB_CACHE_DIR' not in os.environ:
|
||||
os.environ['TFHUB_CACHE_DIR'] = os.path.expanduser('~/.cache/tensorflow-hub')
|
||||
|
||||
if args.checkpoint_resume:
|
||||
# Override current args with checkpoint args
|
||||
resume_checkpoint = parse_checkpoint_from_args(args)
|
||||
args = merge_checkpoint_args(resume_checkpoint.args, args)
|
||||
num_examples_offset = resume_checkpoint.dataset_offset
|
||||
num_examples = resume_checkpoint.num_remaining_attacks
|
||||
logger.info('Recovered from previously saved checkpoint at {}'.format(resume_checkpoint.datetime))
|
||||
print(resume_checkpoint, '\n')
|
||||
else:
|
||||
num_examples_offset = args.num_examples_offset
|
||||
num_examples = args.num_examples
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -27,7 +42,10 @@ def run(args):
|
||||
print(attack, '\n')
|
||||
|
||||
# Logger
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
if args.checkpoint_resume:
|
||||
attack_log_manager = resume_checkpoint.log_manager
|
||||
else:
|
||||
attack_log_manager = parse_logger_from_args(args)
|
||||
|
||||
load_time = time.time()
|
||||
print(f'Load time: {load_time - start_time}s')
|
||||
@@ -57,16 +75,21 @@ def run(args):
|
||||
else:
|
||||
# Not interactive? Use default dataset.
|
||||
if args.model in DATASET_BY_MODEL:
|
||||
data = DATASET_BY_MODEL[args.model](offset=args.num_examples_offset)
|
||||
data = DATASET_BY_MODEL[args.model](offset=num_examples_offset)
|
||||
else:
|
||||
raise ValueError(f'Error: unsupported model {args.model}')
|
||||
|
||||
pbar = tqdm.tqdm(total=args.num_examples, smoothing=0)
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_successes = 0
|
||||
pbar = tqdm.tqdm(total=num_examples, smoothing=0)
|
||||
if args.checkpoint_resume:
|
||||
num_results = resume_checkpoint.results_count
|
||||
num_failures = resume_checkpoint.num_failed_attacks
|
||||
num_successes = resume_checkpoint.num_successful_attacks
|
||||
else:
|
||||
num_results = 0
|
||||
num_failures = 0
|
||||
num_successes = 0
|
||||
for result in attack.attack_dataset(data,
|
||||
num_examples=args.num_examples,
|
||||
num_examples=num_examples,
|
||||
shuffle=args.shuffle,
|
||||
attack_n=args.attack_n):
|
||||
attack_log_manager.log_result(result)
|
||||
@@ -80,6 +103,12 @@ def run(args):
|
||||
if type(result) == textattack.attack_results.FailedAttackResult:
|
||||
num_failures += 1
|
||||
pbar.set_description('[Succeeded / Failed / Total] {} / {} / {}'.format(num_successes, num_failures, num_results))
|
||||
|
||||
if args.checkpoint_interval and num_results % args.checkpoint_interval == 0:
|
||||
checkpoint = textattack.shared.Checkpoint(args, attack_log_manager)
|
||||
checkpoint.save()
|
||||
attack_log_manager.flush()
|
||||
|
||||
pbar.close()
|
||||
print()
|
||||
# Enable summary stdout
|
||||
|
||||
Reference in New Issue
Block a user