Skip to content

Commit

Permalink
Merge pull request octo-models#122 from rail-berkeley/sim_eval_loop
Browse files Browse the repository at this point in the history
Add sim rollout visualization option to train.py + finetune.py
  • Loading branch information
kpertsch authored Dec 5, 2023
2 parents f2923d0 + 57910b1 commit 1b4dc3f
Show file tree
Hide file tree
Showing 9 changed files with 871 additions and 53 deletions.
90 changes: 74 additions & 16 deletions experiments/lucy/aloha_finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from config import wrap

from orca.data.utils.data_utils import ActionEncoding, StateEncoding
from experiments.lucy.aloha_wrapper import AlohaGymEnv


def update_config(config, **kwargs):
Expand All @@ -15,7 +16,7 @@ def update_config(config, **kwargs):


@wrap
def get_config(mode="full"):
def get_config(mode="full", head="mse", augment="full", temp_ensembling="uniform", window_size=1, pred_horizon=50):
# If starting with an ORCA-wrist model, there should be two image keys
# first image key should be the third-person view
# and second image key should be the wrist view
Expand Down Expand Up @@ -50,11 +51,11 @@ def get_config(mode="full"):
"heads_*.map_head.MultiHeadDotProductAttention_0.*",
)
elif mode == "frozen_transformer":
frozen_keys = ("orca_transformer.BlockTransformer_0.*",)
frozen_keys = ("orca_transformer.BlockTransformer_0.*", "*hf_model*")
else:
raise ValueError("Invalid mode")

max_steps = FieldReference(20000)
max_steps = FieldReference(50000)

config = dict(
pretrained_path=placeholder(str),
Expand All @@ -64,16 +65,18 @@ def get_config(mode="full"):
num_val_batches=8,
num_steps=max_steps,
log_interval=100,
eval_interval=500,
save_interval=500,
eval_interval=1, #5000,
save_interval=1, #5000,
save_dir="gs://karl-central-2",
seed=42,
debug_sim=False,
wandb=dict(
project="orca_finetune", group=placeholder(str), entity=placeholder(str)
),
finetuning_dataset=FINETUNING_KWARGS,
modality=None,
finetuning_mode=mode,
window_size=int(window_size),
optimizer=dict(
learning_rate=dict(
init_value=0.0,
Expand All @@ -88,28 +91,48 @@ def get_config(mode="full"):
),
)

goal_relabeling_strategy = "no_image_conditioning"
delete_key_groups_probs = [
(["image_.*", "proprio"], 1.0),
]

if augment == "full":
augment_order = [
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue"
]
elif augment == "none":
augment_order = []

transform_kwargs = dict(
window_size=1,
additional_action_window_size=49,
window_size=int(window_size),
additional_action_window_size=int(pred_horizon) - 1,
resize_size=(256, 256),
image_augment_kwargs=dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.2],
random_contrast=[0.8, 1.2],
random_saturation=[0.8, 1.2],
random_hue=[0.1],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
augment_order=augment_order,
#"random_resized_crop",
#"random_brightness",
#"random_contrast",
#"random_saturation",
#"random_hue",
#],
),
goal_relabeling_strategy="uniform",
action_encoding=ActionEncoding.JOINT_POS_BIMANUAL,
# If the default data loading speed is too slow, try these:
num_parallel_calls=16, # for the most CPU-intensive ops (decoding, resizing, augmenting)
task_augmentation_strategy="delete_task_conditioning",
task_augmentation_kwargs=dict(
delete_key_groups_probs=delete_key_groups_probs,
),
)
config["data_transforms"] = transform_kwargs

Expand All @@ -121,13 +144,18 @@ def get_config(mode="full"):
)
)

if head == "mse":
cls_name = "mse_action_head"
elif head == "L1":
cls_name = "l1_action_head"

config["update_config"] = dict(
model=dict(
heads=dict(
action=dict(
cls_name="mse_action_head",
cls_name=cls_name,
kwargs=dict(
pred_horizon=50,
pred_horizon=int(pred_horizon),
action_dim=14,
vocab_size=256,
normalization_type="normal",
Expand All @@ -149,4 +177,34 @@ def get_config(mode="full"):
},
)
)

if temp_ensembling == "uniform":
use_temp_averaging = True
elif temp_ensembling == "none":
use_temp_averaging = False

config["rollout_envs"] = [
(
"aloha-sim-cube-v0",
dict(
max_episode_length=400,
action_chunk=int(pred_horizon),
vis_fps=25,
video_subsample_rate=2,
norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json",
use_temp_averaging=use_temp_averaging,
)
),
(
"aloha-sim-cube-v0",
dict(
max_episode_length=400,
action_chunk=int(int(pred_horizon)/2),
vis_fps=25,
video_subsample_rate=2,
norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json",
use_temp_averaging=use_temp_averaging,
)
)
]
return ConfigDict(config)
28 changes: 27 additions & 1 deletion experiments/lucy/aloha_scratch_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from config import get_config as get_base_config
from config import update_config, wrap
from functools import partial

from orca.data.utils.data_utils import StateEncoding, ActionEncoding
from experiments.lucy.aloha_wrapper import AlohaGymEnv


def get_config(config_string=None):
Expand All @@ -13,6 +15,8 @@ def get_config(config_string=None):
batch_size=128,
eval_interval=500,
save_interval=500,
trajs_for_rollouts=10,
shuffle_buffer_size=50000,
model={
"observation_tokenizers": {
"image": {
Expand Down Expand Up @@ -68,7 +72,29 @@ def get_config(config_string=None):
additional_action_window_size=49,
action_encoding=ActionEncoding.JOINT_POS_BIMANUAL,
)
)
),
rollout_envs=[
(
"aloha-sim-cube-v0",
dict(
max_episode_length=200,
action_chunk=50,
vis_fps=25,
video_subsample_rate=2,
norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json",
)
),
(
"aloha-sim-cube-v0",
dict(
max_episode_length=200,
action_chunk=30,
vis_fps=25,
video_subsample_rate=2,
norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json",
)
)
],
)

return config
115 changes: 115 additions & 0 deletions experiments/lucy/aloha_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import List
import copy
import cv2
from einops import rearrange
import gym
import jax
import jax.numpy as jnp
import numpy as np
import dlimp as dl

from experiments.lucy.aloha_pro.aloha_scripts.utils import crop_resize
from experiments.lucy.aloha_pro.aloha_scripts.sim_env import BOX_POSE, sample_box_pose
from experiments.lucy.aloha_pro.aloha_scripts.sim_env import make_sim_env


class AlohaGymEnv(gym.Env):
def __init__(
self,
env: gym.Env,
camera_names: List[str],
im_size: int = 256,
seed: int = 1234):
self._env = env
self.observation_space = gym.spaces.Dict(
{
**{
f"image_{i}": gym.spaces.Box(
low=np.zeros((im_size, im_size, 3)),
high=255 * np.ones((im_size, im_size, 3)),
dtype=np.uint8,
) for i in range(len(camera_names))
},
"proprio": gym.spaces.Box(
low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32
),
}
)
self.action_space = gym.spaces.Box(
low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32
)
self.camera_names = camera_names
self._im_size = im_size
self._rng = np.random.default_rng(seed)

def step(self, action):
ts = self._env.step(action)
obs, images = self.get_obs(ts)
reward = ts.reward
info = {"images": images}

if reward == self._env.task.max_reward:
self._episode_is_success = 1

return obs, reward, False, False, info

def reset(self, **kwargs):
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
cube_position = self._rng.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
BOX_POSE[0] = np.concatenate([cube_position, cube_quat])

ts = self._env.reset(**kwargs)
obs, images = self.get_obs(ts)
info = {"images": images}
self._goal_obs = copy.deepcopy(obs) # HACK

self._episode_is_success = 0

return obs, info

def get_obs(self, ts):
curr_obs = {}
vis_images = []

for i, cam_name in enumerate(self.camera_names):
curr_image = ts.observation['images'][cam_name]

# Check for 'cam_high' and crop
if cam_name == 'cam_high':
curr_image = crop_resize(curr_image)

#curr_image = cv2.cvtColor(curr_image, cv2.COLOR_BGR2RGB)
vis_images.append(copy.deepcopy(curr_image))
curr_image = jnp.array(curr_image) # XXX: / 255. ?
curr_obs[f"image_{i}"] = curr_image
curr_obs = dl.transforms.resize_images(
curr_obs, match=curr_obs.keys(), size=(self._im_size, self._im_size))

qpos_numpy = np.array(ts.observation['qpos'])
qpos = jnp.array(qpos_numpy)
curr_obs['proprio'] = qpos

return curr_obs, np.concatenate(vis_images, axis=-2)

def get_task(self):
assert self._goal_obs, "Need to run reset() before!"
return {
"language_instruction": ["pick up the cube and hand it over".encode()],
**jax.tree_map(lambda x: x[None] * 0, self._goal_obs),
}

def get_episode_metrics(self):
return {
"success_rate": self._episode_is_success,
}


# register gym environments
gym.register(
'aloha-sim-cube-v0',
entry_point=lambda: AlohaGymEnv(make_sim_env("sim_transfer_cube"), camera_names=["top"]),
)
Loading

0 comments on commit 1b4dc3f

Please sign in to comment.