-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from Kautenja/fix_ppu_bug
Wrappers
- Loading branch information
Showing
14 changed files
with
517 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,6 @@ | |
|
||
class NESEnv { | ||
|
||
private: | ||
std::string path; | ||
|
||
public: | ||
NESEnv(wchar_t* path); | ||
void reset(); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""Wrappers for altering the functionality of the game.""" | ||
import gym | ||
from .binary_to_discrete_space_env import BinarySpaceToDiscreteSpaceEnv | ||
from .clip_reward_env import ClipRewardEnv | ||
from .downsample_env import DownsampleEnv | ||
from .frame_stack_env import FrameStackEnv | ||
from .normalize_reward_env import NormalizeRewardEnv | ||
from .penalize_death_env import PenalizeDeathEnv | ||
from .reward_cache_env import RewardCacheEnv | ||
|
||
|
||
def wrap(env, | ||
cache_rewards=True, | ||
image_size=(84, 84), | ||
death_penalty=-15, | ||
clip_rewards=False, | ||
normalize_rewards=False, | ||
agent_history_length=4 | ||
) -> gym.Env: | ||
""" | ||
Wrap an environment with standard wrappers. | ||
Args: | ||
env (gym.Env): the environment to wrap | ||
cache_rewards (bool): True to use a reward cache for raw rewards | ||
image_size (tuple): the size to down-sample images to | ||
death_penatly (float): the penalty for losing a life in a game | ||
clip_rewards (bool): whether to clip rewards in {-1, 0, +1} | ||
normalize_rewards (bool): whether to normalize rewards w/ infinity norm | ||
agent_history_length (int): the size of the frame buffer for the agent | ||
Returns: | ||
a gym environment configured for this experiment | ||
""" | ||
# wrap the environment with a reward cacher | ||
if cache_rewards: | ||
env = RewardCacheEnv(env) | ||
# apply a down-sampler for the given game | ||
if image_size is not None: | ||
env = DownsampleEnv(env, image_size) | ||
# apply the death penalty feature if enabled | ||
if death_penalty is not None: | ||
env = PenalizeDeathEnv(env, penalty=death_penalty) | ||
# normalize the rewards in [-1, 1] if the feature is enabled | ||
if normalize_rewards: | ||
env = NormalizeRewardEnv(env) | ||
# clip the rewards in {-1, 0, +1} if the feature is enabled | ||
if clip_rewards: | ||
env = ClipRewardEnv(env) | ||
# apply the back history of frames if the feature is enabled | ||
if agent_history_length is not None: | ||
env = FrameStackEnv(env, agent_history_length) | ||
|
||
return env | ||
|
||
|
||
# explicitly define the outward facing API of this package | ||
__all__ = [ | ||
BinarySpaceToDiscreteSpaceEnv.__name__, | ||
ClipRewardEnv.__name__, | ||
DownsampleEnv.__name__, | ||
FrameStackEnv.__name__, | ||
NormalizeRewardEnv.__name__, | ||
PenalizeDeathEnv.__name__, | ||
RewardCacheEnv.__name__, | ||
wrap.__name__, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Utility modules for the parent package of wrappers.""" | ||
from .lazy_frames import LazyFrames | ||
|
||
|
||
# explicitly define the outward facing API of this package | ||
__all__ = [ | ||
LazyFrames.__name__ | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""A memory efficient buffer for frame tensors.""" | ||
import numpy as np | ||
|
||
|
||
class LazyFrames(object): | ||
"""A memory efficient buffer for frame tensors. | ||
Note: | ||
This object ensures that common frames between the observations are | ||
only stored once. It exists purely to optimize memory usage which can | ||
be huge for DQN's 1M frames replay buffers. This object should only be | ||
converted to numpy array before being passed to the model. You'd not | ||
believe how complex the previous solution was. | ||
""" | ||
|
||
def __init__(self, frames): | ||
""" | ||
Initialize a new lazy frames object. | ||
Args: | ||
frames (list): the list of frames to store lazily | ||
""" | ||
self._frames = frames | ||
self._out = None | ||
|
||
def _force(self): | ||
"""Force the internal buffer of frames into a NumPy array.""" | ||
if self._out is None: | ||
self._out = np.concatenate(self._frames, axis=2) | ||
self._frames = None | ||
|
||
return self._out | ||
|
||
def __array__(self, dtype=None): | ||
""" | ||
Convert this lazy frame buffer to a NumPy array. | ||
Args: | ||
dtype (numpy.dtype): the type to cast the member values to | ||
Returns: | ||
(numpy.ndarray) a NumPy array from the frames in this lazy buffer | ||
""" | ||
out = self._force() | ||
if dtype is not None: | ||
out = out.astype(dtype) | ||
|
||
return out | ||
|
||
def __len__(self): | ||
"""Return the number of frames in this lazy frame buffer.""" | ||
return len(self._force()) | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Return the frame at index i. | ||
Args: | ||
index (int): the index (or slice) of the frame to return | ||
Returns: | ||
(numpy.ndarray) the frame stored at index i | ||
""" | ||
return self._force()[index] | ||
|
||
|
||
# explicitly define the outward facing API of this module | ||
__all__ = [LazyFrames.__name__] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""An environment wrapper to convert binary to discrete action space.""" | ||
import gym | ||
import numpy as np | ||
|
||
|
||
class BinarySpaceToDiscreteSpaceEnv(gym.Wrapper): | ||
"""An environment wrapper to convert binary to discrete action space.""" | ||
|
||
# a mapping of buttons to binary values | ||
_button_map = { | ||
'right': 0b10000000, | ||
'left': 0b01000000, | ||
'down': 0b00100000, | ||
'up': 0b00010000, | ||
'start': 0b00001000, | ||
'select': 0b00000100, | ||
'B': 0b00000010, | ||
'A': 0b00000001, | ||
'NOP': 0b00000000, | ||
} | ||
|
||
def __init__(self, env, actions): | ||
""" | ||
Initialize a new binary to discrete action space wrapper. | ||
Args: | ||
env (gym.Env): the environment to wrap | ||
actions (list): an ordered list of actions (as lists of buttons). | ||
The index of each button list is its discrete coded value | ||
Returns: | ||
None | ||
""" | ||
super(BinarySpaceToDiscreteSpaceEnv, self).__init__(env) | ||
# create the new action space | ||
self.action_space = gym.spaces.Discrete(len(actions)) | ||
# create the action map from the list of discrete actions | ||
self._action_map = {} | ||
# iterate over all the actions (as button lists) | ||
for action, button_list in enumerate(actions): | ||
# the value of this action's bitmap | ||
byte_action = 0 | ||
# iterate over the buttons in this button list | ||
for button in button_list: | ||
byte_action |= self._button_map[button] | ||
# set this action maps value to the byte action value | ||
self._action_map[action] = byte_action | ||
|
||
def step(self, action): | ||
""" | ||
Take a step using the given action. | ||
Args: | ||
action (int): the discrete action to perform | ||
Returns: | ||
a tuple of: | ||
- (numpy.ndarray) the state as a result of the action | ||
- (float) the reward achieved by taking the action | ||
- (bool) a flag denoting whether the episode has ended | ||
- (dict) a dictionary of extra information | ||
""" | ||
# take the step and record the output | ||
return self.env.step(self._action_map[action]) | ||
|
||
|
||
# explicitly define the outward facing API of this module | ||
__all__ = [BinarySpaceToDiscreteSpaceEnv.__name__] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""An environment wrapper to clip rewards.""" | ||
import gym | ||
import numpy as np | ||
|
||
|
||
class ClipRewardEnv(gym.RewardWrapper): | ||
"""An environment that clips rewards in {-1, 0, 1}.""" | ||
|
||
def __init__(self, env: gym.Env) -> None: | ||
""" | ||
Initialize a new reward clipping environment. | ||
Args: | ||
env (gym.Env): the environment to wrap | ||
Returns: | ||
None | ||
""" | ||
super().__init__(env) | ||
|
||
def reward(self, reward: float) -> float: | ||
"""Bin reward to {-1, 0, +1} using its sign.""" | ||
return np.sign(reward) | ||
|
||
|
||
# explicitly specify the external API of this module | ||
__all__ = [ClipRewardEnv.__name__] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""An environment wrapper for down-sampling frames from RGB to smaller B&W.""" | ||
import gym | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class DownsampleEnv(gym.ObservationWrapper): | ||
"""An environment that down-samples frames.""" | ||
|
||
def __init__(self, env, image_size): | ||
""" | ||
Create a new down-sampler. | ||
Args: | ||
env (gym.Env): the environment to wrap | ||
image_size (tuple): the size to output frames as (width X height) | ||
Returns: | ||
None | ||
""" | ||
super(DownsampleEnv, self).__init__(env) | ||
self._image_size = image_size | ||
# set up a new observation space | ||
self.observation_space = gym.spaces.Box( | ||
low=0, | ||
high=255, | ||
shape=(self._image_size[1], self._image_size[0], 1), | ||
dtype=np.uint8 | ||
) | ||
|
||
def observation(self, frame): | ||
""" | ||
Downsample an observation from RGB to gray scale and resize it | ||
Args: | ||
frame (numpy.ndarray): the image to convert to grayscale and resize | ||
Returns: | ||
(numpy.ndarray) the frame in B&W and resized to self._image_size | ||
""" | ||
# convert the frame from RGB to gray scale | ||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) | ||
# resize the frame to the expected shape. use bilinear interpolation | ||
frame = cv2.resize(frame, self._image_size) | ||
|
||
return frame[:, :, np.newaxis] | ||
|
||
|
||
# explicitly define the outward facing API of this module | ||
__all__ = [DownsampleEnv.__name__] |
Oops, something went wrong.