Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure, cleanup and linting #9

Merged
merged 44 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b12ea88
Save best trial kwargs for PPO model
MadsSR Apr 18, 2024
d8b72d7
Remove sb3 dependency from fastfiz-env package
MadsSR Apr 18, 2024
2f1600a
Add params arg to set hyperparams
MadsSR Apr 18, 2024
e38ada3
Add make_callable_wrapped_env to __all__
MadsSR Apr 18, 2024
5708677
Add params_to_kwargs() function
MadsSR Apr 18, 2024
dd680d2
Add params_to_kwargs() function
MadsSR Apr 18, 2024
c1658b6
Add pyproject.toml
MadsSR Apr 18, 2024
fe87f5a
Add options arg to envs
MadsSR Apr 18, 2024
e537e54
Add env options arg
MadsSR Apr 18, 2024
bb9a25c
Fix assert value
MadsSR Apr 18, 2024
368e7a0
Remove p5 import
MadsSR Apr 18, 2024
fb98c96
Add swig and gsl to dependencies
MadsSR Apr 18, 2024
bcf7a07
Format with Ruff
MadsSR Apr 18, 2024
1a5d2e0
Lint with Ruff
MadsSR Apr 18, 2024
109eee5
Fix ruff checks
MadsSR Apr 18, 2024
2d3f7fa
Fix coordinate calculations
MadsSR Apr 18, 2024
5744e74
Remove TestingFastFiz env
MadsSR Apr 18, 2024
dba9b07
Rename action space
MadsSR Apr 18, 2024
d766fba
Rename action space
MadsSR Apr 18, 2024
da7b094
Fix type hints
MadsSR Apr 18, 2024
d09391f
Fix type hints
MadsSR Apr 18, 2024
e423963
Check binary reward function instance
MadsSR Apr 18, 2024
8971d0a
Remove type hint
MadsSR Apr 18, 2024
d8d53a5
Fix type hints
MadsSR Apr 18, 2024
5a4d634
Fix type hints
MadsSR Apr 18, 2024
35e5598
Format with Ruff
MadsSR Apr 18, 2024
1558c51
Rename test
MadsSR Apr 18, 2024
6e627a2
Remove unsued utils
MadsSR Apr 18, 2024
f57da5a
Fix workflow
MadsSR Apr 18, 2024
b2d3ae3
Fix workflow
MadsSR Apr 18, 2024
59ebf1a
Remove requirements.txt
MadsSR Apr 18, 2024
72fdd2c
Remove .vscode
MadsSR Apr 18, 2024
f418e1e
Remove unused
MadsSR Apr 18, 2024
9bde5bd
Fix cart2sph
MadsSR Apr 23, 2024
1baf5eb
Add plot script
MadsSR Apr 23, 2024
da36491
Remove velocity from reward
MadsSR Apr 24, 2024
26b4cf2
Fix calc and conversion
MadsSR Apr 24, 2024
ee54e63
Add action_space_id option
MadsSR Apr 25, 2024
c17989b
Fix model name
MadsSR Apr 25, 2024
20b71da
Add script to log random policy evaluation metrics
MadsSR Apr 30, 2024
8dd0213
Reset with seed
MadsSR Apr 30, 2024
f9fd52f
Remove commit versions
MadsSR Apr 30, 2024
d29ed78
Setup Latex plots
MadsSR Apr 30, 2024
c012f8b
Fix already defined
MadsSR Apr 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Python package
on:
push:
pull_request:
branches: ['main']

jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.10', '3.11']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
sudo apt update
sudo apt install python3-opengl swig libgsl-dev
python -m pip install --upgrade pip
pip install ".[test]"

- name: Lint with Ruff
run: |
ruff check src/fastfiz_env
ruff format src/fastfiz_env

- name: Run MyPy
run: |
mypy src/fastfiz_env

- name: Test with pytest
run: |
pytest
40 changes: 0 additions & 40 deletions .github/workflows/python-package.yml

This file was deleted.

4 changes: 0 additions & 4 deletions .vscode/settings.json

This file was deleted.

57 changes: 57 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"


[project]
name = "fastfiz-env"
description = "Gymnasium environments for FastFiz pool simulator."
readme = "README.md"
requires-python = ">=3.10"
dynamic = ["version"]
dependencies = [
"fastfiz @ git+https://github.com/P6-Pool/fastfiz.git",
"gymnasium",
"numpy",
"vectormath",
]

[project.optional-dependencies]
dev = [
"fastfiz_renderer @ git+https://github.com/P6-Pool/fastfiz-renderer.git",
"stable-baselines3",
"tqdm",
"rich",
"torch",
"tensorboard",
"optuna",
]
test = ["pytest", "mypy", "ruff"]
all = [
# dev
"fastfiz_renderer @ git+https://github.com/P6-Pool/fastfiz-renderer.git",
"stable-baselines3",
"tqdm",
"rich",
"torch",
"tensorboard",
"optuna",
# test
"pytest",
"mypy",
"ruff",
]

[tool.pytest.ini_options]
filterwarnings = [
"ignore::DeprecationWarning:tensorboard",
"ignore::UserWarning:gym",
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true

[tool.ruff]
line-length = 127
9 changes: 0 additions & 9 deletions requirements.txt

This file was deleted.

30 changes: 19 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from setuptools import setup, find_packages
import re
from setuptools import setup

# with open("requirements.txt") as f:
# requirements = f.read().splitlines()


def get_version():
with open("src/fastfiz_env/__init__.py", "r") as f:
for line in f:
match = re.match(r"__version__\s*=\s*['\"]([^'\"]+)['\"]", line)
if match:
return match.group(1)
raise RuntimeError("Version not found in __init__.py")

with open("requirements.txt") as f:
requirements = f.read().splitlines()

setup(
name="fastfiz-env",
description="Gymnasium environment for FastFiz pool simulator",
version="0.0.1",
license="MIT",
install_requires=requirements,
test_requires=["pytest"],
packages=find_packages(where="src"),
package_dir={"": "src"},
version=get_version(),
# install_requires=requirements,
# test_requires=["pytest"],
# packages=find_packages(where="src"),
# package_dir={"": "src"},
)
23 changes: 6 additions & 17 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@
import os
from fastfiz_renderer import GameHandler
import numpy as np
import fastfiz_env
from fastfiz_env.envs import FramesFastFiz, SimpleFastFiz, PocketsFastFiz
from fastfiz_env.reward_functions import reward_function
from fastfiz_env.reward_functions.default_reward import DefaultReward
from fastfiz_env.envs import FramesFastFiz, PocketsFastFiz
from fastfiz_env.utils.fastfiz import (
create_random_table_state,
get_ball_positions,
normalize_ball_positions,
)
from fastfiz_env.envs.utils import game_won, possible_shot
from stable_baselines3 import PPO
from typing import Optional, Callable
from typing import Optional
import argparse

from fastfiz_env.wrappers.action import ActionSpaces, FastFizActionWrapper
from fastfiz_env.wrappers.utils import spherical_coordinates


def get_play_config() -> dict:
Expand Down Expand Up @@ -85,13 +80,9 @@ def decide_shot(self, table_state: ff.TableState) -> Optional[ff.ShotParams]:
for _ in range(10):
if isinstance(self.env, FramesFastFiz):
if self.prev_ts is None:
obs = self.env.compute_observation(
table_state, table_state, self.shot
)
obs = self.env.compute_observation(table_state, table_state, self.shot)
else:
obs = self.env.compute_observation(
self.prev_ts, table_state, self.shot
)
obs = self.env.compute_observation(self.prev_ts, table_state, self.shot)
elif isinstance(self.env, PocketsFastFiz):
obs = self.env.compute_observation(table_state)
else:
Expand All @@ -114,15 +105,13 @@ def main() -> None:
parser.add_argument("-m", "--model", type=str, help="Path to the model file")
args = parser.parse_args()

assert args.model is not None and os.path.exists(
args.model
), f"Model file not found: {args.model}"
assert args.model is not None and os.path.exists(args.model), f"Model file not found: {args.model}"

model = PPO.load(args.model)

# env_vec = fastfiz_env.make("SimpleFastFiz-v0", reward_function=DefaultReward)
# env_vec = FastFizActionWrapper(env_vec, ActionSpaces.NO_OFFSET_3D)
env = FastFizActionWrapper(PocketsFastFiz, ActionSpaces.NO_OFFSET_3D)
env = FastFizActionWrapper(PocketsFastFiz, ActionSpaces.VECTOR_3D)
agent = Agent(model, env)
play(agent.decide_shot, balls=2, episodes=100)

Expand Down
13 changes: 4 additions & 9 deletions src/fastfiz_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

Avaliable environments:
- `SimpleFastFiz-v0`: Observes the position of the balls.
- `TestingFastFiz-v0`: Observes the position of the balls. Used for testing purposes with options e.g. seed, logging, action_space_id.
- `FramesFastFiz-v0`: Observes the position of the balls and the frames of the simulation.
- `PocketsFastFiz-v0`: Observes the position of the balls and in play state. Pocketed balls position always corresponds to given pocket center.

Expand All @@ -28,14 +27,16 @@

"""

from .make import make, make_wrapped_vec_env, make_wrapped_env
__version__ = "0.0.1"

from .make import make, make_wrapped_env, make_callable_wrapped_env
from .reward_functions import DefaultReward, RewardFunction, CombinedReward
from . import envs, utils, wrappers, reward_functions

__all__ = [
"make",
"make_wrapped_vec_env",
"make_wrapped_env",
"make_callable_wrapped_env",
"DefaultReward",
"RewardFunction",
"CombinedReward",
Expand All @@ -55,12 +56,6 @@
)


register(
id="TestingFastFiz-v0",
entry_point="fastfiz_env.envs:TestingFastFiz",
additional_wrappers=(wrappers.TimeLimitInjectionWrapper.wrapper_spec(),),
)

register(
id="FramesFastFiz-v0",
entry_point="fastfiz_env.envs:FramesFastFiz",
Expand Down
2 changes: 0 additions & 2 deletions src/fastfiz_env/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

from . import utils
from .simple_fastfiz import SimpleFastFiz
from .testing_fastfiz import TestingFastFiz
from .frames_fastfiz import FramesFastFiz
from .pockets_fastfiz import PocketsFastFiz

__all__ = [
"utils",
"SimpleFastFiz",
"TestingFastFiz",
"FramesFastFiz",
"PocketsFastFiz",
]
37 changes: 6 additions & 31 deletions src/fastfiz_env/envs/frames_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

from fastfiz_env.envs.utils import game_won, terminal_state
from ..utils.fastfiz import (
shot_params_from_action,
get_ball_positions,
create_random_table_state,
get_ball_velocity,
normalize_ball_positions,
normalize_ball_velocity,
is_pocketed_state,
Expand All @@ -28,9 +26,7 @@ class FramesFastFiz(gym.Env):
TOTAL_BALLS = 16 # Including the cue ball
num_balls = 2

def __init__(
self, reward_function: RewardFunction = DefaultReward, num_balls: int = 16
) -> None:
def __init__(self, reward_function: RewardFunction = DefaultReward, num_balls: int = 16) -> None:
super().__init__()
if num_balls < 2:
warnings.warn(
Expand All @@ -47,15 +43,11 @@ def __init__(

def _max_episode_steps(self):
if self.get_wrapper_attr("_time_limit_max_episode_steps") is not None:
self.max_episode_steps = self.get_wrapper_attr(
"_time_limit_max_episode_steps"
)
self.max_episode_steps = self.get_wrapper_attr("_time_limit_max_episode_steps")
print(f"Setting max episode steps to {self.max_episode_steps}")
self.reward.max_episode_steps = self.max_episode_steps

def reset(
self, *, seed: Optional[int] = None, options: Optional[dict] = None
) -> tuple[np.ndarray, dict]:
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]:
"""
Reset the environment to its initial state.
"""
Expand Down Expand Up @@ -137,18 +129,7 @@ def _observation_space(self):

All values are in the range `[0, TABLE_WIDTH]` and `[0, TABLE_LENGTH]`.
"""
table = self.table_state.getTable()

lower = np.full((self.TOTAL_BALLS, 4), [-1, -1, -1, 0])
# upper = np.full(
# (self.TOTAL_BALLS, 4),
# [
# table.TABLE_WIDTH,
# table.TABLE_LENGTH,
# self.table_state.MAX_VELOCITY * 1.580,
# 1,
# ],
# )
upper = np.full(
(self.TOTAL_BALLS, 4),
[1, 1, 1, 1],
Expand Down Expand Up @@ -187,14 +168,9 @@ def _possible_shot(self, shot_params: ff.ShotParams) -> bool:
"""
Check if the shot is possible.
"""
return (
self.table_state.isPhysicallyPossible(shot_params)
== ff.TableState.OK_PRECONDITION
)
return self.table_state.isPhysicallyPossible(shot_params) == ff.TableState.OK_PRECONDITION

def _compute_observation(
self, prev_table_state: ff.TableState, shot: Optional[ff.Shot]
) -> np.ndarray:
def _compute_observation(self, prev_table_state: ff.TableState, shot: Optional[ff.Shot]) -> np.ndarray:
return self.compute_observation(prev_table_state, self.table_state, shot)

@classmethod
Expand Down Expand Up @@ -236,8 +212,7 @@ def compute_observation(
pocketed = is_pocketed_state(gb.state)
frames_seq[frame][gb.number] = [
*normalize_ball_positions((gb.position.x, gb.position.y)), # type: ignore
normalize_ball_velocity(np.hypot(gb.velocity.x, gb.velocity.y)) * 2
- 1,
normalize_ball_velocity(np.hypot(gb.velocity.x, gb.velocity.y)) * 2 - 1,
pocketed,
]
return frames_seq
Loading