From 6fd0e11144a6abe27c47c2db492a695e99e942a5 Mon Sep 17 00:00:00 2001 From: AlexanderManich Date: Mon, 15 Apr 2024 09:49:46 +0200 Subject: [PATCH] added max_episode_steps to satisfy reward implementation --- src/fastfiz_env/envs/velocity_fastfiz.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/fastfiz_env/envs/velocity_fastfiz.py b/src/fastfiz_env/envs/velocity_fastfiz.py index d16ca64..fce1ca6 100644 --- a/src/fastfiz_env/envs/velocity_fastfiz.py +++ b/src/fastfiz_env/envs/velocity_fastfiz.py @@ -36,6 +36,17 @@ def __init__( self.table_state = create_table_state(self.num_balls) self.observation_space = self._observation_space() self.action_space = self._action_space() + self.max_episode_steps = None + + def _max_episode_steps(self): + if ( + #hasattr(SimpleFastFiz, "_time_limit_max_episode_steps") + self.get_wrapper_attr("_time_limit_max_episode_steps") is not None + ): + self.max_episode_steps = self.get_wrapper_attr( + "_time_limit_max_episode_steps" + ) + self.reward.max_episode_steps = self.max_episode_steps def reset( self, *, seed: Optional[int] = None, options: Optional[dict] = None @@ -45,6 +56,9 @@ def reset( """ super().reset(seed=seed) + if self.max_episode_steps is None: + self._max_episode_steps() + self.table_state = create_table_state(self.num_balls) self.reward.reset(self.table_state) observation = self._get_observation(self.table_state, [])