-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun.py
36 lines (25 loc) · 949 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from jax import config
config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)
import yaml
from adept import ergoExo
from adept.utils import export_run
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Automatic Differentiation Enabled Plasma Transport")
parser.add_argument("--cfg", help="enter path to cfg")
parser.add_argument("--run_id", help="enter run_id to continue")
args = parser.parse_args()
exo = ergoExo()
if args.run_id is None:
with open(f"{os.path.join(os.getcwd(), args.cfg)}.yaml", "r") as fi:
cfg = yaml.safe_load(fi)
modules = exo.setup(cfg=cfg)
sol, post_out, run_id = exo(modules)
else:
exo.run_job(args.run_id, nested=None)
run_id = args.run_id
if "MLFLOW_EXPORT" in os.environ:
export_run(run_id)