Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ChainerX support by DQN #375

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
6 changes: 4 additions & 2 deletions chainerrl/action_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self, q_values, q_values_formatter=lambda x: x):
@cached_property
def greedy_actions(self):
return chainer.Variable(
self.q_values.array.argmax(axis=1).astype(np.int32))
self.q_values.array.argmax(axis=1).astype(np.int32),
requires_grad=False)

@cached_property
def max(self):
Expand Down Expand Up @@ -129,7 +130,8 @@ def __init__(self, q_dist, z_values, q_values_formatter=lambda x: x):
@cached_property
def greedy_actions(self):
return chainer.Variable(
self.q_values.array.argmax(axis=1).astype(np.int32))
self.q_values.array.argmax(axis=1).astype(np.int32),
requires_grad=False)

@cached_property
def max(self):
Expand Down
12 changes: 5 additions & 7 deletions chainerrl/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class DQN(agent.AttributeSavingMixin, agent.BatchAgent):
replay_buffer (ReplayBuffer): Replay buffer
gamma (float): Discount factor
explorer (Explorer): Explorer that specifies an exploration strategy.
gpu (int): GPU device id if not None nor negative.
device (object): Device object.
replay_start_size (int): if the replay buffer's size is less than
replay_start_size, skip update
minibatch_size (int): Minibatch size
Expand All @@ -117,7 +117,7 @@ class DQN(agent.AttributeSavingMixin, agent.BatchAgent):
saved_attributes = ('model', 'target_model', 'optimizer')

def __init__(self, q_function, optimizer, replay_buffer, gamma,
explorer, gpu=None, replay_start_size=50000,
explorer, device, replay_start_size=50000,
minibatch_size=32, update_interval=1,
target_update_interval=10000, clip_delta=True,
phi=lambda x: x,
Expand All @@ -129,19 +129,17 @@ def __init__(self, q_function, optimizer, replay_buffer, gamma,
episodic_update_len=None,
logger=getLogger(__name__),
batch_states=batch_states):

self.model = q_function
self.model.to_device(device)
self.model.device.use()
self.q_function = q_function # For backward compatibility

if gpu is not None and gpu >= 0:
cuda.get_device(gpu).use()
self.model.to_gpu(device=gpu)

self.xp = self.model.xp
self.replay_buffer = replay_buffer
self.optimizer = optimizer
self.gamma = gamma
self.explorer = explorer
self.gpu = gpu
self.target_update_interval = target_update_interval
self.clip_delta = clip_delta
self.phi = phi
Expand Down
3 changes: 3 additions & 0 deletions chainerrl/misc/batch_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def batch_states(states, xp, phi):
if chainer.cuda.available and xp is chainer.cuda.cupy:
# GPU
device = chainer.cuda.Device().id
elif hasattr(chainer, 'chainerx') and xp is chainer.chainerx:
# GPU
device = chainer.chainerx.get_default_device()
else:
# CPU
device = -1
Expand Down
9 changes: 5 additions & 4 deletions chainerrl/misc/random_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np


def set_random_seed(seed, gpus=()):
def set_random_seed(seed, devices=()):
"""Set a given random seed to ChainerRL's random sources.

This function sets a given random seed to random sources that ChainerRL
Expand All @@ -33,9 +33,10 @@ def set_random_seed(seed, gpus=()):
# ChainerRL depends on numpy.random
np.random.seed(seed)
# ChainerRL depends on cupy.random for GPU computation
for gpu in gpus:
if gpu >= 0:
with chainer.cuda.get_device_from_id(gpu):
for device in devices:
device = chainer.get_device(device)
if device.xp is chainer.cuda.cupy:
with chainer.using_device(device):
chainer.cuda.cupy.random.seed(seed)
# chainer.functions.n_step_rnn directly depends on CHAINER_SEED
os.environ['CHAINER_SEED'] = str(seed)
4 changes: 2 additions & 2 deletions chainerrl/q_functions/state_q_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
def scale_by_tanh(x, low, high):
xp = cuda.get_array_module(x.array)
scale = (high - low) / 2
scale = xp.expand_dims(xp.asarray(scale, dtype=np.float32), axis=0)
scale = xp.asarray(scale, dtype=np.float32)[None]
mean = (high + low) / 2
mean = xp.expand_dims(xp.asarray(mean, dtype=np.float32), axis=0)
mean = xp.asarray(mean, dtype=np.float32)[None]
return F.tanh(x) * scale + mean


Expand Down
12 changes: 7 additions & 5 deletions examples/ale/train_dqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def main():
' If it does not exist, it will be created.')
parser.add_argument('--seed', type=int, default=0,
help='Random seed [0, 2 ** 31)')
parser.add_argument('--gpu', type=int, default=0,
help='GPU to use, set to -1 if no GPU.')
parser.add_argument('--device', type=str, default='@numpy',
help='String representing a device.')
parser.add_argument('--demo', action='store_true', default=False)
parser.add_argument('--load', type=str, default=None)
parser.add_argument('--final-exploration-frames',
Expand Down Expand Up @@ -141,7 +141,7 @@ def main():
logging.basicConfig(level=args.logging_level)

# Set a random seed used in ChainerRL.
misc.set_random_seed(args.seed, gpus=(args.gpu,))
misc.set_random_seed(args.seed, devices=(args.device,))

# Set different random seeds for train and test envs.
train_seed = args.seed
Expand Down Expand Up @@ -212,8 +212,10 @@ def phi(x):
return np.asarray(x, dtype=np.float32) / 255

Agent = parse_agent(args.agent)
agent = Agent(q_func, opt, rbuf, gpu=args.gpu, gamma=0.99,
explorer=explorer, replay_start_size=args.replay_start_size,
agent = Agent(q_func, opt, rbuf, gamma=0.99,
explorer=explorer,
device=args.device,
replay_start_size=args.replay_start_size,
target_update_interval=args.target_update_interval,
clip_delta=args.clip_delta,
update_interval=args.update_interval,
Expand Down
19 changes: 12 additions & 7 deletions examples/gym/train_dqn_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import sys

import chainer
from chainer import optimizers
import gym
from gym import spaces
Expand All @@ -39,17 +40,14 @@


def main():
import logging
logging.basicConfig(level=logging.DEBUG)

parser = argparse.ArgumentParser()
parser.add_argument('--outdir', type=str, default='results',
help='Directory path to save output files.'
' If it does not exist, it will be created.')
parser.add_argument('--env', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0,
help='Random seed [0, 2 ** 32)')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--device', type=str, default='@numpy')
parser.add_argument('--final-exploration-steps',
type=int, default=10 ** 4)
parser.add_argument('--start-epsilon', type=float, default=1.0)
Expand All @@ -75,10 +73,15 @@ def main():
parser.add_argument('--render-eval', action='store_true')
parser.add_argument('--monitor', action='store_true')
parser.add_argument('--reward-scale-factor', type=float, default=1e-3)
parser.add_argument('--logging-level', type=int, default=20,
help='Logging level. 10:DEBUG, 20:INFO etc.')
args = parser.parse_args()

import logging
logging.basicConfig(level=args.logging_level)

# Set a random seed used in ChainerRL
misc.set_random_seed(args.seed, gpus=(args.gpu,))
misc.set_random_seed(args.seed, devices=(args.device,))

args.outdir = experiments.prepare_output_dir(
args, args.outdir, argv=sys.argv)
Expand Down Expand Up @@ -171,8 +174,10 @@ def make_env(test):
else:
rbuf = replay_buffer.ReplayBuffer(rbuf_capacity)

agent = DQN(q_func, opt, rbuf, gpu=args.gpu, gamma=args.gamma,
explorer=explorer, replay_start_size=args.replay_start_size,
agent = DQN(q_func, opt, rbuf, gamma=args.gamma,
explorer=explorer,
device=args.device,
replay_start_size=args.replay_start_size,
target_update_interval=args.target_update_interval,
update_interval=args.update_interval,
minibatch_size=args.minibatch_size,
Expand Down