Skip to content

Commit

Permalink
Fix model name
Browse files Browse the repository at this point in the history
  • Loading branch information
MadsSR committed Apr 25, 2024
1 parent ee54e63 commit c17989b
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def get_latest_run_id(log_path: str, name: str) -> int:
return id


def get_model_name(env_name: str, balls: int, algo: str = "PPO") -> str:
return f"{env_name.split('FastFiz-v0')[0]}-{balls}_balls-{algo}".lower()
def get_model_name(env_name: str, balls: int, algo: str = "PPO", action_space_id=ActionSpaces.VECTOR_3D) -> str:
return f"{env_name.split('FastFiz-v0')[0]}-{balls}_balls-{action_space_id.name}-{algo}".lower()


def train(
Expand All @@ -52,7 +52,7 @@ def train(

hyperparams = params_to_kwargs(**params) if params else {}
print(hyperparams)
model_name = get_model_name(env_id, num_balls)
model_name = get_model_name(env_id, num_balls, action_space_id)

if model_dir is None:
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=logs_path, **hyperparams)
Expand Down Expand Up @@ -156,7 +156,8 @@ def train(
model_path: {model_path}\n\
logs_path: {logs_path}\n\
models_path: {models_path}\n\
reward_function: {reward}\n"
reward_function: {reward}\n\
action_space_id: {args.action_id}\n"
)

train(
Expand Down

0 comments on commit c17989b

Please sign in to comment.