Skip to content

Commit

Permalink
Merge pull request #20 from Kautenja/max_steps
Browse files Browse the repository at this point in the history
Max steps
  • Loading branch information
Kautenja authored Jul 20, 2018
2 parents dd92121 + 6f843d9 commit 763a8ee
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
19 changes: 18 additions & 1 deletion nes_py/nes_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A CTypes interface to the C++ NES environment."""
import os
import sys
import math
import ctypes
import itertools
from glob import glob
Expand Down Expand Up @@ -91,13 +92,14 @@ class NESEnv(gym.Env):
# action space is a bitmap of button press values for the 8 NES buttons
action_space = Bitmap(NUM_BUTTONS)

def __init__(self, rom_path, frameskip=1):
def __init__(self, rom_path, frameskip=1, max_episode_steps=math.inf):
"""
Create a new NES environment.
Args:
path (str): the path to the ROM for the environment
frameskip (int): the number of frames to skip between steps
max_episode_steps (int): number of steps before an episode ends
Returns:
None
Expand Down Expand Up @@ -126,6 +128,14 @@ def __init__(self, rom_path, frameskip=1):
# adjust the FPS of the environment by the given frameskip value
self.metadata['video.frames_per_second'] /= frameskip

# check the max episode steps
if not isinstance(max_episode_steps, (int, float)):
raise TypeError('max_episode_steps must be of type: int, float')
if not max_episode_steps > 0:
raise ValueError('max_episode_steps must be > 0')
self._max_episode_steps = max_episode_steps
self._steps = 0

# initialize the C++ object for running the environment
self._env = _LIB.NESEnv_init(self._rom_path)
# setup a boolean for whether to flip from BGR to RGB based on machine
Expand Down Expand Up @@ -203,6 +213,8 @@ def reset(self):
state (np.ndarray): next frame as a result of the given action
"""
# reset the steps counter
self._steps = 0
# call the before reset callback
self._will_reset()
# reset the emulator
Expand Down Expand Up @@ -251,6 +263,11 @@ def step(self, action):
self._did_step()
# copy the screen from the emulator
self._copy_screen()
# increment the steps counter
self._steps += 1
# set the done flag to true if the steps are past the max
if self._steps >= self._max_episode_steps:
done = True
# return the screen from the emulator and other relevant data
return self.screen, reward, done, info

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def README():

setup(
name='nes_py',
version='0.5.3',
version='0.6.0',
description='An NES Emulator and OpenAI Gym interface',
long_description=README(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 763a8ee

Please sign in to comment.