Skip to content

Commit

Permalink
Reward changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MadsSR committed Apr 10, 2024
1 parent b9fcfeb commit 3940c6c
Show file tree
Hide file tree
Showing 29 changed files with 669 additions and 189 deletions.
6 changes: 6 additions & 0 deletions src/fastfiz_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,15 @@
register(
id="TestingFastFiz-v0",
entry_point="fastfiz_env.envs:TestingFastFiz",
additional_wrappers=(
utils.wrappers.MaxEpisodeStepsInjectionWrapper.wrapper_spec(),
),
)

register(
id="ActionFastFiz-v0",
entry_point="fastfiz_env.envs:ActionFastFiz",
additional_wrappers=(
utils.wrappers.MaxEpisodeStepsInjectionWrapper.wrapper_spec(),
),
)
25 changes: 13 additions & 12 deletions src/fastfiz_env/envs/action_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self.observation_space = self._observation_space()
self.action_space = FastFizActionWrapper.get_action_space(action_space_id)
self.reward = reward_function

self.max_episode_steps = None
# Logging
self.logger = logging.getLogger(__name__)
logs_dir = self.options.get("logs_dir", "")
Expand All @@ -70,21 +70,24 @@ def __init__(
self.observation_space,
)

def _max_episode_steps(self):
if self.get_wrapper_attr("_time_limit_max_episode_steps") is not None:
self.max_episode_steps = self.get_wrapper_attr(
"_time_limit_max_episode_steps"
)
self.reward.max_episode_steps = self.max_episode_steps

def reset(
self, *, seed: Optional[int] = None, options: Optional[dict] = None
) -> tuple[np.ndarray, dict]:
super().reset(seed=seed)
self.logger.info("Reset(%s) - total n_steps: %s", self.n_episodes, self.n_step)
self.logger.info("Reset(%s) - table state seed: %s", self.n_episodes, seed)

if self.max_episode_steps is None:
self._max_episode_steps()

self.table_state = create_random_table_state(self.num_balls, seed=seed)
self.reward.reset(self.table_state)

self.logger.info(
"Reset(%s) - table state:\n%s",
self.n_episodes,
table_state_to_string(self.table_state),
)

observation = self._get_observation()
info = self._get_info()

Expand Down Expand Up @@ -122,9 +125,7 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]

observation = self._get_observation()

reward = self.reward.get_reward(
prev_table_state, self.table_state, impossible_shot
)
reward = self.reward.get_reward(prev_table_state, self.table_state, action)

terminated = self._is_terminal_state()
truncated = False
Expand Down
3 changes: 2 additions & 1 deletion src/fastfiz_env/envs/simple_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]

prev_table_state = ff.TableState(self.table_state)

shot_params = shot_params_from_action(self.table_state, [0, 0, *action])
# shot_params = shot_params_from_action(self.table_state, [0, 0, *action])
shot_params = ff.ShotParams(*action)

impossible_shot = not self._possible_shot(shot_params)

Expand Down
28 changes: 18 additions & 10 deletions src/fastfiz_env/envs/testing_fastfiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
import logging
import time

DEFAULT_TEST_OPTIONS = {
"seed": 123,
"log_level": logging.INFO,
}


class TestingFastFiz(gym.Env):
"""FastFiz environment for testing."""
Expand All @@ -35,15 +30,16 @@ def __init__(
reward_function: RewardFunction = DefaultReward,
num_balls: int = 16,
*,
test_options: Optional[dict] = None,
options: Optional[dict] = None,
) -> None:
super().__init__()
self.options = test_options
self.options = options
self.num_balls = num_balls
self.table_state = create_random_table_state(self.num_balls)
self.observation_space = self._observation_space()
action_space_id = self.options.get("action_space_id", ActionSpaces.NO_OFFSET_3D)
self.action_space = FastFizActionWrapper.get_action_space(action_space_id)
self.max_episode_steps = None
self.reward = reward_function

# Logging
Expand All @@ -70,10 +66,21 @@ def __init__(
self.observation_space,
)

def _max_episode_steps(self):
if self.get_wrapper_attr("_time_limit_max_episode_steps") is not None:
self.max_episode_steps = self.get_wrapper_attr(
"_time_limit_max_episode_steps"
)
self.reward.max_episode_steps = self.max_episode_steps

def reset(
self, *, seed: Optional[int] = None, options: Optional[dict] = None
) -> tuple[np.ndarray, dict]:
super().reset(seed=seed)

if self.max_episode_steps is None:
self._max_episode_steps()

seed = self.options.get("seed", None)
self.logger.info("Reset(%s) - total n_steps: %s", self.n_episodes, self.n_step)
self.logger.info("Reset(%s) - table state seed: %s", self.n_episodes, seed)
Expand Down Expand Up @@ -130,9 +137,7 @@ def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]

observation = self._get_observation()

reward = self.reward.get_reward(
prev_table_state, self.table_state, impossible_shot
)
reward = self.reward.get_reward(prev_table_state, self.table_state, action)

terminated = self._is_terminal_state()
truncated = False
Expand Down Expand Up @@ -172,6 +177,9 @@ def _is_terminal_state(self) -> bool:
return self._game_won()

def _game_won(self) -> bool:
if self.table_state.getBall(0).isPocketed():
return False

for i in range(1, self.num_balls):
if not self.table_state.getBall(i).isPocketed():
return False
Expand Down
2 changes: 1 addition & 1 deletion src/fastfiz_env/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def make(
num_balls=num_balls,
max_episode_steps=max_episode_steps,
disable_env_checker=disable_env_checker,
**kwargs
**kwargs,
)
3 changes: 2 additions & 1 deletion src/fastfiz_env/utils/reward_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
This module contains the reward functions used in the FastFiz environment.
"""

from .reward_function import RewardFunction
from .reward_function import RewardFunction, Weight
from .combined_reward import CombinedReward
from .binary_reward import BinaryReward
from .default_reward import DefaultReward
from . import common

__all__ = [
"RewardFunction",
"Weight",
"CombinedReward",
"BinaryReward",
"DefaultReward",
Expand Down
19 changes: 12 additions & 7 deletions src/fastfiz_env/utils/reward_functions/binary_reward.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from abc import ABC, abstractmethod
from typing import Callable, Union
import fastfiz as ff
from .reward_function import RewardFunction
import numpy as np


class BinaryReward(RewardFunction, ABC):
def __init__(self, *, short_circuit: bool = True) -> None:
def __init__(
self,
weight: Union[float, Callable[[int, int, int], float]] = 1,
*,
max_episode_steps: int = None,
short_circuit: bool = True,
) -> None:
"""
Initializes a BinaryReward object.
Expand All @@ -13,18 +21,15 @@ def __init__(self, *, short_circuit: bool = True) -> None:
If set to True, the reward will be calculated based on the first condition that is met.
If set to False, all conditions will be evaluated. Defaults to True.
"""
super().__init__(weight=weight, max_episode_steps=max_episode_steps)
self.short_circuit = short_circuit

@abstractmethod
def reset(self, table_state: ff.TableState) -> None:
pass

@abstractmethod
def get_reward(
def reward(
self,
prev_table_state: ff.TableState,
table_state: ff.TableState,
possible_shot: bool,
action: np.ndarray,
) -> float:
"""
Calculates the reward for a given table state transition.
Expand Down
47 changes: 31 additions & 16 deletions src/fastfiz_env/utils/reward_functions/combined_reward.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .reward_function import RewardFunction
from .reward_function import RewardFunction, Weight
from .binary_reward import BinaryReward
import fastfiz as ff
import numpy as np


class CombinedReward(RewardFunction):
Expand All @@ -10,9 +11,10 @@ class CombinedReward(RewardFunction):

def __init__(
self,
reward_functions: list[RewardFunction],
weights: list[float | int],
weight: Weight = 1,
max_episode_steps: int = None,
*,
reward_functions: list[RewardFunction],
short_circuit: bool = False,
) -> None:
"""
Expand All @@ -26,23 +28,34 @@ def __init__(
Returns:
None
"""
assert len(reward_functions) == len(
weights
), "Reward functions and weights must have the same length."
self.reward_functions = reward_functions
self.weights = weights
super().__init__(weight, max_episode_steps=max_episode_steps)

# Set max_episode_steps for all reward functions

self.short_circuit = short_circuit
self.max_episode_steps = max_episode_steps

@property
def max_episode_steps(self) -> int:
return self._max_episode_steps

@max_episode_steps.setter
def max_episode_steps(self, value: int) -> None:
self._max_episode_steps = value
for reward in self.reward_functions:
reward.max_episode_steps = value

def reset(self, table_state: ff.TableState) -> None:
super().reset(table_state)
for reward in self.reward_functions:
reward.reset(table_state)

def get_reward(
def reward(
self,
prev_table_state: ff.TableState,
table_state: ff.TableState,
impossible_shot: bool,
action: np.ndarray,
) -> float:
"""
Calculates the combined reward based on the given table states and possible shot flag.
Expand All @@ -57,17 +70,19 @@ def get_reward(
"""
total_reward = 0
for i, reward_function in enumerate(self.reward_functions):
reward = reward_function.get_reward(
prev_table_state, table_state, impossible_shot
)
total_reward += reward * self.weights[i]
for reward_function in self.reward_functions:
reward = reward_function.get_reward(prev_table_state, table_state, action)
total_reward += reward

if issubclass(reward_function.__class__, BinaryReward):
if reward == 1 and self.short_circuit and reward_function.short_circuit:
if (
reward == 1 * reward_function.weight()
and self.short_circuit
and reward_function.short_circuit
):
return total_reward

return total_reward

def __str__(self) -> str:
return f"CombinedReward({[str(reward) for reward in self.reward_functions]}, {str(self.weights)}, short_circuit={self.short_circuit})"
return f"CombinedReward({[str(reward) for reward in self.reward_functions]}, {None}, short_circuit={self.short_circuit})"
32 changes: 32 additions & 0 deletions src/fastfiz_env/utils/reward_functions/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,26 @@
from .game_won_reward import GameWonReward
from .impossible_shot_reward import ImpossibleShotReward
from .constant_reward import ConstantReward
from .balls_not_moved_reward import BallsNotMovedReward
from .velocity_reward import VelocityReward
from .exponential_velocity_reward import ExponentialVelocityReward

from .weights import (
ConstantWeight,
NegativeConstantWeight,
ConstantWeightMaxSteps,
NegativeConstantWeightMaxSteps,
ConstantWeightNumBalls,
NegativeConstantWeightNumBalls,
ConstantWeightBalls,
NegativeConstantWeightBalls,
ConstantWeightCurrentStep,
NegativeConstantWeightCurrentStep,
)


__all__ = [
# Reward functions
"StepPocketedReward",
"TotalDistanceReward",
"DeltaBestTotalDistanceReward",
Expand All @@ -20,4 +38,18 @@
"GameWonReward",
"ImpossibleShotReward",
"ConstantReward",
"BallsNotMovedReward",
"VelocityReward",
"ExponentialVelocityReward",
# Weights
"ConstantWeight",
"NegativeConstantWeight",
"ConstantWeightMaxSteps",
"NegativeConstantWeightMaxSteps",
"ConstantWeightNumBalls",
"NegativeConstantWeightNumBalls",
"ConstantWeightBalls",
"NegativeConstantWeightBalls",
"ConstantWeightCurrentStep",
"NegativeConstantWeightCurrentStep",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .. import BinaryReward
from ...fastfiz import any_ball_has_moved, get_ball_positions


class BallsNotMovedReward(BinaryReward):
"""
Reward function that reward based on whether any balls has moved.
"""

def reward(self, prev_table_state, table_state, action) -> float:
"""
Reward function returns 1 if any balls has moved, 0 otherwise.
"""
prev_ball_positions = get_ball_positions(prev_table_state)[1:]
ball_positions = get_ball_positions(table_state)[1:]

not_moved = not any_ball_has_moved(prev_ball_positions, ball_positions)

return 1 if not_moved else 0
14 changes: 9 additions & 5 deletions src/fastfiz_env/utils/reward_functions/common/constant_reward.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from ..reward_function import RewardFunction
from ..reward_function import RewardFunction, Weight
import fastfiz as ff
import numpy as np


class ConstantReward(RewardFunction):
"""
Reward function that always returns 1. Inteded to be used in combination with other reward functions.
"""

def reset(self, table_state) -> None:
pass

def get_reward(self, prev_table_state, table_state, impossible_shot) -> float:
def reward(
self,
prev_table_state: ff.TableState,
table_state: ff.TableState,
action: np.ndarray,
) -> float:
"""
Reward function that always returns 1. Inteded to be used in combination with other reward functions.
"""
Expand Down
Loading

0 comments on commit 3940c6c

Please sign in to comment.