1
0
mirror of https://github.com/QData/TextAttack.git synced 2021-10-13 00:05:06 +03:00
Files
textattack-nlp-transformer/textattack/commands/train_model/train_model_command.py
2020-06-20 19:31:55 -04:00

92 lines
3.9 KiB
Python

from argparse import ArgumentParser
from textattack.commands import TextAttackCommand
class TrainModelCommand(TextAttackCommand):
"""
The TextAttack train module:
A command line parser to train a model from user specifications.
"""
def run(self, args):
if not args.model_dir:
args.model_dir = 'bert-base-cased' if args.cased else 'bert-base-uncased'
if args.output_prefix: args.output_prefix += '-'
cased_str = '-' + ('cased' if args.cased else 'uncased')
date_now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M")
root_output_dir = 'outputs'
args.output_dir = os.path.join(root_output_dir,
f'{args.output_prefix}{args.dataset}{cased_str}-{date_now}/')
# Use multiple GPUs if we can!
args.num_gpus = torch.cuda.device_count()
raise NotImplementedError("Cannot train models yet - stay tuned!!")
@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser("train", help="train a model")parser.add_argument('--model_dir', type=str, default=None,
help='directory of model to train')
parser.add_argument('--logging_steps', type=int,
default=500, help='log model after this many steps')
parser.add_argument('--checkpoint_steps', type=int,
default=5000, help='save model after this many steps')
parser.add_argument('--checkpoint_every_epoch', action='store_true',
default=False, help='save model checkpoint after this many steps')
parser.add_argument('--output_prefix', type=str,
default='', help='prefix for model saved output')
parser.add_argument('--dataset', type=str,
default='yelp', help='dataset for training, like \'yelp\'')
parser.add_argument('--cased', action='store_true', default=False,
help='if true, bert is cased, if false, bert is uncased')
parser.add_argument('--debug_cuda_memory', type=str,
default=False, help='Print CUDA memory info periodically')
parser.add_argument('--num_train_epochs', '--epochs', type=int,
default=100, help='Total number of epochs to train for')
parser.add_argument('--early_stopping_epochs', type=int,
default=-1, help='Number of epochs validation must increase'
' before stopping early')
parser.add_argument('--batch_size', type=int, default=128,
help='Batch size for training')
parser.add_argument('--max_seq_len', type=int, default=128,
help='Maximum length of a sequence (anything beyond this will '
'be truncated) - '
'# BERT\'s max seq length is 512 so can\'t go higher than that.')
parser.add_argument('--learning_rate', '--lr', type=int, default=2e-5,
help='Learning rate for Adam Optimization')
parser.add_argument('--tb_writer_step', type=int, default=1000,
help='Number of steps before writing to tensorboard')
parser.add_argument('--grad_accum_steps', type=int, default=1,
help='Number of steps to accumulate gradients before optimizing, '
'advancing scheduler, etc.')
parser.add_argument('--warmup_proportion', type=int, default=0.1,
help='Warmup proportion for linear scheduling')
parser.add_argument('--config_name', type=str, default='config.json',
help='Filename to save BERT config as')
parser.add_argument('--weights_name', type=str, default='pytorch_model.bin',
help='Filename to save model weights as')
args = parser.parse_args()