mirror of
https://github.com/QData/TextAttack.git
synced 2021-10-13 00:05:06 +03:00
86 lines
2.6 KiB
Python
86 lines
2.6 KiB
Python
"""
|
|
|
|
TextAttack Command Class for Attack Resume
|
|
-------------------------------------------
|
|
|
|
"""
|
|
|
|
|
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
|
|
|
import textattack
|
|
from textattack.commands import TextAttackCommand
|
|
from textattack.commands.attack.attack_args_helpers import (
|
|
merge_checkpoint_args,
|
|
parse_checkpoint_from_args,
|
|
)
|
|
|
|
|
|
class AttackResumeCommand(TextAttackCommand):
|
|
"""The TextAttack attack resume recipe module:
|
|
|
|
A command line parser to resume a checkpointed attack from user
|
|
specifications.
|
|
"""
|
|
|
|
def run(self, args):
|
|
checkpoint = parse_checkpoint_from_args(args)
|
|
args = merge_checkpoint_args(checkpoint.args, args)
|
|
textattack.shared.utils.set_seed(args.random_seed)
|
|
args.checkpoint_resume = True
|
|
|
|
# Run attack from checkpoint.
|
|
from textattack.commands.attack.run_attack_parallel import run as run_parallel
|
|
from textattack.commands.attack.run_attack_single_threaded import (
|
|
run as run_single_threaded,
|
|
)
|
|
|
|
if args.parallel:
|
|
run_parallel(args, checkpoint=checkpoint)
|
|
else:
|
|
run_single_threaded(args, checkpoint=checkpoint)
|
|
|
|
@staticmethod
|
|
def register_subcommand(main_parser: ArgumentParser):
|
|
resume_parser = main_parser.add_parser(
|
|
"attack-resume",
|
|
help="resume a checkpointed attack",
|
|
formatter_class=ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
# Parser for parsing args for resume
|
|
resume_parser.add_argument(
|
|
"--checkpoint-file",
|
|
"-f",
|
|
type=str,
|
|
required=True,
|
|
help='Path of checkpoint file to resume attack from. If "latest" (or "{directory path}/latest") is entered,'
|
|
"recover latest checkpoint from either current path or specified directory.",
|
|
)
|
|
|
|
resume_parser.add_argument(
|
|
"--checkpoint-dir",
|
|
"-d",
|
|
required=False,
|
|
type=str,
|
|
default=None,
|
|
help="The directory to save checkpoint files. If not set, use directory from recovered arguments.",
|
|
)
|
|
|
|
resume_parser.add_argument(
|
|
"--checkpoint-interval",
|
|
"-i",
|
|
required=False,
|
|
type=int,
|
|
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
|
)
|
|
|
|
resume_parser.add_argument(
|
|
"--parallel",
|
|
action="store_true",
|
|
default=False,
|
|
help="Run attack using multiple GPUs.",
|
|
)
|
|
|
|
resume_parser.set_defaults(func=AttackResumeCommand())
|