Skip to content

Commit

Permalink
Merge pull request #9 from P6-Pool/pyproject
Browse files Browse the repository at this point in the history
Restructure, cleanup and linting
  • Loading branch information
MadsSR authored Apr 30, 2024
2 parents e9d6ed1 + c012f8b commit b8728c8
Show file tree
Hide file tree
Showing 40 changed files with 851 additions and 936 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Python package
on:
push:
pull_request:
branches: ['main']

jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.10', '3.11']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
sudo apt update
sudo apt install python3-opengl swig libgsl-dev
python -m pip install --upgrade pip
pip install ".[test]"
- name: Lint with Ruff
run: |
ruff check src/fastfiz_env
ruff format src/fastfiz_env
- name: Run MyPy
run: |
mypy src/fastfiz_env
- name: Test with pytest
run: |
pytest
40 changes: 0 additions & 40 deletions .github/workflows/python-package.yml

This file was deleted.

4 changes: 0 additions & 4 deletions .vscode/settings.json

This file was deleted.

57 changes: 57 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"


[project]
name = "fastfiz-env"
description = "Gymnasium environments for FastFiz pool simulator."
readme = "README.md"
requires-python = ">=3.10"
dynamic = ["version"]
dependencies = [
"fastfiz @ git+https://github.com/P6-Pool/fastfiz.git",
"gymnasium",
"numpy",
"vectormath",
]

[project.optional-dependencies]
dev = [
"fastfiz_renderer @ git+https://github.com/P6-Pool/fastfiz-renderer.git",
"stable-baselines3",
"tqdm",
"rich",
"torch",
"tensorboard",
"optuna",
]
test = ["pytest", "mypy", "ruff"]
all = [
# dev
"fastfiz_renderer @ git+https://github.com/P6-Pool/fastfiz-renderer.git",
"stable-baselines3",
"tqdm",
"rich",
"torch",
"tensorboard",
"optuna",
# test
"pytest",
"mypy",
"ruff",
]

[tool.pytest.ini_options]
filterwarnings = [
"ignore::DeprecationWarning:tensorboard",
"ignore::UserWarning:gym",
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true

[tool.ruff]
line-length = 127
9 changes: 0 additions & 9 deletions requirements.txt

This file was deleted.

30 changes: 19 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from setuptools import setup, find_packages
import re
from setuptools import setup

# with open("requirements.txt") as f:
# requirements = f.read().splitlines()


def get_version():
with open("src/fastfiz_env/__init__.py", "r") as f:
for line in f:
match = re.match(r"__version__\s*=\s*['\"]([^'\"]+)['\"]", line)
if match:
return match.group(1)
raise RuntimeError("Version not found in __init__.py")

with open("requirements.txt") as f:
requirements = f.read().splitlines()

setup(
name="fastfiz-env",
description="Gymnasium environment for FastFiz pool simulator",
version="0.0.1",
license="MIT",
install_requires=requirements,
test_requires=["pytest"],
packages=find_packages(where="src"),
package_dir={"": "src"},
version=get_version(),
# install_requires=requirements,
# test_requires=["pytest"],
# packages=find_packages(where="src"),
# package_dir={"": "src"},
)
23 changes: 6 additions & 17 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@
import os
from fastfiz_renderer import GameHandler
import numpy as np
import fastfiz_env
from fastfiz_env.envs import FramesFastFiz, SimpleFastFiz, PocketsFastFiz
from fastfiz_env.reward_functions import reward_function
from fastfiz_env.reward_functions.default_reward import DefaultReward
from fastfiz_env.envs import FramesFastFiz, PocketsFastFiz
from fastfiz_env.utils.fastfiz import (
create_random_table_state,
get_ball_positions,
normalize_ball_positions,
)
from fastfiz_env.envs.utils import game_won, possible_shot
from stable_baselines3 import PPO
from typing import Optional, Callable
from typing import Optional
import argparse

from fastfiz_env.wrappers.action import ActionSpaces, FastFizActionWrapper
from fastfiz_env.wrappers.utils import spherical_coordinates


def get_play_config() -> dict:
Expand Down Expand Up @@ -85,13 +80,9 @@ def decide_shot(self, table_state: ff.TableState) -> Optional[ff.ShotParams]:
for _ in range(10):
if isinstance(self.env, FramesFastFiz):
if self.prev_ts is None:
obs = self.env.compute_observation(
table_state, table_state, self.shot
)
obs = self.env.compute_observation(table_state, table_state, self.shot)
else:
obs = self.env.compute_observation(
self.prev_ts, table_state, self.shot
)
obs = self.env.compute_observation(self.prev_ts, table_state, self.shot)
elif isinstance(self.env, PocketsFastFiz):
obs = self.env.compute_observation(table_state)
else:
Expand All @@ -114,15 +105,13 @@ def main() -> None:
parser.add_argument("-m", "--model", type=str, help="Path to the model file")
args = parser.parse_args()

assert args.model is not None and os.path.exists(
args.model
), f"Model file not found: {args.model}"
assert args.model is not None and os.path.exists(args.model), f"Model file not found: {args.model}"

model = PPO.load(args.model)

# env_vec = fastfiz_env.make("SimpleFastFiz-v0", reward_function=DefaultReward)
# env_vec = FastFizActionWrapper(env_vec, ActionSpaces.NO_OFFSET_3D)
env = FastFizActionWrapper(PocketsFastFiz, ActionSpaces.NO_OFFSET_3D)
env = FastFizActionWrapper(PocketsFastFiz, ActionSpaces.VECTOR_3D)
agent = Agent(model, env)
play(agent.decide_shot, balls=2, episodes=100)

Expand Down
13 changes: 4 additions & 9 deletions src/fastfiz_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Avaliable environments:
- `SimpleFastFiz-v0`: Observes the position of the balls.
- `TestingFastFiz-v0`: Observes the position of the balls. Used for testing purposes with options e.g. seed, logging, action_space_id.
- `FramesFastFiz-v0`: Observes the position of the balls and the frames of the simulation.
- `PocketsFastFiz-v0`: Observes the position of the balls and in play state. Pocketed balls position always corresponds to given pocket center.
Expand All @@ -28,14 +27,16 @@
"""

from .make import make, make_wrapped_vec_env, make_wrapped_env
__version__ = "0.0.1"

from .make import make, make_wrapped_env, make_callable_wrapped_env
from .reward_functions import DefaultReward, RewardFunction, CombinedReward
from . import envs, utils, wrappers, reward_functions

__all__ = [
"make",
"make_wrapped_vec_env",
"make_wrapped_env",
"make_callable_wrapped_env",
"DefaultReward",
"RewardFunction",
"CombinedReward",
Expand All @@ -55,12 +56,6 @@
)


register(
id="TestingFastFiz-v0",
entry_point="fastfiz_env.envs:TestingFastFiz",
additional_wrappers=(wrappers.TimeLimitInjectionWrapper.wrapper_spec(),),
)

register(
id="FramesFastFiz-v0",
entry_point="fastfiz_env.envs:FramesFastFiz",
Expand Down
2 changes: 0 additions & 2 deletions src/fastfiz_env/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

from . import utils
from .simple_fastfiz import SimpleFastFiz
from .testing_fastfiz import TestingFastFiz
from .frames_fastfiz import FramesFastFiz
from .pockets_fastfiz import PocketsFastFiz

__all__ = [
"utils",
"SimpleFastFiz",
"TestingFastFiz",
"FramesFastFiz",
"PocketsFastFiz",
]
37 changes: 6 additions & 31 deletions src/fastfiz_env/envs/frames_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

from fastfiz_env.envs.utils import game_won, terminal_state
from ..utils.fastfiz import (
shot_params_from_action,
get_ball_positions,
create_random_table_state,
get_ball_velocity,
normalize_ball_positions,
normalize_ball_velocity,
is_pocketed_state,
Expand All @@ -28,9 +26,7 @@ class FramesFastFiz(gym.Env):
TOTAL_BALLS = 16 # Including the cue ball
num_balls = 2

def __init__(
self, reward_function: RewardFunction = DefaultReward, num_balls: int = 16
) -> None:
def __init__(self, reward_function: RewardFunction = DefaultReward, num_balls: int = 16) -> None:
super().__init__()
if num_balls < 2:
warnings.warn(
Expand All @@ -47,15 +43,11 @@ def __init__(

def _max_episode_steps(self):
if self.get_wrapper_attr("_time_limit_max_episode_steps") is not None:
self.max_episode_steps = self.get_wrapper_attr(
"_time_limit_max_episode_steps"
)
self.max_episode_steps = self.get_wrapper_attr("_time_limit_max_episode_steps")
print(f"Setting max episode steps to {self.max_episode_steps}")
self.reward.max_episode_steps = self.max_episode_steps

def reset(
self, *, seed: Optional[int] = None, options: Optional[dict] = None
) -> tuple[np.ndarray, dict]:
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]:
"""
Reset the environment to its initial state.
"""
Expand Down Expand Up @@ -137,18 +129,7 @@ def _observation_space(self):
All values are in the range `[0, TABLE_WIDTH]` and `[0, TABLE_LENGTH]`.
"""
table = self.table_state.getTable()

lower = np.full((self.TOTAL_BALLS, 4), [-1, -1, -1, 0])
# upper = np.full(
# (self.TOTAL_BALLS, 4),
# [
# table.TABLE_WIDTH,
# table.TABLE_LENGTH,
# self.table_state.MAX_VELOCITY * 1.580,
# 1,
# ],
# )
upper = np.full(
(self.TOTAL_BALLS, 4),
[1, 1, 1, 1],
Expand Down Expand Up @@ -187,14 +168,9 @@ def _possible_shot(self, shot_params: ff.ShotParams) -> bool:
"""
Check if the shot is possible.
"""
return (
self.table_state.isPhysicallyPossible(shot_params)
== ff.TableState.OK_PRECONDITION
)
return self.table_state.isPhysicallyPossible(shot_params) == ff.TableState.OK_PRECONDITION

def _compute_observation(
self, prev_table_state: ff.TableState, shot: Optional[ff.Shot]
) -> np.ndarray:
def _compute_observation(self, prev_table_state: ff.TableState, shot: Optional[ff.Shot]) -> np.ndarray:
return self.compute_observation(prev_table_state, self.table_state, shot)

@classmethod
Expand Down Expand Up @@ -236,8 +212,7 @@ def compute_observation(
pocketed = is_pocketed_state(gb.state)
frames_seq[frame][gb.number] = [
*normalize_ball_positions((gb.position.x, gb.position.y)), # type: ignore
normalize_ball_velocity(np.hypot(gb.velocity.x, gb.velocity.y)) * 2
- 1,
normalize_ball_velocity(np.hypot(gb.velocity.x, gb.velocity.y)) * 2 - 1,
pocketed,
]
return frames_seq
Loading

0 comments on commit b8728c8

Please sign in to comment.