-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlaunch.py
51 lines (45 loc) · 1.69 KB
/
launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import argparse
import json
import os
import sys
import warnings
from pointrix.utils.config import load_config
from pointrix.engine.default_trainer import DefaultTrainer
from pointrix.logger.writer import logproject, Logger
from model import FDS_camera, FDSGaussianModel
from hook import NormalLogHook
from dataset import MushRoomDataset
from exporter import DepthMetricExporter
def main(args, extras) -> None:
warnings.filterwarnings("ignore")
cfg = load_config(args.config, cli_args=extras)
project_path = os.path.dirname(os.path.abspath(__file__))
logproject(project_path, os.path.join(cfg.exp_dir, 'project_file'), ['py', 'yaml'])
try:
gaussian_trainer = DefaultTrainer(
cfg.trainer,
cfg.exp_dir,
cfg.name
)
if cfg.trainer.training:
gaussian_trainer.train_loop()
model_path = os.path.join(
cfg.exp_dir,
"chkpnt" + str(gaussian_trainer.global_step) + ".pth"
)
gaussian_trainer.save_model(path=model_path)
gaussian_trainer.test()
else:
gaussian_trainer.test(cfg.trainer.test_model_path)
Logger.print("\nTraining complete.")
except:
Logger.print_exception(show_locals=False)
for hook in gaussian_trainer.hooks:
hook.exception()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config file")
parser.add_argument("--smc_file", type=str, default = None)
args, extras = parser.parse_known_args()
main(args, extras)