Skip to content

Commit

Permalink
Isaac binding physhoi
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Suarez committed Jan 16, 2025
1 parent 40827e0 commit 5e7518b
Show file tree
Hide file tree
Showing 11 changed files with 376 additions and 10 deletions.
34 changes: 34 additions & 0 deletions config/ase.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[base]
package = ase
env_name = ase
vec = native
rnn_name = Recurrent

[env]
env_cfg_file ="ase/data/cfg/humanoid_ase_sword_shield_getup.yaml"
motion_file = "ase/data/motions/reallusion_sword_shield/RL_Avatar_Atk_Jump_Motion.npy"

[train]
total_timesteps = 200_000_000
num_envs = 2
num_workers = 2
env_batch_size = 1
batch_size = 131072
update_epochs = 1
minibatch_size = 32768
bptt_horizon = 16
anneal_lr = False
gae_lambda = 0.9776227170639571
gamma = 0.8567482546637853
clip_coef = 0.011102333784435113
vf_coef = 0.3403069830175013
vf_clip_coef = 0.26475190539131727
max_grad_norm = 0.8660179376602173
ent_coef = 0.01376980586465873
learning_rate = 0.002064722899262613
checkpoint_interval = 1000
device = cuda

[sweep.metric]
goal = maximize
name = environment/reward
39 changes: 39 additions & 0 deletions config/physhoi.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
[base]
package = physhoi
env_name = physhoi
vec = native
rnn_name = Recurrent

[env]
env_cfg_file ="PhysHOI/physhoi/data/cfg/physhoi.yaml"
motion_file = "PhysHOI/physhoi/data/motions/BallPlay/walkpick.pt"
physx_num_threads = 4
physx_num_subscenes = 0
physx_num_client_threads = 0
num_envs = 2048
use_gpu = True

[train]
total_timesteps = 200_000_000
num_envs = 1
num_workers = 1
env_batch_size = 1
batch_size = 65536
update_epochs = 6
minibatch_size = 16384
bptt_horizon = 16
anneal_lr = False
gae_lambda = 0.95
gamma = 0.99
clip_coef = 0.2
vf_coef = 5
vf_clip_coef = 0.5
max_grad_norm = 1.0
ent_coef = 0.0005
learning_rate = 0.0005
checkpoint_interval = 1000
device = cuda

[sweep.metric]
goal = maximize
name = environment/reward
4 changes: 4 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import ast
import os

import isaacgym # noqa
from isaacgym import gymapi
from isaacgym import gymutil

import pufferlib
import pufferlib.utils
import pufferlib.vector
Expand Down
12 changes: 12 additions & 0 deletions pufferlib/environments/ase/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .environment import env_creator, make

try:
import torch
except ImportError:
pass
else:
from .torch import Policy
try:
from .torch import Recurrent
except:
Recurrent = None
87 changes: 87 additions & 0 deletions pufferlib/environments/ase/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from pdb import set_trace as T

import gymnasium
import functools
import yaml

import isaacgym # noqa
from isaacgym import gymapi
from isaacgym import gymutil

from ase.env.tasks.humanoid_amp_getup import HumanoidAMPGetup
import torch

import pufferlib.emulation
import pufferlib.environments
import pufferlib.postprocess


def env_creator(name='ase'):
return functools.partial(make, name=name)


def make(env_cfg_file, motion_file,
physx_num_threads=1, physx_num_subscenes=1, physx_num_client_threads=1,
sim_timestep=1.0 / 60.0, headless=False,
device_id=0, use_gpu=True, num_envs=1, buf=None):

sim_params = gymapi.SimParams()
sim_params.dt = sim_timestep
sim_params.use_gpu_pipeline = use_gpu
sim_params.physx.use_gpu = use_gpu
sim_params.physx.max_gpu_contact_pairs = 8 * 1024 * 1024
sim_params.physx.num_threads = physx_num_threads
sim_params.physx.num_subscenes = physx_num_subscenes
sim_params.num_client_threads = physx_num_client_threads
#if "sim" in cfg:
# gymutil.parse_sim_config(cfg["sim"], sim_params)


rl_device = "cpu"
if use_gpu:
assert torch.cuda.is_available(), "CUDA is not available"
rl_device = "cuda:" + str(device_id)

with open(env_cfg_file, "r") as f:
cfg = yaml.load(f, Loader=yaml.SafeLoader)

assert "env" in cfg, "env is not set in the config file"
assert "sim" in cfg, "sim is not set in the config file"

# Fill in the env config
cfg["env"]["numEnvs"] = num_envs
cfg["env"]["motion_file"] = motion_file

# Use gpu and physx by default
# NOTE: Start with training low-level controller, HumanoidAMPGetup
task = HumanoidAMPGetup(
cfg=cfg,
sim_params=sim_params,
physics_engine=gymapi.SIM_PHYSX,
device_type=rl_device, # "cuda" if torch.cuda.is_available() and args.cuda else "cpu",
device_id=device_id,
headless=headless,
)

env = ASEPufferEnv(task, buf=buf)

class ASEPufferEnv(pufferlib.PufferEnv):
def __init__(self, env, buf=None):
self.env = env
self.single_observation_space = env.observation_space
self.single_action_space = env.action_space
self.num_agents = env.num_agents
super().__init__(buf)

def reset(self, seed=None):
obs, _ = self.env.reset()
self.observations[:] = obs
return self.observations, {}

def step(self, actions):
obs, reward, done, info = self.env.step(actions)
self.observations[:] = obs
self.rewards[:] = reward
self.terminals[:] = done
self.truncations[:] = False
return self.observations, self.rewards, self.terminals, self.truncations, info
1 change: 1 addition & 0 deletions pufferlib/environments/ase/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pufferlib.models import Default as Policy
12 changes: 12 additions & 0 deletions pufferlib/environments/physhoi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .environment import env_creator, make

try:
import torch
except ImportError:
pass
else:
from .torch import Policy
try:
from .torch import Recurrent
except:
Recurrent = None
180 changes: 180 additions & 0 deletions pufferlib/environments/physhoi/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from pdb import set_trace as T

import gymnasium as gym
import numpy as np
import functools
import yaml

import isaacgym # noqa
from isaacgym import gymapi
from isaacgym import gymutil

from physhoi.env.tasks.physhoi import PhysHOI_BallPlay
from physhoi.env.tasks.task_wrappers import VecTaskWrapper
import torch

import pufferlib.emulation
import pufferlib.environments
import pufferlib.postprocess


def env_creator(name='ase'):
return functools.partial(make, name=name)


def make(name, env_cfg_file, motion_file,
physx_num_threads=1, physx_num_subscenes=1, physx_num_client_threads=1,
sim_timestep=1.0 / 60.0, headless=True,
device_id=0, use_gpu=True, num_envs=32, buf=None):

sim_params = gymapi.SimParams()
sim_params.dt = sim_timestep
sim_params.use_gpu_pipeline = use_gpu
sim_params.physx.use_gpu = use_gpu
sim_params.physx.max_gpu_contact_pairs = 8 * 1024 * 1024
sim_params.physx.num_threads = physx_num_threads
sim_params.physx.num_subscenes = physx_num_subscenes
sim_params.num_client_threads = physx_num_client_threads

rl_device = "cpu"
if use_gpu:
assert torch.cuda.is_available(), "CUDA is not available"
rl_device = "cuda:" + str(device_id)

with open(env_cfg_file, "r") as f:
cfg = yaml.load(f, Loader=yaml.SafeLoader)

assert "env" in cfg, "env is not set in the config file"
assert "sim" in cfg, "sim is not set in the config file"

if "sim" in cfg:
gymutil.parse_sim_config(cfg["sim"], sim_params)

# Fill in the env config
cfg["env"]["numEnvs"] = num_envs
cfg["env"]["motion_file"] = motion_file

# Patch paths
cfg["env"]["asset"]["assetRoot"] = 'PhysHOI/' + cfg["env"]["asset"]["assetRoot"]

task = PhysHOI_BallPlay(
cfg=cfg,
sim_params=sim_params,
physics_engine=gymapi.SIM_PHYSX,
device_type=rl_device, # "cuda" if torch.cuda.is_available() and args.cuda else "cpu",
device_id=device_id,
headless=headless,
)

envs = VecTaskWrapper(task, rl_device, clip_observations=np.inf, clip_actions=1.0)
print("num_envs: {:d}".format(envs.num_envs))
print("num_actions: {:d}".format(envs.num_actions))
print("num_obs: {:d}".format(envs.num_obs))
print("num_states: {:d}".format(envs.num_states))

envs = RecordEpisodeStatisticsTorch(envs, torch.device(rl_device))
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs.use_gpu = use_gpu
assert isinstance(
envs.single_action_space, pufferlib.spaces.Box
), "only continuous action space is supported"

return PhysHOIPufferEnv(envs, buf=buf)

class RecordEpisodeStatisticsTorch(gym.Wrapper):
def __init__(self, env, device):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.device = device
self.episode_returns = None
self.episode_lengths = None
self.infos = {
'episode_return': [],
'episode_length': [],
}

def reset(self, env_ids=None):
obs = self.env.reset(env_ids)
if env_ids is None:
self.episode_returns = torch.zeros(
self.num_envs, dtype=torch.float32, device=self.device
)
self.episode_lengths = torch.zeros(self.num_envs, dtype=torch.int32, device=self.device)
self.returned_episode_returns = torch.zeros(
self.num_envs, dtype=torch.float32, device=self.device
)
self.returned_episode_lengths = torch.zeros(
self.num_envs, dtype=torch.int32, device=self.device
)
else:
self.infos['episode_return'] += self.episode_returns[env_ids].tolist()
self.infos['episode_length'] += self.episode_lengths[env_ids].tolist()
self.episode_returns[env_ids] = 0
self.episode_lengths[env_ids] = 0

return obs

def step(self, action):
observations, rewards, dones, infos = super().step(action)
self.episode_returns += rewards
self.episode_lengths += 1
return (
observations,
rewards,
dones,
self.infos,
)

def mean_and_log(self):
info = {
'episode_return': np.mean(self.infos['episode_return']),
'episode_length': np.mean(self.infos['episode_length']),
}
self.infos = {
'episode_return': [],
'episode_length': [],
}
return [info]

class PhysHOIPufferEnv(pufferlib.PufferEnv):
def __init__(self, env, log_interval=128, buf=None):
self.env = env
self.single_observation_space = env.single_observation_space
self.single_action_space = env.single_action_space
self.num_agents = env.num_envs
self.log_interval = log_interval
super().__init__(buf)

# WARNING: ONLY works with native vec. Will break in multiprocessing
device = torch.device("cuda" if env.use_gpu else "cpu")
self.observations = torch.from_numpy(self.observations).to(device)
self.actions = torch.from_numpy(self.actions).to(device)
self.rewards = torch.from_numpy(self.rewards).to(device)
self.terminals = torch.from_numpy(self.terminals).to(device)
self.truncations = torch.from_numpy(self.truncations).to(device)

def reset(self, seed=None):
obs = self.env.reset()
self.observations[:] = obs
return self.observations, []

def step(self, actions):
actions_np = actions
actions = torch.from_numpy(actions).cuda()
obs, reward, done, info = self.env.step(actions)
self.observations[:] = obs
self.rewards[:] = reward
self.terminals[:] = done
self.truncations[:] = False

done_indices = torch.nonzero(done).squeeze(-1)
if len(done_indices) > 0:
self.observations[done_indices] = self.env.reset(done_indices)[done_indices]

if len(info['episode_return']) > self.log_interval:
info = self.env.mean_and_log()
else:
info = []

return self.observations, self.rewards, self.terminals, self.truncations, info
2 changes: 2 additions & 0 deletions pufferlib/environments/physhoi/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pufferlib.models import Default as Policy
from pufferlib.models import LSTMWrapper as Recurrent
Loading

0 comments on commit 5e7518b

Please sign in to comment.