Skip to content

Commit

Permalink
Add env options arg
Browse files Browse the repository at this point in the history
  • Loading branch information
MadsSR committed Apr 18, 2024
1 parent fe87f5a commit e537e54
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/fastfiz_env/make.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from gymnasium.envs.registration import EnvSpec
import gymnasium as gym

Expand Down Expand Up @@ -39,7 +40,11 @@ def make(


def make_wrapped_env(
env_id: str, num_balls: int, max_episode_steps: int, reward_function: RewardFunction
env_id: str,
num_balls: int,
max_episode_steps: int,
reward_function: RewardFunction,
**kwargs,
):
"""
Create an instance of the specified environment with the FastFizActionWrapper.
Expand All @@ -50,6 +55,7 @@ def make_wrapped_env(
num_balls=num_balls,
max_episode_steps=max_episode_steps,
disable_env_checker=False,
**kwargs,
)
env = FastFizActionWrapper(env, action_space_id=ActionSpaces.NO_OFFSET_3D)
return env
Expand All @@ -60,13 +66,16 @@ def make_callable_wrapped_env(
num_balls: int,
max_episode_steps: int,
reward_function: RewardFunction,
**kwargs,
):
"""
Create a callable function that returns an instance of the specified environment with the FastFizActionWrapper.
This is useful for creating environments in parallel or with stable-baselines `make_vec_env` function.
"""

def _init() -> gym.Env:
return make_wrapped_env(env_id, num_balls, max_episode_steps, reward_function)
return make_wrapped_env(
env_id, num_balls, max_episode_steps, reward_function, **kwargs
)

return _init
37 changes: 36 additions & 1 deletion src/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,14 @@ def objective(
n_timesteps: int,
start_time: str,
no_logs: bool,
env_kwargs: dict,
) -> float:
kwargs = sample_ppo_params(trial)
N_ENVS = 4

env = make_vec_env(
make_callable_wrapped_env(
env_id, num_balls, max_episode_steps, reward_function
env_id, num_balls, max_episode_steps, reward_function, **env_kwargs
),
n_envs=N_ENVS,
)
Expand Down Expand Up @@ -206,6 +207,28 @@ def save_trial(trial: optuna.trial.FrozenTrial, path: str) -> None:
)


class StoreDict(argparse.Action):
"""
Custom argparse action for storing dict.
In: args1:0.0 args2:"dict(a=1)"
Out: {'args1': 0.0, arg2: dict(a=1)}
"""

def __init__(self, option_strings, dest, nargs=None, **kwargs):
self._nargs = nargs
super().__init__(option_strings, dest, nargs=nargs, **kwargs)

def __call__(self, parser, namespace, values, option_string=None):
arg_dict = {}
for arguments in values: # type: ignore
key = arguments.split(":")[0]
value = ":".join(arguments.split(":")[1:])
# Evaluate the string as python code
arg_dict[key] = eval(value)
setattr(namespace, self.dest, arg_dict)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Description of your program")
parser.add_argument("--n_trials", type=int, default=20, help="Number of trials")
Expand Down Expand Up @@ -249,6 +272,15 @@ def save_trial(trial: optuna.trial.FrozenTrial, path: str) -> None:
"--no-logs", action="store_true", help="Disable Tensorboard logging"
)

parser.add_argument(
"--env-options",
type=str,
nargs="+",
action=StoreDict,
help="Optional keyword argument to pass to the env constructor",
default={},
)

args = parser.parse_args()

# Set pytorch num threads to 1 for faster training.
Expand All @@ -265,6 +297,8 @@ def save_trial(trial: optuna.trial.FrozenTrial, path: str) -> None:

reward_function = DefaultReward if args.reward == "DefaultReward" else WinningReward

env_kwargs = {"options": args.env_options}

def obj_fn(trial):
return objective(
trial,
Expand All @@ -277,6 +311,7 @@ def obj_fn(trial):
args.n_timesteps,
start_time,
args.no_logs,
env_kwargs,
)

try:
Expand Down

0 comments on commit e537e54

Please sign in to comment.