From c17989b05bd3fa996e51bf9b66ac8bc5856f2ad5 Mon Sep 17 00:00:00 2001 From: Mads Risager Date: Thu, 25 Apr 2024 11:55:40 +0200 Subject: [PATCH] Fix model name --- src/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index 121d5b5..aadc06f 100644 --- a/src/train.py +++ b/src/train.py @@ -27,8 +27,8 @@ def get_latest_run_id(log_path: str, name: str) -> int: return id -def get_model_name(env_name: str, balls: int, algo: str = "PPO") -> str: - return f"{env_name.split('FastFiz-v0')[0]}-{balls}_balls-{algo}".lower() +def get_model_name(env_name: str, balls: int, algo: str = "PPO", action_space_id=ActionSpaces.VECTOR_3D) -> str: + return f"{env_name.split('FastFiz-v0')[0]}-{balls}_balls-{action_space_id.name}-{algo}".lower() def train( @@ -52,7 +52,7 @@ def train( hyperparams = params_to_kwargs(**params) if params else {} print(hyperparams) - model_name = get_model_name(env_id, num_balls) + model_name = get_model_name(env_id, num_balls, action_space_id) if model_dir is None: model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=logs_path, **hyperparams) @@ -156,7 +156,8 @@ def train( model_path: {model_path}\n\ logs_path: {logs_path}\n\ models_path: {models_path}\n\ - reward_function: {reward}\n" + reward_function: {reward}\n\ + action_space_id: {args.action_id}\n" ) train(