diff --git a/src/fastfiz_env/make.py b/src/fastfiz_env/make.py index 9b13b4e..14d315b 100644 --- a/src/fastfiz_env/make.py +++ b/src/fastfiz_env/make.py @@ -1,3 +1,4 @@ +from typing import Optional from gymnasium.envs.registration import EnvSpec import gymnasium as gym @@ -39,7 +40,11 @@ def make( def make_wrapped_env( - env_id: str, num_balls: int, max_episode_steps: int, reward_function: RewardFunction + env_id: str, + num_balls: int, + max_episode_steps: int, + reward_function: RewardFunction, + **kwargs, ): """ Create an instance of the specified environment with the FastFizActionWrapper. @@ -50,6 +55,7 @@ def make_wrapped_env( num_balls=num_balls, max_episode_steps=max_episode_steps, disable_env_checker=False, + **kwargs, ) env = FastFizActionWrapper(env, action_space_id=ActionSpaces.NO_OFFSET_3D) return env @@ -60,6 +66,7 @@ def make_callable_wrapped_env( num_balls: int, max_episode_steps: int, reward_function: RewardFunction, + **kwargs, ): """ Create a callable function that returns an instance of the specified environment with the FastFizActionWrapper. @@ -67,6 +74,8 @@ def make_callable_wrapped_env( """ def _init() -> gym.Env: - return make_wrapped_env(env_id, num_balls, max_episode_steps, reward_function) + return make_wrapped_env( + env_id, num_balls, max_episode_steps, reward_function, **kwargs + ) return _init diff --git a/src/optimize.py b/src/optimize.py index 222728b..0de08dc 100644 --- a/src/optimize.py +++ b/src/optimize.py @@ -135,13 +135,14 @@ def objective( n_timesteps: int, start_time: str, no_logs: bool, + env_kwargs: dict, ) -> float: kwargs = sample_ppo_params(trial) N_ENVS = 4 env = make_vec_env( make_callable_wrapped_env( - env_id, num_balls, max_episode_steps, reward_function + env_id, num_balls, max_episode_steps, reward_function, **env_kwargs ), n_envs=N_ENVS, ) @@ -206,6 +207,28 @@ def save_trial(trial: optuna.trial.FrozenTrial, path: str) -> None: ) +class StoreDict(argparse.Action): + """ + Custom argparse action for storing dict. + + In: args1:0.0 args2:"dict(a=1)" + Out: {'args1': 0.0, arg2: dict(a=1)} + """ + + def __init__(self, option_strings, dest, nargs=None, **kwargs): + self._nargs = nargs + super().__init__(option_strings, dest, nargs=nargs, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + arg_dict = {} + for arguments in values: # type: ignore + key = arguments.split(":")[0] + value = ":".join(arguments.split(":")[1:]) + # Evaluate the string as python code + arg_dict[key] = eval(value) + setattr(namespace, self.dest, arg_dict) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Description of your program") parser.add_argument("--n_trials", type=int, default=20, help="Number of trials") @@ -249,6 +272,15 @@ def save_trial(trial: optuna.trial.FrozenTrial, path: str) -> None: "--no-logs", action="store_true", help="Disable Tensorboard logging" ) + parser.add_argument( + "--env-options", + type=str, + nargs="+", + action=StoreDict, + help="Optional keyword argument to pass to the env constructor", + default={}, + ) + args = parser.parse_args() # Set pytorch num threads to 1 for faster training. @@ -265,6 +297,8 @@ def save_trial(trial: optuna.trial.FrozenTrial, path: str) -> None: reward_function = DefaultReward if args.reward == "DefaultReward" else WinningReward + env_kwargs = {"options": args.env_options} + def obj_fn(trial): return objective( trial, @@ -277,6 +311,7 @@ def obj_fn(trial): args.n_timesteps, start_time, args.no_logs, + env_kwargs, ) try: