Skip to content

Commit

Permalink
Integrate Kyoung's fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Suarez committed Feb 10, 2025
1 parent 91e709a commit bd8120a
Show file tree
Hide file tree
Showing 4 changed files with 551 additions and 13 deletions.
15 changes: 8 additions & 7 deletions clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def train(data):
rewards_np = experience.rewards_np[idxs]
experience.flatten_batch()


# Optimizing the policy and value network
total_minibatches = experience.num_minibatches * config.update_epochs
mean_pg_loss, mean_v_loss, mean_entropy_loss = 0, 0, 0
Expand All @@ -175,12 +174,14 @@ def train(data):
config.minibatch_size, experience.state.shape[-1])
adversarial_reward = torch.zeros(
experience.num_minibatches, config.minibatch_size).to(config.device)
'''
with torch.no_grad():
for mb in range(experience.num_minibatches):
disc_logits = data.policy.policy.discriminate(state[mb]).squeeze()
prob = 1 / (1 + torch.exp(-disc_logits))
adversarial_reward[mb] = -torch.log(torch.maximum(
1 - prob, torch.tensor(0.0001, device=config.device)))
'''

# TODO: Nans in adversarial reward and gae
adversarial_reward_np = adversarial_reward.cpu().numpy().ravel()
Expand Down Expand Up @@ -222,10 +223,10 @@ def train(data):
# Discriminator loss
# BUG: Data shape is wrong for morph. State should have same shape as demo
disc_state = data.policy.policy.discriminate(state)
disc_demo = data.policy.policy.discriminate(demo)
disc_loss_agent = torch.nn.BCEWithLogitsLoss()(disc_state, torch.zeros_like(disc_state))
disc_loss_demo = torch.nn.BCEWithLogitsLoss()(disc_demo, torch.ones_like(disc_demo))
disc_loss = 0.5 * (disc_loss_agent + disc_loss_demo)
#disc_demo = data.policy.policy.discriminate(demo)
#disc_loss_agent = torch.nn.BCEWithLogitsLoss()(disc_state, torch.zeros_like(disc_state))
#disc_loss_demo = torch.nn.BCEWithLogitsLoss()(disc_demo, torch.ones_like(disc_demo))
#disc_loss = 0.5 * (disc_loss_agent + disc_loss_demo)

if config.device == 'cuda':
torch.cuda.synchronize()
Expand Down Expand Up @@ -267,7 +268,7 @@ def train(data):
v_loss = 0.5 * ((newvalue - ret) ** 2).mean()

entropy_loss = entropy.mean()
loss = pg_loss - config.ent_coef*entropy_loss + config.vf_coef*v_loss + config.disc_coef*disc_loss
loss = pg_loss - config.ent_coef*entropy_loss + config.vf_coef*v_loss# + config.disc_coef*disc_loss

with profile.learn:
data.optimizer.zero_grad()
Expand All @@ -284,7 +285,7 @@ def train(data):
losses.old_approx_kl += old_approx_kl.item() / total_minibatches
losses.approx_kl += approx_kl.item() / total_minibatches
losses.clipfrac += clipfrac.item() / total_minibatches
losses.discriminator += disc_loss.item() / total_minibatches
#losses.discriminator += disc_loss.item() / total_minibatches

if config.target_kl is not None:
if approx_kl > config.target_kl:
Expand Down
4 changes: 3 additions & 1 deletion config/morph.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[base]
package = morph
env_name = morph
env_name = morph morph-render
vec = native
policy_name = Policy
# rnn_name = Recurrent
Expand All @@ -15,6 +15,8 @@ hidden = 512
motion_file = "resources/morph/amass_train_take6_upright.pkl"
has_self_collision = True
num_envs = 2048
#num_envs = 32
#headless = False

[train]
seed = 1
Expand Down
27 changes: 22 additions & 5 deletions pufferlib/environments/morph/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools

from pufferlib.environments.morph.humanoid_phc import HumanoidPHC
from pufferlib.environments.morph.render_env import HumanoidRenderEnv

import torch
import numpy as np
Expand All @@ -13,11 +14,12 @@ def env_creator(name='morph'):
return functools.partial(make, name)

def make(name, **kwargs):
return PHCPufferEnv(**kwargs)
return PHCPufferEnv(name, **kwargs)

class PHCPufferEnv(pufferlib.PufferEnv):
def __init__(self, motion_file, has_self_collision, num_envs=32, device_type="cuda",
device_id=0, headless=True, log_interval=32):
def __init__(self, name, motion_file, has_self_collision, num_envs=32, device_type="cuda",
exp_name='morph', clip_actions=True, device_id=0, headless=True, log_interval=32):
self.render_mode = 'native'
cfg = {
'env': {
'num_envs': num_envs,
Expand All @@ -26,18 +28,26 @@ def __init__(self, motion_file, has_self_collision, num_envs=32, device_type="cu
'robot': {
'has_self_collision': has_self_collision,
},
'exp_name': exp_name,
}
self.env = HumanoidPHC(cfg, device_type=device_type, device_id=device_id, headless=headless)
if name == 'morph':
self.env = HumanoidPHC(cfg, device_type=device_type, device_id=device_id, headless=headless)
elif name == 'morph-render':
self.env = HumanoidRenderEnv(cfg, device_type=device_type, device_id=device_id, headless=headless)
else:
raise ValueError(f'Unknown environment {name}')

self.single_observation_space = self.env.single_observation_space
self.single_action_space = self.env.single_action_space
self.num_agents = self.num_envs = self.env.num_envs
self.clip_actions = clip_actions
self.device = self.env.device

# Check the buffer data types, match them to puffer
buffers = pufferlib.namespace(
observations=self.env.obs_buf,
rewards=self.env.rew_buf,
terminals=self.env.reset_buf,
terminals=torch.zeros(self.num_agents, dtype=torch.bool, device=self.device),
truncations=torch.zeros_like(self.env.reset_buf),
masks=torch.ones_like(self.env.reset_buf),
actions=torch.zeros(
Expand All @@ -63,13 +73,17 @@ def reset(self, seed=None):
return self.observations, []

def step(self, actions_np):
if self.clip_actions:
actions_np = np.clip(actions_np, -1, 1)

self.actions[:] = torch.from_numpy(actions_np)

# obs, reward, done are put into the buffers
self.env.step(self.actions)
self.demo = self.env.demo
self.state = self.env.state

self.terminals[:] = self.env.reset_buf
done_indices = torch.nonzero(self.terminals).squeeze(-1)
if len(done_indices) > 0:
self.observations[done_indices] = self.env.reset(done_indices)[done_indices]
Expand All @@ -89,6 +103,9 @@ def step(self, actions_np):

return self.observations, self.rewards, self.terminals, self.truncations, info

def render(self):
return self.env.render()

def close(self):
self.env.close()

Expand Down
Loading

0 comments on commit bd8120a

Please sign in to comment.