Skip to content

Commit

Permalink
Learning policy
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Suarez committed Feb 11, 2025
1 parent df7d9e8 commit 0556518
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 55 deletions.
84 changes: 66 additions & 18 deletions clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def evaluate(data):
actions = actions.cpu().numpy()
mask = torch.as_tensor(mask)# * policy.mask)
o = o if config.cpu_offload else o_device
experience.store(o, value, actions, logprob, r, d, env_id, mask)

state = data.vecenv.state
demo = data.vecenv.demo
experience.store(o, state, demo, value, actions, logprob, r, d, env_id, mask)

for i in info:
for k, v in pufferlib.utils.unroll_nested_dict(i):
Expand Down Expand Up @@ -161,10 +164,37 @@ def train(data):
dones_np = experience.dones_np[idxs]
values_np = experience.values_np[idxs]
rewards_np = experience.rewards_np[idxs]
# TODO: bootstrap between segment bounds
advantages_np = compute_gae(dones_np, values_np,
rewards_np, config.gamma, config.gae_lambda)
experience.flatten_batch(advantages_np)
experience.flatten_batch()

# Compute adversarial reward. Note: discriminator doesn't get
# updated as often this way, but GAE is more accurate
state = experience.state.view(experience.num_minibatches,
config.minibatch_size, experience.state.shape[-1])
adversarial_reward = torch.zeros(
experience.num_minibatches, config.minibatch_size).to(config.device)

'''
with torch.no_grad():
for mb in range(experience.num_minibatches):
disc_logits = data.policy.policy.discriminate(state[mb]).squeeze()
prob = 1 / (1 + torch.exp(-disc_logits))
adversarial_reward[mb] = -torch.log(torch.maximum(
1 - prob, torch.tensor(0.0001, device=config.device)))
'''

# TODO: Nans in adversarial reward and gae
adversarial_reward_np = adversarial_reward.cpu().numpy().ravel()
advantages_np = compute_gae(dones_np, values_np,
rewards_np + adversarial_reward_np, config.gamma, config.gae_lambda)
advantages = torch.as_tensor(advantages_np).to(config.device)
experience.b_advantages = advantages.reshape(experience.minibatch_rows,
experience.num_minibatches, experience.bptt_horizon).transpose(0, 1).reshape(
experience.num_minibatches, experience.minibatch_size)
experience.returns_np = advantages_np + experience.values_np
experience.b_returns = experience.b_advantages + experience.b_values

# DO NOT CLAMP ACTIONS HERE. Crashes learning.
#experience.b_actions = torch.clamp(experience.b_actions, -1, 1)

# Optimizing the policy and value network
total_minibatches = experience.num_minibatches * config.update_epochs
Expand All @@ -174,8 +204,9 @@ def train(data):
lstm_state = None
for mb in range(experience.num_minibatches):
with profile.train_misc:
obs = experience.b_obs[mb]
obs = obs.to(config.device)
obs = experience.b_obs[mb].to(config.device)
state = experience.b_state[mb].to(config.device)
demo = experience.b_demo[mb].to(config.device)
atn = experience.b_actions[mb]
log_probs = experience.b_logprobs[mb]
val = experience.b_values[mb]
Expand Down Expand Up @@ -235,8 +266,15 @@ def train(data):
else:
v_loss = 0.5 * ((newvalue - ret) ** 2).mean()

# Discriminator loss
#disc_state = data.policy.policy.discriminate(state)
#disc_demo = data.policy.policy.discriminate(demo)
#disc_loss_agent = torch.nn.BCEWithLogitsLoss()(disc_state, torch.zeros_like(disc_state))
#disc_loss_demo = torch.nn.BCEWithLogitsLoss()(disc_demo, torch.ones_like(disc_demo))
#disc_loss = 0.5 * (disc_loss_agent + disc_loss_demo)

entropy_loss = entropy.mean()
loss = pg_loss - config.ent_coef * entropy_loss + v_loss * config.vf_coef
loss = pg_loss - config.ent_coef * entropy_loss + v_loss * config.vf_coef #+ disc_loss * config.disc_coef

with profile.learn:
data.optimizer.zero_grad()
Expand All @@ -249,6 +287,7 @@ def train(data):
with profile.train_misc:
losses.policy_loss += pg_loss.item() / total_minibatches
losses.value_loss += v_loss.item() / total_minibatches
#losses.discriminator += disc_loss.item() / total_minibatches
losses.entropy += entropy_loss.item() / total_minibatches
losses.old_approx_kl += old_approx_kl.item() / total_minibatches
losses.approx_kl += approx_kl.item() / total_minibatches
Expand Down Expand Up @@ -387,6 +426,7 @@ def make_losses():
return pufferlib.namespace(
policy_loss=0,
value_loss=0,
discriminator_loss=0,
entropy=0,
old_approx_kl=0,
approx_kl=0,
Expand All @@ -407,6 +447,10 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, obs_shape, obs_dtyp
obs_device = device if not pin else 'cpu'
self.obs=torch.zeros(batch_size, *obs_shape, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.demo=torch.zeros(batch_size, 358, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.state=torch.zeros(batch_size, 358, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.actions=torch.zeros(batch_size, *atn_shape, dtype=atn_dtype, pin_memory=pin)
self.logprobs=torch.zeros(batch_size, pin_memory=pin)
self.rewards=torch.zeros(batch_size, pin_memory=pin)
Expand Down Expand Up @@ -451,13 +495,15 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, obs_shape, obs_dtyp
def full(self):
return self.ptr >= self.batch_size

def store(self, obs, value, action, logprob, reward, done, env_id, mask):
def store(self, obs, state, demo, value, action, logprob, reward, done, env_id, mask):
# Mask learner and Ensure indices do not exceed batch size
ptr = self.ptr
indices = torch.where(mask)[0].numpy()[:self.batch_size - ptr]
end = ptr + len(indices)

self.obs[ptr:end] = obs.to(self.obs.device)[indices]
self.state[ptr:end] = state.to(self.state.device)[indices]
self.demo[ptr:end] = demo.to(self.demo.device)[indices]
self.values_np[ptr:end] = value.cpu().numpy()[indices]
self.actions_np[ptr:end] = action[indices]
self.logprobs_np[ptr:end] = logprob.cpu().numpy()[indices]
Expand All @@ -473,29 +519,31 @@ def sort_training_data(self):
self.b_idxs_obs = torch.as_tensor(idxs.reshape(
self.minibatch_rows, self.num_minibatches, self.bptt_horizon
).transpose(1,0,-1)).to(self.obs.device).long()
self.b_idxs_state = torch.as_tensor(idxs.reshape(
self.minibatch_rows, self.num_minibatches, self.bptt_horizon
).transpose(1,0,-1)).to(self.state.device).long()
self.b_idxs_demo = torch.as_tensor(idxs.reshape(
self.minibatch_rows, self.num_minibatches, self.bptt_horizon
).transpose(1,0,-1)).to(self.demo.device).long()
self.b_idxs = self.b_idxs_obs.to(self.device)
self.b_idxs_flat = self.b_idxs.reshape(
self.num_minibatches, self.minibatch_size)
self.sort_keys = []
return idxs

def flatten_batch(self, advantages_np):
advantages = torch.as_tensor(advantages_np).to(self.device)
def flatten_batch(self):
b_idxs, b_flat = self.b_idxs, self.b_idxs_flat
self.b_actions = self.actions.to(self.device, non_blocking=True)
self.b_logprobs = self.logprobs.to(self.device, non_blocking=True)
self.b_dones = self.dones.to(self.device, non_blocking=True)
self.b_values = self.values.to(self.device, non_blocking=True)
self.b_advantages = advantages.reshape(self.minibatch_rows,
self.num_minibatches, self.bptt_horizon).transpose(0, 1).reshape(
self.num_minibatches, self.minibatch_size)
self.returns_np = advantages_np + self.values_np
self.b_obs = self.obs[self.b_idxs_obs]
self.b_state = self.state[self.b_idxs_state]
self.b_demo = self.demo[self.b_idxs_demo]
self.b_actions = self.b_actions[b_idxs].contiguous()
self.b_logprobs = self.b_logprobs[b_idxs]
self.b_dones = self.b_dones[b_idxs]
self.b_values = self.b_values[b_flat]
self.b_returns = self.b_advantages + self.b_values

class Utilization(Thread):
def __init__(self, delay=1, maxlen=20):
Expand Down Expand Up @@ -591,7 +639,7 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k

frames = []
tick = 0
while tick <= 2000:
while tick <= 200000:
if tick % 1 == 0:
render = driver.render()
if driver.render_mode == 'ansi':
Expand Down Expand Up @@ -771,4 +819,4 @@ def print_dashboard(env_name, utilization, global_step, epoch,
with console.capture() as capture:
console.print(dashboard)

print('\033[0;0H' + capture.get())
print('\033[0;0H' + capture.get())
21 changes: 11 additions & 10 deletions config/morph.ini
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
[base]
package = morph
env_name = morph morph-render
env_name = morph
vec = native
policy_name = Policy
# rnn_name = Recurrent

[policy]
; input_dim = 934
; action_dim = 69
; demo_dim = 358
hidden_size = 2048
input_dim = 934
action_dim = 69
demo_dim = 358
hidden = 2048

[env]
motion_file = "resources/morph/totalcapture_acting_poses.pkl"
has_self_collision = True
#has_self_collision = True
has_self_collision = False
num_envs = 2048
#num_envs = 32
#headless = False
Expand All @@ -35,16 +36,16 @@ num_workers = 1
num_envs = 1
batch_size = 65536
minibatch_size = 16384
; batch_size = 1024
; minibatch_size = 256
#batch_size = 1024
#minibatch_size = 256

disc_coef = 5.0

update_epochs = 4
bptt_horizon = 8
anneal_lr = False
gae_lambda = 0.99
gamma = 0.95
gae_lambda = 0.95
gamma = 0.99
clip_coef = 0.2
clip_vloss = True
vf_coef = 2.0
Expand Down
24 changes: 13 additions & 11 deletions pufferlib/environments/morph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from .environment import env_creator

# try:
# import torch
# except ImportError:
# pass
# else:
# from .torch import Policy
# try:
# from .torch import Recurrent
# except:
# Recurrent = None
try:
import torch
except ImportError:
pass
else:
from .torch import Policy
try:
from .torch import Recurrent
except:
Recurrent = None

'''
try:
import pufferlib.environments.morph.policy as torch
except ImportError:
Expand All @@ -20,4 +21,5 @@
try:
from .policy import Recurrent
except:
Recurrent = None
Recurrent = None
'''
14 changes: 6 additions & 8 deletions pufferlib/environments/morph/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@ def __init__(self, name, motion_file, has_self_collision, num_envs=32, device_ty
},
'exp_name': exp_name,
}
if name == 'morph':
if headless:
self.env = HumanoidPHC(cfg, device_type=device_type, device_id=device_id, headless=headless)
elif name == 'morph-render':
self.env = HumanoidRenderEnv(cfg, device_type=device_type, device_id=device_id, headless=headless)
else:
raise ValueError(f'Unknown environment {name}')
self.env = HumanoidRenderEnv(cfg, device_type=device_type, device_id=device_id, headless=headless)

self.single_observation_space = self.env.single_observation_space
self.single_action_space = self.env.single_action_space
Expand Down Expand Up @@ -66,8 +64,8 @@ def __init__(self, name, motion_file, has_self_collision, num_envs=32, device_ty

def reset(self, seed=None):
self.env.reset()
# self.demo = self.env.demo
# self.state = self.env.state
self.demo = self.env.demo
self.state = self.env.state
self.tick = 0
return self.observations, []

Expand All @@ -78,8 +76,8 @@ def step(self, actions_np):

# obs, reward, done are put into the buffers
self.env.step(self.actions)
# self.demo = self.env.demo
# self.state = self.env.state
self.demo = self.env.demo
self.state = self.env.state

self.terminals[:] = self.env.reset_buf
done_indices = torch.nonzero(self.terminals).squeeze(-1)
Expand Down
15 changes: 11 additions & 4 deletions pufferlib/environments/morph/humanoid_phc.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def step(self, actions):
self.progress_buf += 1

self._refresh_sim_tensors()

#body_pos = self._rigid_body_pos
#self.rew_buf[:] = body_pos[:, 0, 2]
#self.reset_buf[:] = body_pos[:, 0, 2] < 0.25

self._compute_reward()

# NOTE: Which envs must be reset is computed here, but the envs get reset outside the env
Expand Down Expand Up @@ -559,6 +564,8 @@ def _create_envs(self):
dof_prop["driveMode"] = gymapi.DOF_MODE_POS
dof_prop["stiffness"] *= self._kp_scale
dof_prop["damping"] *= self._kd_scale
dof_prop["stiffness"] = 1000
dof_prop["damping"] = 200

# NOTE: (from Joseph) You get a small perf boost (~4%) by putting all the actors in the same env
for i in range(self.num_envs):
Expand Down Expand Up @@ -908,7 +915,7 @@ def _load_motion(self, motion_train_file, motion_test_file=None):
# TODO: find a way to evaluate full motion, probably not during training
max_length=self.max_episode_length,
im_eval=self.flag_im_eval,
multi_thread=False, # CHECK ME: need to config?
multi_thread=False,
smpl_type=self.humanoid_type,
randomrize_heading=True,
step_dt=self.dt,
Expand Down Expand Up @@ -1217,7 +1224,7 @@ def _compute_observations(self, env_ids=None):

# This is the normalized vector with position, rotation, velocity, and
# angular velocity for the simulated humanoid and the demo data
# self.state, self.demo = self._compute_state_obs(env_ids)
self.state, self.demo = self._compute_state_obs(env_ids)

if self.add_obs_noise and not self.flag_test:
obs = obs + torch.randn_like(obs) * 0.1
Expand Down Expand Up @@ -1498,7 +1505,7 @@ def _compute_reward(self):
body_rot = self._rigid_body_rot
body_vel = self._rigid_body_vel
body_ang_vel = self._rigid_body_ang_vel

motion_times = (
self.progress_buf * self.dt + self._motion_start_times + self._motion_start_times_offset
) # reward is computed after physics step, and progress_buf is already updated for next time step.
Expand Down Expand Up @@ -1766,7 +1773,7 @@ def remove_base_rot(quat):
return quat_mul(quat, base_rot.repeat(shape, 1))


@torch.jit.script
#@torch.jit.script
def compute_humanoid_observations_smpl_max(
body_pos,
body_rot,
Expand Down
7 changes: 3 additions & 4 deletions pufferlib/environments/morph/render_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def render(self):
sys.exit()

if self.viewer or self.flag_server_mode:
self._update_camera()
#self._update_camera()
self._update_marker()

# check for keyboard events
Expand Down Expand Up @@ -233,10 +233,9 @@ def _create_viewer(self):

def _init_camera(self):
self.gym.refresh_actor_root_state_tensor(self.sim)
self._cam_prev_char_pos = self._humanoid_root_states[0, 0:3].cpu().numpy()
cam_pos = gymapi.Vec3(20.0, 25.0, 3.0)
cam_target = gymapi.Vec3(10.0, 15.0, 0.0)

cam_pos = gymapi.Vec3(self._cam_prev_char_pos[0], self._cam_prev_char_pos[1] - 3.0, 1.0)
cam_target = gymapi.Vec3(self._cam_prev_char_pos[0], self._cam_prev_char_pos[1], 1.0)
if self.viewer:
self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)

Expand Down

0 comments on commit 0556518

Please sign in to comment.