From 37c5495322eba631ad4b66ea6d1eb8d573ce4bc3 Mon Sep 17 00:00:00 2001 From: Nicolas1203 Date: Fri, 2 Aug 2019 17:20:43 +0200 Subject: [PATCH] add last epoch argument --- train_ssd.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train_ssd.py b/train_ssd.py index 43f6a050..4c8eb024 100644 --- a/train_ssd.py +++ b/train_ssd.py @@ -66,6 +66,8 @@ parser.add_argument('--pretrained_ssd', help='Pre-trained base model') parser.add_argument('--resume', default=None, type=str, help='Checkpoint state_dict file to resume training from') +parser.add_argument('--last-epoch', default=-1, type=int, help='Epoch to start from, when resuming the model.') + # Scheduler parser.add_argument('--scheduler', default="multi-step", type=str, @@ -240,8 +242,7 @@ def test(loader, net, criterion, device): logging.info("Build network.") net = create_net(num_classes) min_loss = -10000.0 - last_epoch = -1 - + base_net_lr = args.base_net_lr if args.base_net_lr is not None else args.lr extra_layers_lr = args.extra_layers_lr if args.extra_layers_lr is not None else args.lr if args.freeze_base_net: @@ -302,18 +303,18 @@ def test(loader, net, criterion, device): if args.scheduler == 'multi-step': logging.info("Uses MultiStepLR scheduler.") milestones = [int(v.strip()) for v in args.milestones.split(",")] - scheduler = MultiStepLR(optimizer, milestones=milestones, - gamma=0.1, last_epoch=last_epoch) + scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1) + elif args.scheduler == 'cosine': logging.info("Uses CosineAnnealingLR scheduler.") - scheduler = CosineAnnealingLR(optimizer, args.t_max, last_epoch=last_epoch) + scheduler = CosineAnnealingLR(optimizer, args.t_max) else: logging.fatal(f"Unsupported Scheduler: {args.scheduler}.") parser.print_help(sys.stderr) sys.exit(1) - logging.info(f"Start training from epoch {last_epoch + 1}.") - for epoch in range(last_epoch + 1, args.num_epochs): + logging.info(f"Start training from epoch {args.last_epoch + 1}.") + for epoch in range(args.last_epoch + 1, args.num_epochs): scheduler.step() train(train_loader, net, criterion, optimizer, device=DEVICE, debug_steps=args.debug_steps, epoch=epoch)