Skip to content

Commit

Permalink
quick fixes to make it run
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Feb 9, 2025
1 parent 5f3cb01 commit 256c711
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 44 deletions.
45 changes: 45 additions & 0 deletions config/morph.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
[base]
package = morph
env_name = morph
vec = native
policy_name = Policy
# rnn_name = Recurrent

[env]
motion_file = "/workspace/dataset/AMASS/amass_train_take6_upright.pkl"
has_self_collision = True
num_envs = 2048

[train]
seed = 1
torch_deterministic = True
device = cuda

cpu_offload = False
compile = False
norm_adv = True
target_kl = None

total_timesteps = 10_000_000
eval_timesteps = 100_000

num_envs = 1
batch_size = 65536
minibatch_size = 16384

update_epochs = 5
bptt_horizon = 16
learning_rate = 0.00002
anneal_lr = False
gae_lambda = 0.99
gamma = 0.95

clip_coef = 0.2
clip_vloss = True
vf_coef = 5.0
vf_clip_coef = 0.2
max_grad_norm = 0.5
ent_coef = 0.0
disc_coef = 5.0

checkpoint_interval = 1000
9 changes: 5 additions & 4 deletions pufferlib/environments/morph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .environment import env_creator

try:
import torch
# NOTE: demo.py looks the policy class from the torch module
import pufferlib.environments.morph.policy as torch
except ImportError:
pass
else:
from .torch import Policy
from .policy import Policy
try:
from .torch import Recurrent
from .policy import Recurrent
except:
Recurrent = None
Recurrent = None
16 changes: 6 additions & 10 deletions pufferlib/environments/morph/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def mean_and_log(self):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--num_envs", type=int, default=32)
parser.add_argument("-m", "--motion_file", type=str, default="sample_data/amass_train_take6_upright.pkl")
parser.add_argument("-m", "--motion_file", type=str, default="/workspace/dataset/AMASS/amass_train_take6_upright.pkl")
parser.add_argument("--disable_self_collision", action="store_true")
args = parser.parse_args()

Expand All @@ -128,13 +128,9 @@ def test_perf(env, timeout=10):
sps = int(steps / (end - start))
print(f"Steps: {steps}, SPS: {sps}")

cfg = {
"env": {
"num_envs": args.num_envs,
"motion_file": args.motion_file,
},
"robot": {"has_self_collision": not args.disable_self_collision},
}

env = PHCPufferEnv(cfg)
env = PHCPufferEnv(
motion_file=args.motion_file,
has_self_collision=not args.disable_self_collision,
num_envs=args.num_envs,
)
test_perf(env)
40 changes: 19 additions & 21 deletions pufferlib/environments/morph/humanoid_phc.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import os
import sys
from enum import Enum
from types import SimpleNamespace

from isaacgym import gymapi
import gymtorch

from gym import spaces
import torch
import numpy as np
from easydict import EasyDict

from smpl_sim.smpllib.smpl_joint_names import SMPL_MUJOCO_NAMES

from phc import PHC_ROOT
from phc.pufferl.poselib_skeleton import SkeletonTree
from phc.pufferl.motion_lib import MotionLibSMPL, FixHeightMode
from phc.pufferl.torch_utils import (
from pufferlib.environments.morph.poselib_skeleton import SkeletonTree
from pufferlib.environments.morph.motion_lib import MotionLibSMPL, FixHeightMode
from pufferlib.environments.morph.torch_utils import (
to_torch,
torch_rand_float,
exp_map_to_quat,
Expand Down Expand Up @@ -380,7 +379,8 @@ def _config_robot(self):
self.humanoid_shapes = torch.tensor(np.array([self.gender_beta] * self.num_envs)).float().to(self.device)

# NOTE: The below SMPL assets must be present.
asset_file_real = str(PHC_ROOT / f"phc/data/assets/mjcf/smpl_{int(self.gender_beta[0])}_humanoid.xml")
# asset_file_real = str(PHC_ROOT / f"phc/data/assets/mjcf/smpl_{int(self.gender_beta[0])}_humanoid.xml")
asset_file_real = "resources/smpl_humanoid.xml"
assert os.path.exists(asset_file_real)

sk_tree = SkeletonTree.from_mjcf(asset_file_real)
Expand All @@ -391,7 +391,7 @@ def _config_robot(self):
asset_options.max_angular_velocity = 100.0
asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE

self.humanoid_asset = self.gym.load_asset(self.sim, "/", asset_file_real, asset_options)
self.humanoid_asset = self.gym.load_asset(self.sim, ".", asset_file_real, asset_options)
self.num_bodies = self.gym.get_asset_rigid_body_count(self.humanoid_asset)
self.num_dof = self.gym.get_asset_dof_count(self.humanoid_asset)

Expand Down Expand Up @@ -897,20 +897,18 @@ def _setup_env_buffers(self):
self.ref_dof_pos = torch.zeros_like(self._dof_pos)

def _load_motion(self, motion_train_file, motion_test_file=None):
motion_lib_cfg = EasyDict(
{
"motion_file": motion_train_file,
"device": self.device,
"fix_height": FixHeightMode.full_fix,
"min_length": self._min_motion_len,
"max_length": self._max_motion_len,
"im_eval": self.flag_im_eval,
"multi_thread": False, # CHECK ME: need to config?
"smpl_type": self.humanoid_type,
"randomrize_heading": True,
"step_dt": self.dt,
"is_deterministic": self.flag_debug,
}
motion_lib_cfg = SimpleNamespace(
motion_file=motion_train_file,
device=self.device,
fix_height=FixHeightMode.full_fix,
min_length=self._min_motion_len,
max_length=self._max_motion_len,
im_eval=self.flag_im_eval,
multi_thread=False, # CHECK ME: need to config?
smpl_type=self.humanoid_type,
randomrize_heading=True,
step_dt=self.dt,
is_deterministic=self.flag_debug,
)
self._motion_train_lib = MotionLibSMPL(motion_lib_cfg)

Expand Down
2 changes: 1 addition & 1 deletion pufferlib/environments/morph/motion_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __getattr__(self, string):
class MotionLibBase:
def __init__(self, motion_lib_cfg):
self.m_cfg = motion_lib_cfg
self._sim_fps = 1 / self.m_cfg.get("step_dt", 1 / 30) # CHECK ME: hardcoded
self._sim_fps = 1 / getattr(self.m_cfg, "step_dt", 1 / 30) # CHECK ME: hardcoded
print("SIM FPS (from MotionLibBase):", self._sim_fps)
self._device = self.m_cfg.device

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1):
super().__init__(env, policy, input_size, hidden_size, num_layers)

class Policy(nn.Module):
def __init__(self, env, input_dim, action_dim, demo_dim, hidden):
def __init__(self, env, input_dim=934, action_dim=69, demo_dim=358, hidden=512):
super().__init__()
self.is_continuous = True

Expand Down Expand Up @@ -50,10 +50,10 @@ def __init__(self, env, input_dim, action_dim, demo_dim, hidden):
nn.SiLU(),
layer_init(nn.Linear(1024, 512)),
nn.SiLU(),
layer_init(nn.Linear(512, action_dim)),
layer_init(nn.Linear(512, hidden)),
nn.SiLU(),
layer_init(nn.Linear(hidden, 1)),
)
self.value = nn.Linear(hidden, 1)

### Discriminator
self._disc_mlp = nn.Sequential(
Expand All @@ -66,7 +66,8 @@ def __init__(self, env, input_dim, action_dim, demo_dim, hidden):

def forward(self, observations):
hidden, lookup = self.encode_observations(observations)
actions, value = self.decode_actions(hidden, lookup)
actions, _ = self.decode_actions(hidden, lookup)
value = self.critic_mlp(observations)
return actions, value

def encode_observations(self, obs):
Expand All @@ -76,8 +77,9 @@ def decode_actions(self, hidden, lookup=None):
mu = self.mu(hidden)
std = torch.exp(self.sigma).expand_as(mu)
probs = torch.distributions.Normal(mu, std)
value = self.value(hidden)
return probs, value

# value = self.value(hidden)
return probs, 0 # NOTE: value comes form the separate critic network

def discriminate(self, amp_obs):
disc_mlp_out = self._disc_mlp(amp_obs)
Expand All @@ -95,5 +97,3 @@ def disc_weights(self):

weights.append(torch.flatten(self._disc_logits.weight))
return weights


0 comments on commit 256c711

Please sign in to comment.