diff --git a/src/fastfiz_env/envs/pockets_fastfiz.py b/src/fastfiz_env/envs/pockets_fastfiz.py index 6b4dcf9..8dfb519 100644 --- a/src/fastfiz_env/envs/pockets_fastfiz.py +++ b/src/fastfiz_env/envs/pockets_fastfiz.py @@ -80,10 +80,7 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict] shot_params = ff.ShotParams(*action) if self._possible_shot(shot_params): - try: - self.table_state.executeShot(shot_params) - except Exception: - pass + self.table_state.executeShot(shot_params, verbose=True) observation = self._get_observation() reward = self.reward.get_reward(prev_table_state, self.table_state, action) @@ -106,11 +103,11 @@ def compute_observation(cls, table_state: ff.TableState) -> np.ndarray: if ball.isPocketed(): pocket = ball_state_to_pocket(ball.getState()) pocket_pos = get_pocket_center(pocket) - observation[i] = [*pocket_pos, 0] + observation[i] = [*pocket_pos, 1] elif ball.isInPlay(): - observation[i] = [*ball_pos, 1] + observation[i] = [*ball_pos, 0] else: - observation[i] = [0, 0, 0] + observation[i] = [0, 0, 1] return np.array(observation) diff --git a/src/fastfiz_env/envs/simple_fastfiz.py b/src/fastfiz_env/envs/simple_fastfiz.py index 5879766..bc628c4 100644 --- a/src/fastfiz_env/envs/simple_fastfiz.py +++ b/src/fastfiz_env/envs/simple_fastfiz.py @@ -76,10 +76,7 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict] shot_params = ff.ShotParams(*action) if self._possible_shot(shot_params): - try: - self.table_state.executeShot(shot_params) - except Exception: - pass + self.table_state.executeShot(shot_params) observation = self._get_observation() reward = self.reward.get_reward(prev_table_state, self.table_state, action) diff --git a/src/fastfiz_env/wrappers/action.py b/src/fastfiz_env/wrappers/action.py index b1fe6d9..3ddb978 100644 --- a/src/fastfiz_env/wrappers/action.py +++ b/src/fastfiz_env/wrappers/action.py @@ -72,8 +72,8 @@ class FastFizActionWrapper(ActionWrapper): MAX_PHI = 360 - 0.001 MIN_VELOCITY = 0 MAX_VELOCITY = 10 - 0.001 - MIN_OFFSET = -28 - MAX_OFFSET = 28 + MIN_OFFSET = -28 + 0.001 + MAX_OFFSET = 28 - 0.001 SPACES = { "VECTOR_2D": spaces.Box( low=np.array([-1, -1]), diff --git a/src/train.py b/src/train.py index 6405890..ddc352d 100644 --- a/src/train.py +++ b/src/train.py @@ -50,6 +50,8 @@ def train( n_envs=n_envs, ) + # env = VecCheckNan(env, raise_exception=True) + hyperparams = params_to_kwargs(**params) if params else {} print(hyperparams) model_name = get_model_name(env_id, num_balls, action_space_id=action_space_id) @@ -97,8 +99,8 @@ def train( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--env", type=str, required=True) - parser.add_argument("-b", "--num_balls", type=int, required=True) + parser.add_argument("--env", type=str, default="PocketsFastFiz-v0") + parser.add_argument("-b", "--num_balls", type=int, default=2) parser.add_argument("-m", "--max_episode_steps", type=int, default=20) parser.add_argument("-n", "--n_time_steps", type=int, default=1_000_000) parser.add_argument( @@ -125,7 +127,7 @@ def train( ) parser.add_argument( - "-a", "--action_id", type=lambda a: ActionSpaces[a], choices=list(ActionSpaces), default=ActionSpaces.VECTOR_3D + "-a", "--action_id", type=lambda a: ActionSpaces[a], choices=list(ActionSpaces), default=ActionSpaces.OFFSET_NORM_5D ) args = parser.parse_args()