-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_autoencoder.py
56 lines (45 loc) · 1.52 KB
/
main_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import importlib
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from utils import generate_path, load_yaml, save_yaml
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-sf",
"--servers_file",
type=str,
default="config/servers.yml",
help="Servers file.",
)
parser.add_argument(
"-s", "--server", type=str, default="local", help="Specify server."
)
parser.add_argument(
"-c", "--config", type=str, required=True, help="Path to configuration file."
)
args = parser.parse_args()
# config
servers = load_yaml(args.servers_file)
server = servers[args.server]
config_path = server["config"]["location"]
configs = load_yaml(f"{config_path}/{args.config}")
model = getattr(
importlib.import_module(configs["model"]["module"]), configs["model"]["name"]
)(**configs["model"]["kwargs"])
data_module = getattr(
importlib.import_module(configs["data"]["module"]), configs["data"]["name"]
)(**configs["data"]["kwargs"])
# output
logger = TensorBoardLogger(
save_dir=server["logging"]["location"], name=configs["experiment"]
)
# save configs
generate_path(logger.log_dir)
save_yaml(f"{logger.log_dir}/config.yml", configs)
# train
trainer = pl.Trainer(**configs["trainer"], logger=logger)
trainer.fit(model, data_module)
# trainer.test(model, data_module)
if __name__ == "__main__":
main()