-
Notifications
You must be signed in to change notification settings - Fork 384
/
train.py
114 lines (94 loc) · 4.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import logging
import os
import torch
import torch.distributed as dist
from ssd.engine.inference import do_evaluation
from ssd.config import cfg
from ssd.data.build import make_data_loader
from ssd.engine.trainer import do_train
from ssd.modeling.detector import build_detection_model
from ssd.solver.build import make_optimizer, make_lr_scheduler
from ssd.utils import dist_util, mkdir
from ssd.utils.checkpoint import CheckPointer
from ssd.utils.dist_util import synchronize
from ssd.utils.logger import setup_logger
from ssd.utils.misc import str2bool
def train(cfg, args):
logger = logging.getLogger('SSD.trainer')
model = build_detection_model(cfg)
device = torch.device(cfg.MODEL.DEVICE)
model.to(device)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
lr = cfg.SOLVER.LR * args.num_gpus # scale by num gpus
optimizer = make_optimizer(cfg, model, lr)
milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
scheduler = make_lr_scheduler(cfg, optimizer, milestones)
arguments = {"iteration": 0}
save_to_disk = dist_util.get_rank() == 0
checkpointer = CheckPointer(model, optimizer, scheduler, cfg.OUTPUT_DIR, save_to_disk, logger)
extra_checkpoint_data = checkpointer.load()
arguments.update(extra_checkpoint_data)
max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
train_loader = make_data_loader(cfg, is_train=True, distributed=args.distributed, max_iter=max_iter, start_iter=arguments['iteration'])
model = do_train(cfg, model, train_loader, optimizer, scheduler, checkpointer, device, arguments, args)
return model
def main():
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training With PyTorch')
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--log_step', default=10, type=int, help='Print logs every log_step')
parser.add_argument('--save_step', default=2500, type=int, help='Save checkpoint every save_step')
parser.add_argument('--eval_step', default=2500, type=int, help='Evaluate dataset every eval_step, disabled when eval_step < 0')
parser.add_argument('--use_tensorboard', default=True, type=str2bool)
parser.add_argument(
"--skip-test",
dest="skip_test",
help="Do not test the final model",
action="store_true",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
args.num_gpus = num_gpus
if torch.cuda.is_available():
# This flag allows you to enable the inbuilt cudnn auto-tuner to
# find the best algorithm to use for your hardware.
torch.backends.cudnn.benchmark = True
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
if cfg.OUTPUT_DIR:
mkdir(cfg.OUTPUT_DIR)
logger = setup_logger("SSD", dist_util.get_rank(), cfg.OUTPUT_DIR)
logger.info("Using {} GPUs".format(num_gpus))
logger.info(args)
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
model = train(cfg, args)
if not args.skip_test:
logger.info('Start evaluating...')
torch.cuda.empty_cache() # speed up evaluating after training finished
do_evaluation(cfg, model, distributed=args.distributed)
if __name__ == '__main__':
main()