forked from natolambert/dynamicslearn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy_tsac.py
79 lines (70 loc) · 2.07 KB
/
policy_tsac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Run PyTorch Soft Actor Critic on HalfCheetahEnv with the "Twin" architecture
from TD3: https://arxiv.org/pdf/1802.09477.pdf
"""
import numpy as np
from gymenv_quad import QuadEnv
import rlkit.rlkit.torch.pytorch_util as ptu
from rlkit.rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.rlkit.launchers.launcher_util import setup_logger
from rlkit.rlkit.torch.sac.policies import TanhGaussianPolicy
from rlkit.rlkit.torch.sac.sac import SoftActorCritic
from rlkit.rlkit.torch.networks import FlattenMlp
from rlkit.rlkit.torch.sac.twin_sac import TwinSAC
def experiment(variant):
import gym
env = QuadEnv()
env = NormalizedBoxEnv(env)
# env = NormalizedBoxEnv(gym.make('HalfCheetah-v2'))
obs_dim = int(np.prod(env.observation_space.shape))
action_dim = int(np.prod(env.action_space.shape))
net_size = variant['net_size']
qf1 = FlattenMlp(
hidden_sizes=[net_size, net_size],
input_size=obs_dim + action_dim,
output_size=1,
)
qf2 = FlattenMlp(
hidden_sizes=[net_size, net_size],
input_size=obs_dim + action_dim,
output_size=1,
)
vf = FlattenMlp(
hidden_sizes=[net_size, net_size],
input_size=obs_dim,
output_size=1,
)
policy = TanhGaussianPolicy(
hidden_sizes=[net_size, net_size],
obs_dim=obs_dim,
action_dim=action_dim,
)
algorithm = TwinSAC(
env=env,
policy=policy,
qf1=qf1,
qf2=qf2,
vf=vf,
**variant['algo_params']
)
algorithm.to(ptu.device)
algorithm.train()
if __name__ == "__main__":
# noinspection PyTypeChecker
variant = dict(
algo_params=dict(
num_epochs=150,
num_steps_per_epoch=500,
num_steps_per_eval=500,
max_path_length=300,
batch_size=128,
discount=0.99,
soft_target_tau=0.001,
policy_lr=3E-4,
qf_lr=3E-4,
vf_lr=3E-4,
),
net_size=300,
)
setup_logger('tsac-cheetah', variant=variant)
experiment(variant)