Skip to content

Commit

Permalink
Remove broken envs
Browse files Browse the repository at this point in the history
  • Loading branch information
MadsSR committed Apr 11, 2024
1 parent c0ba829 commit f8cb888
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 258 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Pool Agent Reinforcement Learning

Gymnasium environments for 8-ball pool, using FastFiz to simulate the physics of the game.
Gymnasium environment for 8-ball pool, using FastFiz to simulate the physics of the game.

## Preqrequisites

The following packages are required to run the environment:
The package, `python3-opengl` is required to run the environment. Install it using the following command:

```
apt-get install python3-opengl
Expand All @@ -24,12 +24,11 @@ Use the environment for training a reinforcement learning agent:

```python
from stable_baselines3 import PPO
import fastfiz_env
from fastfiz_env.utils import DefaultReward
from fastfiz_env import DefaultReward, make

env = fastfiz_env.make("BaseRLFastFiz-v0", reward_function=DefaultReward, num_balls=2)
env = make("SimpleFastFiz-v0", reward_function=DefaultReward, num_balls=2)

model = PPO("MlpPolicy", env)

model.learn(total_timesteps=10_000)
model.learn(total_timesteps=100_000)
```
32 changes: 15 additions & 17 deletions src/fastfiz_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Gymnasium environments for pool, using FastFiz to simulate the physics of the game.
Avaliable environments:
- `BaseFastFiz-v0`: Base class for FastFiz.
- `BaseRLFastFiz-v0`: Base class for FastFiz with reinforcement learning, using initial random table state.
- `PocketRLFastFiz-v0`: Subclass of BaseRLFastFiz. Observes if a ball is pocketed.
- `SimpleFastFiz-v0`: Observes the position of the balls.
- `VelocityFastFiz-v0`: Observes the velocity of the balls.
- `TestingFastFiz-v0`: Observes the position of the balls. Used for testing purposes with options e.g. seed, logging, action_space_id.
### Example
Expand All @@ -17,7 +17,7 @@
from fastfiz_env.utils.reward_functions.common import StepPocketedReward
reward_function = StepPocketedReward()
env = fastfiz_env.make("BaseRLFastFiz-v0", reward_function=reward_function, num_balls=2)
env = fastfiz_env.make("SimpleFastFiz-v0", reward_function=reward_function, num_balls=2)
model = PPO("MlpPolicy", env, verbose=1)
Expand All @@ -28,24 +28,22 @@
"""

from .make import make
from .reward_functions import DefaultReward, RewardFunction, CombinedReward
from . import envs, utils, wrappers, reward_functions

__all__ = ["make", "envs", "utils", "wrappers", "reward_functions"]
__all__ = [
"make",
"DefaultReward",
"RewardFunction",
"CombinedReward",
"envs",
"utils",
"wrappers",
"reward_functions",
]

from gymnasium.envs.registration import register

register(
id="BaseFastFiz-v0",
entry_point="fastfiz_env.envs:BaseFastFiz",
additional_wrappers=(wrappers.MaxEpisodeStepsInjectionWrapper.wrapper_spec(),),
)

register(
id="BaseRLFastFiz-v0",
entry_point="fastfiz_env.envs:BaseRLFastFiz",
additional_wrappers=(wrappers.MaxEpisodeStepsInjectionWrapper.wrapper_spec(),),
)


register(
id="VelocityFastFiz-v0",
Expand Down
4 changes: 0 additions & 4 deletions src/fastfiz_env/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
"""

from . import utils
from .base_fastfiz import BaseFastFiz
from .base_rl_fastfiz import BaseRLFastFiz
from .velocity_fastfiz import VelocityFastFiz
from .simple_fastfiz import SimpleFastFiz
from .testing_fastfiz import TestingFastFiz

__all__ = [
"utils",
"BaseFastFiz",
"BaseRLFastFiz",
"VelocityFastFiz",
"SimpleFastFiz",
"TestingFastFiz",
Expand Down
118 changes: 0 additions & 118 deletions src/fastfiz_env/envs/base_fastfiz.py

This file was deleted.

100 changes: 0 additions & 100 deletions src/fastfiz_env/envs/base_rl_fastfiz.py

This file was deleted.

11 changes: 4 additions & 7 deletions src/fastfiz_env/envs/simple_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(
self.max_episode_steps = None

def _max_episode_steps(self):
if self.get_wrapper_attr("_time_limit_max_episode_steps") is not None:
if (
hasattr(SimpleFastFiz, "_time_limit_max_episode_steps")
and self.get_wrapper_attr("_time_limit_max_episode_steps") is not None
):
self.max_episode_steps = self.get_wrapper_attr(
"_time_limit_max_episode_steps"
)
Expand Down Expand Up @@ -74,12 +77,6 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]
truncated = False
info = self._get_info()

pocketed = num_balls_pocketed(self.table_state)
if pocketed > self._prev_pocketed:
self._prev_pocketed = pocketed
else:
reward = min(-1, reward - 0.3)

return observation, reward, terminated, truncated, info

def _get_observation(self):
Expand Down
12 changes: 6 additions & 6 deletions src/tests/envs/test_envs.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import unittest
from fastfiz_env.envs import BaseRLFastFiz
from fastfiz_env.envs import SimpleFastFiz
from fastfiz_env.reward_functions.common import ConstantReward


class TestBaseRLFastFiz(unittest.TestCase):
class TestSimpleFastFiz(unittest.TestCase):
def test_init(self):
num_balls = 16
env = BaseRLFastFiz(num_balls=num_balls)
env = SimpleFastFiz(num_balls=num_balls)
self.assertEqual(env.observation_space.shape, (16, 2))
self.assertEqual(env.action_space.shape, (5,))
self.assertEqual(env.action_space.shape, (3,))

def test_reset(self):
num_balls = 16
env = BaseRLFastFiz(num_balls=num_balls)
env = SimpleFastFiz(num_balls=num_balls)
obs, info = env.reset()
self.assertEqual(obs.shape, (16, 2))
self.assertEqual(info, {"is_success": False})

def test_step(self):
num_balls = 16
env = BaseRLFastFiz(num_balls=num_balls, reward_function=ConstantReward())
env = SimpleFastFiz(num_balls=num_balls, reward_function=ConstantReward())
env.reset()
action = [0, 0, 60, 0, 0]
obs, reward, done, truncated, info = env.step(action)
Expand Down

0 comments on commit f8cb888

Please sign in to comment.