Skip to content

Commit

Permalink
Fix Poolfiz errors
Browse files Browse the repository at this point in the history
  • Loading branch information
MadsSR committed May 13, 2024
1 parent 3c8e941 commit 4abfc75
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 16 deletions.
11 changes: 4 additions & 7 deletions src/fastfiz_env/envs/pockets_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]
shot_params = ff.ShotParams(*action)

if self._possible_shot(shot_params):
try:
self.table_state.executeShot(shot_params)
except Exception:
pass
self.table_state.executeShot(shot_params, verbose=True)

observation = self._get_observation()
reward = self.reward.get_reward(prev_table_state, self.table_state, action)
Expand All @@ -106,11 +103,11 @@ def compute_observation(cls, table_state: ff.TableState) -> np.ndarray:
if ball.isPocketed():
pocket = ball_state_to_pocket(ball.getState())
pocket_pos = get_pocket_center(pocket)
observation[i] = [*pocket_pos, 0]
observation[i] = [*pocket_pos, 1]
elif ball.isInPlay():
observation[i] = [*ball_pos, 1]
observation[i] = [*ball_pos, 0]
else:
observation[i] = [0, 0, 0]
observation[i] = [0, 0, 1]

return np.array(observation)

Expand Down
5 changes: 1 addition & 4 deletions src/fastfiz_env/envs/simple_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]
shot_params = ff.ShotParams(*action)

if self._possible_shot(shot_params):
try:
self.table_state.executeShot(shot_params)
except Exception:
pass
self.table_state.executeShot(shot_params)

observation = self._get_observation()
reward = self.reward.get_reward(prev_table_state, self.table_state, action)
Expand Down
4 changes: 2 additions & 2 deletions src/fastfiz_env/wrappers/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class FastFizActionWrapper(ActionWrapper):
MAX_PHI = 360 - 0.001
MIN_VELOCITY = 0
MAX_VELOCITY = 10 - 0.001
MIN_OFFSET = -28
MAX_OFFSET = 28
MIN_OFFSET = -28 + 0.001
MAX_OFFSET = 28 - 0.001
SPACES = {
"VECTOR_2D": spaces.Box(
low=np.array([-1, -1]),
Expand Down
8 changes: 5 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def train(
n_envs=n_envs,
)

# env = VecCheckNan(env, raise_exception=True)

hyperparams = params_to_kwargs(**params) if params else {}
print(hyperparams)
model_name = get_model_name(env_id, num_balls, action_space_id=action_space_id)
Expand Down Expand Up @@ -97,8 +99,8 @@ def train(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, required=True)
parser.add_argument("-b", "--num_balls", type=int, required=True)
parser.add_argument("--env", type=str, default="PocketsFastFiz-v0")
parser.add_argument("-b", "--num_balls", type=int, default=2)
parser.add_argument("-m", "--max_episode_steps", type=int, default=20)
parser.add_argument("-n", "--n_time_steps", type=int, default=1_000_000)
parser.add_argument(
Expand All @@ -125,7 +127,7 @@ def train(
)

parser.add_argument(
"-a", "--action_id", type=lambda a: ActionSpaces[a], choices=list(ActionSpaces), default=ActionSpaces.VECTOR_3D
"-a", "--action_id", type=lambda a: ActionSpaces[a], choices=list(ActionSpaces), default=ActionSpaces.OFFSET_NORM_5D
)

args = parser.parse_args()
Expand Down

0 comments on commit 4abfc75

Please sign in to comment.