From 86e3de6880f8804b576b37c8c7d17930340f2ec9 Mon Sep 17 00:00:00 2001 From: Karl Pertsch Date: Sun, 3 Dec 2023 14:30:27 -0800 Subject: [PATCH 1/4] implement sim eval loop in train + finetune, adds simple logging support for action chunks --- experiments/lucy/aloha_finetune_config.py | 26 ++ experiments/lucy/aloha_scratch_config.py | 30 +- experiments/lucy/aloha_wrapper.py | 105 +++++++ experiments/lucy/eval.py | 328 ++++++++++++++++++++++ experiments/lucy/eval.sh | 45 +++ finetune.py | 44 ++- orca/utils/gym_wrappers.py | 42 ++- orca/utils/visualization_lib.py | 155 +++++++++- train.py | 45 ++- 9 files changed, 786 insertions(+), 34 deletions(-) create mode 100644 experiments/lucy/aloha_wrapper.py create mode 100644 experiments/lucy/eval.py create mode 100755 experiments/lucy/eval.sh diff --git a/experiments/lucy/aloha_finetune_config.py b/experiments/lucy/aloha_finetune_config.py index cabd1b36..0efaaa08 100644 --- a/experiments/lucy/aloha_finetune_config.py +++ b/experiments/lucy/aloha_finetune_config.py @@ -5,6 +5,7 @@ from config import wrap from orca.data.utils.data_utils import ActionEncoding, StateEncoding +from experiments.lucy.aloha_wrapper import AlohaGymEnv def update_config(config, **kwargs): @@ -149,4 +150,29 @@ def get_config(mode="full"): }, ) ) + + config["rollout_envs"] = [ + ( + "aloha-sim-cube-v0", + dict( + max_episode_length=200, + action_chunk=50, + vis_render_size=(320, 240), + vis_fps=25, + video_subsample_rate=2, + norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + ) + ), + ( + "aloha-sim-cube-v0", + dict( + max_episode_length=200, + action_chunk=30, + vis_render_size=(320, 240), + vis_fps=25, + video_subsample_rate=2, + norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + ) + ) + ] return ConfigDict(config) diff --git a/experiments/lucy/aloha_scratch_config.py b/experiments/lucy/aloha_scratch_config.py index df71cdf3..4a34f2fa 100644 --- a/experiments/lucy/aloha_scratch_config.py +++ b/experiments/lucy/aloha_scratch_config.py @@ -1,7 +1,9 @@ from config import get_config as get_base_config from config import update_config, wrap +from functools import partial from orca.data.utils.data_utils import StateEncoding, ActionEncoding +from experiments.lucy.aloha_wrapper import AlohaGymEnv def get_config(config_string=None): @@ -13,6 +15,8 @@ def get_config(config_string=None): batch_size=128, eval_interval=500, save_interval=500, + trajs_for_rollouts=10, + shuffle_buffer_size=50000, model={ "observation_tokenizers": { "image": { @@ -68,7 +72,31 @@ def get_config(config_string=None): additional_action_window_size=49, action_encoding=ActionEncoding.JOINT_POS_BIMANUAL, ) - ) + ), + rollout_envs=[ + ( + "aloha-sim-cube-v0", + dict( + max_episode_length=200, + action_chunk=50, + vis_render_size=(320, 240), + vis_fps=25, + video_subsample_rate=2, + norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + ) + ), + ( + "aloha-sim-cube-v0", + dict( + max_episode_length=200, + action_chunk=30, + vis_render_size=(320, 240), + vis_fps=25, + video_subsample_rate=2, + norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + ) + ) + ], ) return config diff --git a/experiments/lucy/aloha_wrapper.py b/experiments/lucy/aloha_wrapper.py new file mode 100644 index 00000000..4b65209b --- /dev/null +++ b/experiments/lucy/aloha_wrapper.py @@ -0,0 +1,105 @@ +from typing import List +import copy +import cv2 +from einops import rearrange +import gym +import jax +import jax.numpy as jnp +import numpy as np +import dlimp as dl + +from experiments.lucy.aloha_pro.aloha_scripts.utils import crop_resize +from experiments.lucy.aloha_pro.aloha_scripts.sim_env import BOX_POSE, sample_box_pose +from experiments.lucy.aloha_pro.aloha_scripts.sim_env import make_sim_env + + +class AlohaGymEnv(gym.Env): + def __init__( + self, + env: gym.Env, + camera_names: List[str], + im_size: int = 256, + seed: int = 1234): + self._env = env + self.observation_space = gym.spaces.Dict( + { + **{ + f"image_{i}": gym.spaces.Box( + low=np.zeros((im_size, im_size, 3)), + high=255 * np.ones((im_size, im_size, 3)), + dtype=np.uint8, + ) for i in range(len(camera_names)) + }, + "proprio": gym.spaces.Box( + low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32 + ), + } + ) + self.action_space = gym.spaces.Box( + low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32 + ) + self.camera_names = camera_names + self._im_size = im_size + self._rng = np.random.default_rng(seed) + + def step(self, action): + ts = self._env.step(action) + obs, images = self.get_obs(ts) + reward = ts.reward + info = {"images": images} + + return obs, reward, False, False, info + + def reset(self, **kwargs): + x_range = [0.0, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + ranges = np.vstack([x_range, y_range, z_range]) + cube_position = self._rng.uniform(ranges[:, 0], ranges[:, 1]) + cube_quat = np.array([1, 0, 0, 0]) + BOX_POSE[0] = np.concatenate([cube_position, cube_quat]) + + ts = self._env.reset(**kwargs) + obs, images = self.get_obs(ts) + info = {"images": images} + self._goal_obs = obs # HACK + + return obs, info + + def get_obs(self, ts): + curr_obs = {} + vis_images = [] + + for i, cam_name in enumerate(self.camera_names): + curr_image = ts.observation['images'][cam_name] + + # Check for 'cam_high' and crop + if cam_name == 'cam_high': + curr_image = crop_resize(curr_image) + + curr_image = cv2.cvtColor(curr_image, cv2.COLOR_BGR2RGB) + vis_images.append(copy.deepcopy(curr_image)) + curr_image = jnp.array(curr_image) # XXX: / 255. ? + curr_obs[f"image_{i}"] = curr_image + curr_obs = dl.transforms.resize_images( + curr_obs, match=curr_obs.keys(), size=(self._im_size, self._im_size)) + + qpos_numpy = np.array(ts.observation['qpos']) + qpos = jnp.array(qpos_numpy) + curr_obs['proprio'] = qpos + + return curr_obs, np.concatenate(vis_images, axis=-2) + + def get_task(self): + assert self._goal_obs, "Need to run reset() before!" + return { + "language_instruction": ["pick up the cube and hand it over".encode()], + **jax.tree_map(lambda x: x[None], self._goal_obs), + } + + +# register gym environments +gym.register( + 'aloha-sim-cube-v0', + entry_point=lambda: AlohaGymEnv(make_sim_env("sim_transfer_cube"), camera_names=["top"]), +) diff --git a/experiments/lucy/eval.py b/experiments/lucy/eval.py new file mode 100644 index 00000000..b9f7e3d2 --- /dev/null +++ b/experiments/lucy/eval.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 + +from datetime import datetime +from functools import partial +import json +import os +from pathlib import Path, PurePath +import time +import pickle + +from absl import app, flags, logging +import click +import cv2 +import flax +import imageio +from PIL import Image +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf +from einops import rearrange +import wandb + +# aloha +try: + import sys; sys.path.append(os.path.join(os.getcwd(), 'aloha_pro/aloha_scripts/')) + from aloha_pro.aloha_scripts.real_env import make_real_env + from aloha_pro.aloha_scripts.robot_utils import move_grippers +except: + print("Skipping real env import...") +from aloha_pro.aloha_scripts.constants import DT, PUPPET_GRIPPER_JOINT_OPEN +from aloha_pro.aloha_scripts.visualize_episodes import save_videos +from aloha_wrapper import AlohaGymEnv +from aloha_pro.aloha_scripts.sim_env import make_sim_env, sample_box_pose, sample_insertion_pose, BOX_POSE + +from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio +from orca.utils.pretrained_utils import PretrainedModel + +np.set_printoptions(suppress=True) + +logging.set_verbosity(logging.WARNING) + +FLAGS = flags.FLAGS + +flags.DEFINE_multi_string( + "checkpoint_weights_path", None, "Path to checkpoint", required=True +) +flags.DEFINE_multi_integer("checkpoint_step", None, "Checkpoint step", required=True) +flags.DEFINE_bool("add_jaxrlm_baseline", False, "Also compare to jaxrl_m baseline") + + +flags.DEFINE_string( + "checkpoint_cache_dir", + "/tmp/", + "Where to cache checkpoints downloaded from GCS", +) +flags.DEFINE_string( + "modality", "", "Either 'g', 'goal', 'l', 'language' (leave empty to prompt when running)" +) + +flags.DEFINE_integer("im_size", None, "Image size", required=True) +flags.DEFINE_string("video_save_path", None, "Path to save video") +flags.DEFINE_integer("num_timesteps", 500, "num timesteps") +flags.DEFINE_bool("blocking", False, "Use the blocking controller") +flags.DEFINE_spaceseplist("goal_eep", [0.3, 0.0, 0.15], "Goal position") +flags.DEFINE_spaceseplist("initial_eep", [0.3, 0.0, 0.15], "Initial position") +flags.DEFINE_integer("horizon", 1, "Observation history length") +flags.DEFINE_integer("pred_horizon", 1, "Length of action sequence from model") +flags.DEFINE_integer("exec_horizon", 1, "Length of action sequence to execute") +flags.DEFINE_bool("deterministic", False, "Whether to sample action deterministically") +flags.DEFINE_float("temperature", 1.0, "Temperature for sampling actions") +flags.DEFINE_string("ip", "localhost", "IP address of the robot") +flags.DEFINE_integer("port", 5556, "Port of the robot") + +# show image flag +flags.DEFINE_bool("show_image", False, "Show image") + +# sim flags +flags.DEFINE_bool("is_sim", False, "Is simulation env") +flags.DEFINE_string("task_name", "sim_transfer_cube_scripted", "Task name") +flags.DEFINE_string("wandb_name", None, "Wandb log name") + +############################################################################## + +STEP_DURATION_MESSAGE = """ +Bridge data was collected with non-blocking control and a step duration of 0.2s. +However, we relabel the actions to make it look like the data was collected with blocking control and we evaluate with blocking control. +We also use a step duration of 0.4s to reduce the jerkiness of the policy. +Be sure to change the step duration back to 0.2 if evaluating with non-blocking control. +""" +STEP_DURATION = 0.4 +STICKY_GRIPPER_NUM_STEPS = 1 +WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]] +CAMERA_TOPICS = [{"name": "/blue/image_raw"}] +ENV_PARAMS = { + "camera_topics": CAMERA_TOPICS, + "override_workspace_boundaries": WORKSPACE_BOUNDS, + "move_duration": STEP_DURATION, +} + +############################################################################## + + +def maybe_download_checkpoint_from_gcs(cloud_path, step, save_path): + if not cloud_path.startswith("gs://"): + return cloud_path, step # Actually on the local filesystem + + checkpoint_path = tf.io.gfile.join(cloud_path, f"{step}") + norm_path = tf.io.gfile.join(cloud_path, "dataset_statistics*") + config_path = tf.io.gfile.join(cloud_path, "config.json*") + example_batch_path = tf.io.gfile.join(cloud_path, "example_batch.msgpack*") + + run_name = Path(cloud_path).name + save_path = os.path.join(save_path, run_name) + + target_checkpoint_path = os.path.join(save_path, f"{step}") + if os.path.exists(target_checkpoint_path): + logging.warning( + "Checkpoint already exists at %s, skipping download", target_checkpoint_path + ) + return save_path, step + os.makedirs(save_path, exist_ok=True) + logging.warning("Downloading checkpoint and metadata to %s", save_path) + + os.system(f"sudo gsutil cp -r {checkpoint_path} {save_path}/") + os.system(f"sudo gsutil cp {norm_path} {save_path}/") + os.system(f"sudo gsutil cp {config_path} {save_path}/") + os.system(f"sudo gsutil cp {example_batch_path} {save_path}/") + + return save_path, step + + +def supply_rng(f, rng=jax.random.PRNGKey(0)): + def wrapped(*args, **kwargs): + nonlocal rng + rng, key = jax.random.split(rng) + return f(*args, rng=key, **kwargs) + + return wrapped + + +@partial(jax.jit, static_argnames="argmax") +def sample_actions( + pretrained_model: PretrainedModel, + observations, + tasks, + rng, + argmax=False, + temperature=1.0, +): + + # add batch dim to observations + observations = jax.tree_map(lambda x: x[None], observations) + logging.warning( + "observations: %s", flax.core.pretty_repr(jax.tree_map(jnp.shape, observations)) + ) + logging.warning("tasks: %s", flax.core.pretty_repr(jax.tree_map(jnp.shape, tasks))) + actions = pretrained_model.sample_actions( + observations, + tasks, + rng=rng, + argmax=argmax, + temperature=temperature, + ) + # remove batch dim + return actions[0] + + +def load_checkpoint(weights_path, step): + model = PretrainedModel.load_pretrained(weights_path, step=int(step)) + + policy_fn = supply_rng( + partial( + sample_actions, + model, + # argmax=FLAGS.deterministic, # Python version issue + argmax=True, + temperature=FLAGS.temperature, + ), + ) + return (policy_fn, model) + +def main(_): + assert len(FLAGS.checkpoint_weights_path) == len(FLAGS.checkpoint_step) + # policies is a dict from run_name to policy function + policies = {} + for (checkpoint_weights_path, checkpoint_step,) in zip( + FLAGS.checkpoint_weights_path, + FLAGS.checkpoint_step, + ): + checkpoint_weights_path, checkpoint_step = maybe_download_checkpoint_from_gcs( + checkpoint_weights_path, + checkpoint_step, + FLAGS.checkpoint_cache_dir, + ) + assert tf.io.gfile.exists(checkpoint_weights_path), checkpoint_weights_path + run_name = checkpoint_weights_path.rpartition("/")[2] + policies[f"{run_name}-{checkpoint_step}"] = load_checkpoint( + checkpoint_weights_path, + checkpoint_step, + ) + + # ask for which policy to use + if len(policies) == 1: + policy_idx = 0 + print("Using default policy 0: ", list(policies.keys())[policy_idx]) + else: + print("policies:") + for i, name in enumerate(policies.keys()): + print(f"{i}) {name}") + policy_idx = click.prompt("Select policy", type=int) + + policy_name = list(policies.keys())[policy_idx] + policy_fn, model = policies[policy_name] + model: PretrainedModel # type hinting + + # set up environment + if FLAGS.is_sim: + env = make_sim_env(task_name=FLAGS.task_name) + camera_names = ['top'] + env_max_reward = env.task.max_reward + episode_returns = [] + highest_rewards = [] + else: + env = make_real_env(init_node=True) + from interbotix_xs_modules.arm import InterbotixManipulatorXS + master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_left', init_node=False) + master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_right', init_node=False) + + camera_names = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] + + # load normalization statistics + metadata_path = os.path.join( + checkpoint_weights_path, "dataset_statistics_aloha_sim_cube_scripted_dataset.json" + ) + with open(metadata_path, "r") as f: + norm_statistics = json.load(f) + + # wrap environment for history conditioning, action chunking and action/proprio norm/denorm + env = AlohaGymEnv(env, camera_names) + env = HistoryWrapper(env, FLAGS.horizon) + env = RHCWrapper(env, FLAGS.exec_horizon) + env = UnnormalizeActionProprio(env, norm_statistics, normalization_type="normal") + + query_frequency = FLAGS.exec_horizon # chunk size + max_timesteps = FLAGS.num_timesteps // query_frequency + num_rollouts = 50 + + wandb_id = "{name}_{task}_chunk{chunk}_{time}".format( + name=policy_name, + task=FLAGS.task_name, + chunk=FLAGS.exec_horizon, + time=datetime.now().strftime("%Y%m%d_%H%M%S"), + ) + wandb.init( + id=wandb_id, + name=FLAGS.wandb_name, + project="aloha_eval" + ) + + n_existing_rollouts = len([f for f in os.listdir(FLAGS.video_save_path) if f.startswith('video')]) + print(f'{n_existing_rollouts=}') + + for rollout_id in range(num_rollouts): + if FLAGS.is_sim: + ### set task + if 'sim_transfer_cube' in FLAGS.task_name: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif 'sim_insertion' in FLAGS.task_name: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + rewards = [] + + obs, info = env.reset() + image_list = [] # for visualization + + for t in range(max_timesteps): + if t > 0: + image_list.extend(info['images']) + + # query policy + actions = policy_fn(obs, tasks={}) + target_qpos = np.array(actions) + + obs, reward, done, trunc, info = env.step(target_qpos) + if FLAGS.is_sim: + rewards.extend(info["rewards"]) + + if FLAGS.is_sim: + rewards = np.array(rewards) + episode_return = np.sum(rewards[rewards!=None]) + episode_returns.append(episode_return) + episode_highest_reward = np.max(rewards) + highest_rewards.append(episode_highest_reward) + print(f'Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, ' + f'{env_max_reward=}, Success: {episode_highest_reward==env_max_reward}') + else: + move_grippers([env.puppet_bot_left, env.puppet_bot_right], + [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # open + + print(f'Finished rollout {rollout_id}') + if rollout_id < 3: + # construct video, resize + imgs = [np.array(Image.fromarray(img).resize(int(320*len(camera_names)), 240)) for img in image_list] + video = np.stack(imgs) + wandb.log({ + f"{policy_name}/rollout_{rollout_id}": wandb.Video(video.transpose(0, 3, 1, 2)[::2], fps=25) + }) + + if FLAGS.is_sim: + success_rate = np.mean(np.array(highest_rewards) == env_max_reward) + avg_return = np.mean(episode_returns) + summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n' + for r in range(env_max_reward+1): + more_or_equal_r = (np.array(highest_rewards) >= r).sum() + more_or_equal_r_rate = more_or_equal_r / num_rollouts + summary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n' + + print(summary_str) + + wandb.log({ + f"{policy_name}/success_rate": success_rate, + f"{policy_name}/average_return": avg_return, + }) + +if __name__ == "__main__": + app.run(main) diff --git a/experiments/lucy/eval.sh b/experiments/lucy/eval.sh new file mode 100755 index 00000000..a3395297 --- /dev/null +++ b/experiments/lucy/eval.sh @@ -0,0 +1,45 @@ +PATHS=( + # "gs://karl-central-2/orca/aloha_scratch_chunk50_vit_s_updated_20231120_213607" + #"gs://karl-central-2/orca/aloha_scratch_chunk50_vit_ti_updated_20231120_213331" + # "gs://karl-central-2/orca/aloha_scratch_chunk50_vit_ti_d01_20231119_000843" + # "gs://karl-central-2/orca_finetune/aloha_finetune_naiv_20231124_231333" + # "gs://karl-central-2/orca_finetune/aloha_finetune_frozen_20231124_231411" + "gs://karl-central-2/orca_finetune/aloha_sim_scratch_vit_s_20231130_080455" +) + +STEPS=( + "50000" + #"20000" +) + +CONDITIONING_MODE="" + +TIMESTEPS="400" + +TEMPERATURE="1.0" + +HORIZON="1" + +PRED_HORIZON="50" + +EXEC_HORIZON="50" + +CMD="python eval.py \ + --num_timesteps $TIMESTEPS \ + --video_save_path gs://karl-central-2/orca_sim_eval/videos \ + $(for i in "${!PATHS[@]}"; do echo "--checkpoint_weights_path ${PATHS[$i]} "; done) \ + $(for i in "${!PATHS[@]}"; do echo "--checkpoint_step ${STEPS[$i]} "; done) \ + --im_size 256 \ + --temperature $TEMPERATURE \ + --horizon $HORIZON \ + --pred_horizon $PRED_HORIZON \ + --exec_horizon $EXEC_HORIZON \ + --modality $CONDITIONING_MODE \ + --checkpoint_cache_dir /tmp/ \ + --is_sim \ + --task_name sim_transfer_cube_scripted +" + +echo $CMD + +$CMD diff --git a/finetune.py b/finetune.py index 657f7f3b..7cccf8c7 100644 --- a/finetune.py +++ b/finetune.py @@ -31,7 +31,7 @@ Timer, TrainState, ) -from orca.utils.visualization_lib import Visualizer +from orca.utils.visualization_lib import RolloutVisualizer, Visualizer try: from jax_smi import initialise_tracking # type: ignore @@ -183,6 +183,28 @@ def create_iterator(dataset): example_batch = next(train_data_iter) + ######### + # + # Optionally build visualizers for sim env evals + # + ######### + + if FLAGS.config.get("rollout_envs", None): + rollout_visualizers = [] + for env_name, visualizer_kwargs in FLAGS.config["rollout_envs"]: + input_kwargs = dict( + env_name=env_name, + history_length=FLAGS.config["data_transforms"]["window_size"], + action_chunk=config["model"]["heads"]["action"]["kwargs"].get( + "pred_horizon", 1 + ), + text_processor=text_processor, + ) + input_kwargs.update(visualizer_kwargs) + rollout_visualizers.append(RolloutVisualizer(**input_kwargs)) + else: + rollout_visualizers = None + ######### # # Load Pretrained Model @@ -364,8 +386,9 @@ def sample_actions(state, observations, tasks): sample_shape=(SAMPLES_FOR_VIZ,), rng=state.rng, ) - actions = actions[..., 0, :] # get prediction for current action - actions = jnp.moveaxis(actions, 0, 1) # (batch_size, n_samples, action_dim) + actions = jnp.moveaxis( + actions, 0, 1 + ) # (batch_size, n_samples, pred_horizon, action_dim) return actions ######### @@ -427,6 +450,21 @@ def wandb_log(info, step): step=i, ) + if rollout_visualizers: + with timer("rollout"): + for rollout_visualizer in rollout_visualizers: + logging.info("Running rollouts...") + rollout_infos = rollout_visualizer.run_rollouts( + policy_fn, n_rollouts=10 + ) + wandb_log( + { + f"rollouts_{rollout_visualizer.env_name}" + f"_chunk{rollout_visualizer.action_chunk}": rollout_infos, + }, + step=i, + ) + if (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None: logging.info("Saving checkpoint...") diff --git a/orca/utils/gym_wrappers.py b/orca/utils/gym_wrappers.py index 3e74583d..e55c43d6 100644 --- a/orca/utils/gym_wrappers.py +++ b/orca/utils/gym_wrappers.py @@ -43,6 +43,10 @@ def space_stack(space: gym.Space, repeat: int): raise ValueError(f"Space {space} is not supported by ORCA Gym wrappers.") +def listdict2dictlist(LD): + return {k: [dic[k] for dic in LD] for k in LD[0]} + + class HistoryWrapper(gym.Wrapper): """ Accumulates the observation history into `horizon` size chunks. If the length of the history @@ -58,7 +62,7 @@ def __init__(self, env: gym.Env, horizon: int): self.history = deque(maxlen=self.horizon) self.num_obs = 0 - self.observation_space = space_stack(self.env.observation_space, self.horizon) + # self.observation_space = space_stack(self.env.observation_space, self.horizon) def step(self, action): obs, reward, done, trunc, info = self.env.step(action) @@ -84,24 +88,30 @@ class RHCWrapper(gym.Wrapper): we execute `exec_horizon` of them. """ - def __init__(self, env: gym.Env, pred_horizon: int, exec_horizon: int): + def __init__(self, env: gym.Env, exec_horizon: int): super().__init__(env) - assert exec_horizon <= pred_horizon - - self.pred_horizon = pred_horizon self.exec_horizon = exec_horizon - self.action_space = space_stack(self.env.action_space, self.pred_horizon) - def step(self, actions): - assert len(actions) == self.pred_horizon + assert len(actions) >= self.exec_horizon + rewards = [] + observations = [] + infos = [] for i in range(self.exec_horizon): obs, reward, done, trunc, info = self.env.step(actions[i]) + observations.append(obs) + rewards.append(reward) + infos.append(info) + if done or trunc: break - return obs, reward, done, trunc, info + infos = listdict2dictlist(infos) + infos["rewards"] = rewards + infos["observations"] = observations + + return obs, np.sum(rewards), done, trunc, infos class TemporalEnsembleWrapper(gym.Wrapper): @@ -167,11 +177,23 @@ def unnormalize(self, data, metadata): f"Unknown action/proprio normalization type: {self.normalization_type}" ) + def normalize(self, data, metadata): + if self.normalization_type == "normal": + return (data / (metadata["std"] + 1e-8)) - metadata["mean"] + elif self.normalization_type == "bounds": + return ( + (data + 1) / (2 * (metadata["max"] - metadata["min"] + 1e-8)) + ) + metadata["min"] + else: + raise ValueError( + f"Unknown action/proprio normalization type: {self.normalization_type}" + ) + def action(self, action): return self.unnormalize(action, self.action_proprio_metadata["action"]) def observation(self, obs): - obs["proprio"] = self.unnormalize( + obs["proprio"] = self.normalize( obs["proprio"], self.action_proprio_metadata["proprio"] ) return obs diff --git a/orca/utils/visualization_lib.py b/orca/utils/visualization_lib.py index a3c68156..2c8bd373 100644 --- a/orca/utils/visualization_lib.py +++ b/orca/utils/visualization_lib.py @@ -2,20 +2,27 @@ matplotlib.use("Agg") from dataclasses import dataclass +from functools import reduce +import json +from typing import Any, Callable, Optional, Tuple import dlimp as dl import flax +import gym import jax import jax.numpy as jnp from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import numpy as np +from PIL import Image import plotly.graph_objects as go +import tensorflow as tf import tqdm import wandb from orca.data.utils.data_utils import ActionEncoding +from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio BASE_METRIC_KEYS = { "mse": ("mse", tuple()), # What is the MSE @@ -89,7 +96,8 @@ def run_policy_on_trajectory(policy_fn, traj, *, text_processor=None): horizon = jax.tree_util.tree_leaves(traj["observation"])[0].shape[1] return { "n": np.array(len_traj), - "pred_actions": actions, + "pred_actions_chunk": actions, + "pred_actions": actions[:, :, 0], # only use first predicted action "actions": traj["action"][:, horizon - 1, :], "proprio": traj["observation"]["proprio"][:, horizon - 1], } @@ -254,6 +262,114 @@ def get_maybe_cached_iterator(dataset, n, cached_trajs, use_cache): return cached_trajs[:n] +@dataclass +class RolloutVisualizer: + """ + Runs policy rollouts on a given simulated environment. + + Args: + env_name (str): Gym.make environment creation string + history_length (int): Number of history steps policy gets conditioned on (window_size). + action_chunk (int): Number of future steps. + max_episode_length (int): Max number of steps per rollout episode. + vis_fps (int): FPS of logged rollout video + video_subsample_rate (int): Subsampling rate for video logging (to reduce video size for high-frequency control) + norm_statistics_path (str, optional): Optional path to stats for de-normalizing policy actionsi & proprio. + """ + + env_name: str + history_length: int + action_chunk: int + max_episode_length: int + vis_fps: int = 10 + video_subsample_rate: int = 1 + norm_statistics_path: Optional[str] = None + text_processor: object = None + + def __post_init__(self): + self._env = gym.make(self.env_name) + self._env = HistoryWrapper(self._env, self.history_length) + self._env = RHCWrapper(self._env, self.action_chunk) + if self.norm_statistics_path: + with tf.io.gfile.GFile(self.norm_statistics_path, "r") as f: + norm_stats = json.load(f) + norm_stats = tree_map(np.array, norm_stats) + self._env = UnnormalizeActionProprio( + self._env, norm_stats, normalization_type="normal" + ) + + def run_rollouts(self, policy_fn, n_rollouts=10, n_vis_rollouts=3): + def extract_images(obs): + return jnp.concatenate([obs[k] for k in obs if "image_" in k], axis=-2) + + def listdict2dictlist(LD): + return {k: [dic[k] for dic in LD] for k in LD[0]} + + rollout_info = { + "episode_returns": [], + "episode_metrics": [], + } + for rollout_idx in tqdm.tqdm(range(n_rollouts)): + obs, info = self._env.reset() + task = self._env.get_task() + if "language_instruction" in task and self.text_processor: + task["language_instruction"] = self.text_processor.encode( + [s.decode("utf-8") for s in task["language_instruction"]] + ) + images = [extract_images(obs)] + episode_return = 0.0 + metrics = [] + for _ in range(self.max_episode_length // self.action_chunk): + # policy outputs are shape [batch, n_samples, pred_horizon, act_dim] + # we remove batch dimension & use first sampled action, ignoring other samples + actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task)[0, 0] + obs, reward, done, trunc, info = self._env.step(actions) + images.extend([extract_images(o) for o in info["observations"]]) + episode_return += reward + if "metrics" in info: + metrics.extend(info["metrics"]) + if done or trunc: + break + + rollout_info["episode_returns"].append(episode_return) + if metrics: + # concatenate all chunks into one dict of lists, then average across episode + metrics = listdict2dictlist(metrics) + rollout_info["episode_metrics"].append( + jax.tree_map(lambda x: np.mean(x), metrics) + ) + if rollout_idx < n_vis_rollouts: + # save rollout video + assert ( + images[0].dtype == np.uint8 + ), f"Expect uint8, got {images[0].dtype}" + assert ( + images[0].shape[-1] == 3 + ), f"Expect [height, width, channels] format, got {images[0].shape}" + rollout_info[f"rollout_{rollout_idx}_vid"] = wandb.Video( + np.array(images).transpose(0, 3, 1, 2)[ + :: self.video_subsample_rate + ], + fps=self.vis_fps, + ) + rollout_info["avg_return"] = np.mean(rollout_info["episode_returns"]) + rollout_info["episode_returns"] = wandb.Histogram( + rollout_info["episode_returns"] + ) + metrics = rollout_info.pop(["episode_metrics"]) + for metric in metrics: + rollout_info[metric] = wandb.Histogram(metrics[metric]) + rollout_info[f"avg_{metric}"] = np.mean(metrics[metric]) + return rollout_info + + +def tree_map(fn: Callable, tree: dict) -> dict: + """Maps a function over a nested dictionary.""" + return { + k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items() + } + + def unnormalize(arr, mean, std, **kwargs): return arr * np.array(std) + np.array(mean) @@ -271,6 +387,9 @@ def add_unnormalized_info( "unnorm_pred_actions": unnormalize( info["pred_actions"], **normalization_stats["action"] ), + "unnorm_pred_actions_chunk": unnormalize( + info["pred_actions_chunk"], **normalization_stats["action"] + ), "unnorm_actions": unnormalize( info["actions"], **normalization_stats["action"] ), @@ -404,7 +523,7 @@ def __exit__(self, exc_type, exc_value, traceback): def plot_trajectory_overview_mpl( traj, - unnorm_pred_actions, + unnorm_pred_actions_chunk, unnorm_actions, unnorm_proprio, **info, @@ -420,18 +539,26 @@ def plot_trajectory_overview_mpl( for i in range(n_act_dims): ax = fig.add_subplot(gs[(i + 1) // grid_size, (i + 1) % grid_size]) ax.plot(unnorm_actions[:, i], label="action") - unnorm_pred_actions_i = unnorm_pred_actions[:, :, i] - x = np.tile( - np.arange(len(unnorm_pred_actions_i))[:, None], - (1, unnorm_pred_actions_i.shape[1]), - ) - ax.scatter( - x.flat[:], - unnorm_pred_actions_i.flat[:], - color="tab:red", - s=4, - alpha=0.5, - ) + # plot predicted action chunks, unnorm_pred_actions_chunk.shape = [time, n_samples, chunk, act_dim] + chunk_length = unnorm_pred_actions_chunk.shape[2] + for t in range(unnorm_pred_actions_chunk.shape[0]): + step_idx, chunk_idx = divmod(t, chunk_length) + unnorm_pred_actions_i = unnorm_pred_actions_chunk[ + int(step_idx * chunk_length), :, chunk_idx, i + ] + x = np.full((unnorm_pred_actions_i.shape[0],), t) + ax.scatter( + x.flat[:], + unnorm_pred_actions_i.flat[:], + color="tab:red", + s=4, + alpha=0.5, + ) + if ( + chunk_idx == 0 + and (unnorm_pred_actions_chunk.shape[0] // chunk_length) <= 20 + ): + ax.axvline(t, color="red", linestyle="--", alpha=0.2) ax.set_ylabel(f"dim {i}") fig.suptitle(traj["tasks"]["language_instruction"][0].decode("utf-8")) return wandb.Image(wandb_figure.image) diff --git a/train.py b/train.py index f60b0f5b..146aba8f 100644 --- a/train.py +++ b/train.py @@ -37,7 +37,7 @@ format_name_with_config, Timer, ) -from orca.utils.visualization_lib import Visualizer +from orca.utils.visualization_lib import RolloutVisualizer, Visualizer try: from jax_smi import initialise_tracking # type: ignore @@ -257,6 +257,25 @@ def process_text(batch): map(shard, map(process_text, val_data.iterator())) for val_data in val_datas ] + # optionally build visualizers for sim env evals + if FLAGS.config.get("rollout_envs", None): + rollout_visualizers = [] + for env_name, visualizer_kwargs in FLAGS.config["rollout_envs"]: + input_kwargs = dict( + env_name=env_name, + history_length=FLAGS.config["dataset_kwargs"]["transform_kwargs"][ + "window_size" + ], + action_chunk=FLAGS.config["model"]["heads"]["action"]["kwargs"].get( + "pred_horizon", 1 + ), + text_processor=text_processor, + ) + input_kwargs.update(visualizer_kwargs) + rollout_visualizers.append(RolloutVisualizer(**input_kwargs)) + else: + rollout_visualizers = None + example_batch = next(train_data_iter) logging.info(f"Batch size: {example_batch['action'].shape[0]}") logging.info(f"Number of devices: {jax.device_count()}") @@ -473,6 +492,8 @@ def get_actions(model, observations, tasks, train): ) return actions + # actions is (NUM_ACTIONS_FOR_VIS, batch_size, pred_horizon, action_dim) + # where actions[:, :, i] predicts the action at timestep "window_size + i" actions = state.apply_fn( {"params": state.params}, observations, @@ -482,11 +503,7 @@ def get_actions(model, observations, tasks, train): rngs={"dropout": state.rng}, ) # We could also have used run_head here, but this is easier to read - # actions is (NUM_ACTIONS_FOR_VIS, batch_size, pred_horizon, action_dim) - # where actions[:, :, i] predicts the action at timestep "window_size + i" - actions = actions[..., 0, :] - - # viz expects (batch_size, n_samples, action_dim) + # viz expects (batch_size, n_samples, pred_horizon, action_dim) actions = jnp.moveaxis(actions, 0, 1) return actions @@ -578,6 +595,22 @@ def wandb_log(info, step): }, step=i, ) + + if rollout_visualizers: + for rollout_visualizer in rollout_visualizers: + for mode, policy_fn in modal_policy_fns.items(): + logging.info("Running rollouts...") + rollout_infos = rollout_visualizer.run_rollouts( + policy_fn, n_rollouts=FLAGS.config.trajs_for_rollouts + ) + wandb_log( + { + f"rollouts_{rollout_visualizer.env_name}" + f"_chunk{rollout_visualizer.action_chunk}/{mode}": rollout_infos, + }, + step=i, + ) + timer.tock("visualize") if (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None: From 21f20fffc46e0390c24c00dacd32682c3c30c85f Mon Sep 17 00:00:00 2001 From: Karl Pertsch Date: Sun, 3 Dec 2023 14:40:41 -0800 Subject: [PATCH 2/4] cleanup --- experiments/homer/bridge/eval.py | 2 +- experiments/lucy/aloha_finetune_config.py | 2 -- experiments/lucy/aloha_scratch_config.py | 2 -- orca/utils/gym_wrappers.py | 3 ++- orca/utils/visualization_lib.py | 6 ++++-- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/experiments/homer/bridge/eval.py b/experiments/homer/bridge/eval.py index 02581a88..f119fb93 100644 --- a/experiments/homer/bridge/eval.py +++ b/experiments/homer/bridge/eval.py @@ -251,7 +251,7 @@ def main(_): ) env = HistoryWrapper(env, FLAGS.horizon) # env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon) - env = RHCWrapper(env, FLAGS.pred_horizon, FLAGS.exec_horizon) + env = RHCWrapper(env, FLAGS.exec_horizon) goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) goal_instruction = "" diff --git a/experiments/lucy/aloha_finetune_config.py b/experiments/lucy/aloha_finetune_config.py index 0efaaa08..c4ec72eb 100644 --- a/experiments/lucy/aloha_finetune_config.py +++ b/experiments/lucy/aloha_finetune_config.py @@ -157,7 +157,6 @@ def get_config(mode="full"): dict( max_episode_length=200, action_chunk=50, - vis_render_size=(320, 240), vis_fps=25, video_subsample_rate=2, norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", @@ -168,7 +167,6 @@ def get_config(mode="full"): dict( max_episode_length=200, action_chunk=30, - vis_render_size=(320, 240), vis_fps=25, video_subsample_rate=2, norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", diff --git a/experiments/lucy/aloha_scratch_config.py b/experiments/lucy/aloha_scratch_config.py index 4a34f2fa..121e7ad5 100644 --- a/experiments/lucy/aloha_scratch_config.py +++ b/experiments/lucy/aloha_scratch_config.py @@ -79,7 +79,6 @@ def get_config(config_string=None): dict( max_episode_length=200, action_chunk=50, - vis_render_size=(320, 240), vis_fps=25, video_subsample_rate=2, norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", @@ -90,7 +89,6 @@ def get_config(config_string=None): dict( max_episode_length=200, action_chunk=30, - vis_render_size=(320, 240), vis_fps=25, video_subsample_rate=2, norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", diff --git a/orca/utils/gym_wrappers.py b/orca/utils/gym_wrappers.py index e55c43d6..dc93338c 100644 --- a/orca/utils/gym_wrappers.py +++ b/orca/utils/gym_wrappers.py @@ -62,7 +62,7 @@ def __init__(self, env: gym.Env, horizon: int): self.history = deque(maxlen=self.horizon) self.num_obs = 0 - # self.observation_space = space_stack(self.env.observation_space, self.horizon) + self.observation_space = space_stack(self.env.observation_space, self.horizon) def step(self, action): obs, reward, done, trunc, info = self.env.step(action) @@ -107,6 +107,7 @@ def step(self, actions): if done or trunc: break + # pass through all infos, also return full observation and reward sequence in infos infos = listdict2dictlist(infos) infos["rewards"] = rewards infos["observations"] = observations diff --git a/orca/utils/visualization_lib.py b/orca/utils/visualization_lib.py index 2c8bd373..fc3fbe3a 100644 --- a/orca/utils/visualization_lib.py +++ b/orca/utils/visualization_lib.py @@ -275,6 +275,7 @@ class RolloutVisualizer: vis_fps (int): FPS of logged rollout video video_subsample_rate (int): Subsampling rate for video logging (to reduce video size for high-frequency control) norm_statistics_path (str, optional): Optional path to stats for de-normalizing policy actionsi & proprio. + text_processor (object, optional): Used to encode language instruction in task if not None. """ env_name: str @@ -300,7 +301,8 @@ def __post_init__(self): def run_rollouts(self, policy_fn, n_rollouts=10, n_vis_rollouts=3): def extract_images(obs): - return jnp.concatenate([obs[k] for k in obs if "image_" in k], axis=-2) + # obs has [window_size, ...] shape, only use first time step + return jnp.concatenate([obs[k][0] for k in obs if "image_" in k], axis=-2) def listdict2dictlist(LD): return {k: [dic[k] for dic in LD] for k in LD[0]} @@ -356,7 +358,7 @@ def listdict2dictlist(LD): rollout_info["episode_returns"] = wandb.Histogram( rollout_info["episode_returns"] ) - metrics = rollout_info.pop(["episode_metrics"]) + metrics = rollout_info.pop("episode_metrics") for metric in metrics: rollout_info[metric] = wandb.Histogram(metrics[metric]) rollout_info[f"avg_{metric}"] = np.mean(metrics[metric]) From 4f0d8d666033d65d601118ebc05796e52ca5fc9c Mon Sep 17 00:00:00 2001 From: Karl Pertsch Date: Mon, 4 Dec 2023 22:48:51 -0800 Subject: [PATCH 3/4] address comments, add episode metrics, temp ensembling --- experiments/homer/bridge/eval.py | 2 +- experiments/lucy/aloha_finetune_config.py | 78 +++++++++++++------ experiments/lucy/aloha_scratch_config.py | 4 +- experiments/lucy/aloha_wrapper.py | 16 +++- finetune.py | 55 ++++++++------ orca/utils/gym_wrappers.py | 7 +- orca/utils/visualization_lib.py | 92 ++++++++++++++--------- train.py | 4 +- 8 files changed, 165 insertions(+), 93 deletions(-) diff --git a/experiments/homer/bridge/eval.py b/experiments/homer/bridge/eval.py index f119fb93..02581a88 100644 --- a/experiments/homer/bridge/eval.py +++ b/experiments/homer/bridge/eval.py @@ -251,7 +251,7 @@ def main(_): ) env = HistoryWrapper(env, FLAGS.horizon) # env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon) - env = RHCWrapper(env, FLAGS.exec_horizon) + env = RHCWrapper(env, FLAGS.pred_horizon, FLAGS.exec_horizon) goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) goal_instruction = "" diff --git a/experiments/lucy/aloha_finetune_config.py b/experiments/lucy/aloha_finetune_config.py index c4ec72eb..ed604203 100644 --- a/experiments/lucy/aloha_finetune_config.py +++ b/experiments/lucy/aloha_finetune_config.py @@ -16,7 +16,7 @@ def update_config(config, **kwargs): @wrap -def get_config(mode="full"): +def get_config(mode="full", head="mse", augment="full", temp_ensembling="uniform", window_size=1, pred_horizon=50): # If starting with an ORCA-wrist model, there should be two image keys # first image key should be the third-person view # and second image key should be the wrist view @@ -51,11 +51,11 @@ def get_config(mode="full"): "heads_*.map_head.MultiHeadDotProductAttention_0.*", ) elif mode == "frozen_transformer": - frozen_keys = ("orca_transformer.BlockTransformer_0.*",) + frozen_keys = ("orca_transformer.BlockTransformer_0.*", "*hf_model*") else: raise ValueError("Invalid mode") - max_steps = FieldReference(20000) + max_steps = FieldReference(50000) config = dict( pretrained_path=placeholder(str), @@ -65,16 +65,18 @@ def get_config(mode="full"): num_val_batches=8, num_steps=max_steps, log_interval=100, - eval_interval=500, - save_interval=500, + eval_interval=1, #5000, + save_interval=1, #5000, save_dir="gs://karl-central-2", seed=42, + debug_sim=False, wandb=dict( project="orca_finetune", group=placeholder(str), entity=placeholder(str) ), finetuning_dataset=FINETUNING_KWARGS, modality=None, finetuning_mode=mode, + window_size=int(window_size), optimizer=dict( learning_rate=dict( init_value=0.0, @@ -89,9 +91,25 @@ def get_config(mode="full"): ), ) + goal_relabeling_strategy = "no_image_conditioning" + delete_key_groups_probs = [ + (["image_.*", "proprio"], 1.0), + ] + + if augment == "full": + augment_order = [ + "random_resized_crop", + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue" + ] + elif augment == "none": + augment_order = [] + transform_kwargs = dict( - window_size=1, - additional_action_window_size=49, + window_size=int(window_size), + additional_action_window_size=int(pred_horizon) - 1, resize_size=(256, 256), image_augment_kwargs=dict( random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), @@ -99,18 +117,22 @@ def get_config(mode="full"): random_contrast=[0.8, 1.2], random_saturation=[0.8, 1.2], random_hue=[0.1], - augment_order=[ - "random_resized_crop", - "random_brightness", - "random_contrast", - "random_saturation", - "random_hue", - ], + augment_order=augment_order, + #"random_resized_crop", + #"random_brightness", + #"random_contrast", + #"random_saturation", + #"random_hue", + #], ), goal_relabeling_strategy="uniform", action_encoding=ActionEncoding.JOINT_POS_BIMANUAL, # If the default data loading speed is too slow, try these: num_parallel_calls=16, # for the most CPU-intensive ops (decoding, resizing, augmenting) + task_augmentation_strategy="delete_task_conditioning", + task_augmentation_kwargs=dict( + delete_key_groups_probs=delete_key_groups_probs, + ), ) config["data_transforms"] = transform_kwargs @@ -122,13 +144,18 @@ def get_config(mode="full"): ) ) + if head == "mse": + cls_name = "mse_action_head" + elif head == "L1": + cls_name = "l1_action_head" + config["update_config"] = dict( model=dict( heads=dict( action=dict( - cls_name="mse_action_head", + cls_name=cls_name, kwargs=dict( - pred_horizon=50, + pred_horizon=int(pred_horizon), action_dim=14, vocab_size=256, normalization_type="normal", @@ -151,25 +178,32 @@ def get_config(mode="full"): ) ) + if temp_ensembling == "uniform": + use_temp_averaging = True + elif temp_ensembling == "none": + use_temp_averaging = False + config["rollout_envs"] = [ ( "aloha-sim-cube-v0", dict( - max_episode_length=200, - action_chunk=50, + max_episode_length=400, + action_chunk=int(pred_horizon), vis_fps=25, video_subsample_rate=2, - norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + use_temp_averaging=use_temp_averaging, ) ), ( "aloha-sim-cube-v0", dict( - max_episode_length=200, - action_chunk=30, + max_episode_length=400, + action_chunk=int(int(pred_horizon)/2), vis_fps=25, video_subsample_rate=2, - norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + use_temp_averaging=use_temp_averaging, ) ) ] diff --git a/experiments/lucy/aloha_scratch_config.py b/experiments/lucy/aloha_scratch_config.py index 121e7ad5..332ef05e 100644 --- a/experiments/lucy/aloha_scratch_config.py +++ b/experiments/lucy/aloha_scratch_config.py @@ -81,7 +81,7 @@ def get_config(config_string=None): action_chunk=50, vis_fps=25, video_subsample_rate=2, - norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", ) ), ( @@ -91,7 +91,7 @@ def get_config(config_string=None): action_chunk=30, vis_fps=25, video_subsample_rate=2, - norm_statistics_path="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", + norm_statistics="gs://rail-orca-central2/aloha_sim_cube_scripted_dataset/1.0.0/dataset_statistics_707801797899cdd91dcb18bd45463cf73ac935bfd6ac6b62456653e96f120a5f.json", ) ) ], diff --git a/experiments/lucy/aloha_wrapper.py b/experiments/lucy/aloha_wrapper.py index 4b65209b..c36d5d75 100644 --- a/experiments/lucy/aloha_wrapper.py +++ b/experiments/lucy/aloha_wrapper.py @@ -48,6 +48,9 @@ def step(self, action): reward = ts.reward info = {"images": images} + if reward == self._env.task.max_reward: + self._episode_is_success = 1 + return obs, reward, False, False, info def reset(self, **kwargs): @@ -62,7 +65,9 @@ def reset(self, **kwargs): ts = self._env.reset(**kwargs) obs, images = self.get_obs(ts) info = {"images": images} - self._goal_obs = obs # HACK + self._goal_obs = copy.deepcopy(obs) # HACK + + self._episode_is_success = 0 return obs, info @@ -77,7 +82,7 @@ def get_obs(self, ts): if cam_name == 'cam_high': curr_image = crop_resize(curr_image) - curr_image = cv2.cvtColor(curr_image, cv2.COLOR_BGR2RGB) + #curr_image = cv2.cvtColor(curr_image, cv2.COLOR_BGR2RGB) vis_images.append(copy.deepcopy(curr_image)) curr_image = jnp.array(curr_image) # XXX: / 255. ? curr_obs[f"image_{i}"] = curr_image @@ -94,7 +99,12 @@ def get_task(self): assert self._goal_obs, "Need to run reset() before!" return { "language_instruction": ["pick up the cube and hand it over".encode()], - **jax.tree_map(lambda x: x[None], self._goal_obs), + **jax.tree_map(lambda x: x[None] * 0, self._goal_obs), + } + + def get_episode_metrics(self): + return { + "success_rate": self._episode_is_success, } diff --git a/finetune.py b/finetune.py index 7cccf8c7..a079c2b4 100644 --- a/finetune.py +++ b/finetune.py @@ -194,7 +194,7 @@ def create_iterator(dataset): for env_name, visualizer_kwargs in FLAGS.config["rollout_envs"]: input_kwargs = dict( env_name=env_name, - history_length=FLAGS.config["data_transforms"]["window_size"], + history_length=FLAGS.config["window_size"], action_chunk=config["model"]["heads"]["action"]["kwargs"].get( "pred_horizon", 1 ), @@ -408,28 +408,32 @@ def wandb_log(info, step): ): timer.tick("total") - with timer("dataset"): - batch = next(train_data_iter) + if not FLAGS.config.debug_sim: + with timer("dataset"): + batch = next(train_data_iter) - with timer("train"): - train_state, update_info = train_step(train_state, batch) + with timer("train"): + train_state, update_info = train_step(train_state, batch) timer.tock("total") - if (i + 1) % FLAGS.config.log_interval == 0: + if not FLAGS.config.debug_sim and (i + 1) % FLAGS.config.log_interval == 0: update_info = jax.device_get(update_info) wandb_log( {"training": update_info, "timer": timer.get_average_times()}, step=i ) - if (i + 1) % FLAGS.config.eval_interval == 0: + if FLAGS.config.debug_sim or (i + 1) % FLAGS.config.eval_interval == 0: logging.info("Evaluating...") - with timer("val"): - metrics = [] - for _, batch in zip(range(FLAGS.config.num_val_batches), val_data_iter): - metrics.append(eval_step(train_state, batch)) - metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) - wandb_log({"validation": metrics}, step=i) + if not FLAGS.config.debug_sim: + with timer("val"): + metrics = [] + for _, batch in zip( + range(FLAGS.config.num_val_batches), val_data_iter + ): + metrics.append(eval_step(train_state, batch)) + metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) + wandb_log({"validation": metrics}, step=i) with timer("visualize"): policy_fn = batched_apply( @@ -439,23 +443,24 @@ def wandb_log(info, step): ), FLAGS.config.batch_size, ) - raw_infos = visualizer.raw_evaluations(policy_fn, max_trajs=100) - metrics = visualizer.metrics_for_wandb(raw_infos) - images = visualizer.visualize_for_wandb(policy_fn, max_trajs=8) - wandb_log( - { - "offline_metrics": metrics, - "visualizations": images, - }, - step=i, - ) + if not FLAGS.config.debug_sim: + raw_infos = visualizer.raw_evaluations(policy_fn, max_trajs=100) + metrics = visualizer.metrics_for_wandb(raw_infos) + images = visualizer.visualize_for_wandb(policy_fn, max_trajs=8) + wandb_log( + { + "offline_metrics": metrics, + "visualizations": images, + }, + step=i, + ) if rollout_visualizers: with timer("rollout"): for rollout_visualizer in rollout_visualizers: logging.info("Running rollouts...") rollout_infos = rollout_visualizer.run_rollouts( - policy_fn, n_rollouts=10 + policy_fn, n_rollouts=2 if FLAGS.config.debug_sim else 10 ) wandb_log( { @@ -465,7 +470,7 @@ def wandb_log(info, step): step=i, ) - if (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None: + if False: # (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None: logging.info("Saving checkpoint...") params_checkpointer.save( diff --git a/orca/utils/gym_wrappers.py b/orca/utils/gym_wrappers.py index dc93338c..d13f9ae2 100644 --- a/orca/utils/gym_wrappers.py +++ b/orca/utils/gym_wrappers.py @@ -93,6 +93,8 @@ def __init__(self, env: gym.Env, exec_horizon: int): self.exec_horizon = exec_horizon def step(self, actions): + if self.exec_horizon == 1 and len(actions.shape) == 1: + actions = actions[None] assert len(actions) >= self.exec_horizon rewards = [] observations = [] @@ -107,7 +109,6 @@ def step(self, actions): if done or trunc: break - # pass through all infos, also return full observation and reward sequence in infos infos = listdict2dictlist(infos) infos["rewards"] = rewards infos["observations"] = observations @@ -132,9 +133,9 @@ def __init__(self, env: gym.Env, pred_horizon: int, exp_weight: int = 0): self.action_space = space_stack(self.env.action_space, self.pred_horizon) def step(self, actions): - assert len(actions) == self.pred_horizon + assert len(actions) >= self.pred_horizon - self.act_history.append(actions) + self.act_history.append(actions[: self.pred_horizon]) num_actions = len(self.act_history) # select the predicted action for the current step from the history of action chunk predictions diff --git a/orca/utils/visualization_lib.py b/orca/utils/visualization_lib.py index fc3fbe3a..a4ac3c97 100644 --- a/orca/utils/visualization_lib.py +++ b/orca/utils/visualization_lib.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from functools import reduce import json -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, Union import dlimp as dl import flax @@ -22,7 +22,12 @@ import wandb from orca.data.utils.data_utils import ActionEncoding -from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio +from orca.utils.gym_wrappers import ( + HistoryWrapper, + RHCWrapper, + TemporalEnsembleWrapper, + UnnormalizeActionProprio, +) BASE_METRIC_KEYS = { "mse": ("mse", tuple()), # What is the MSE @@ -198,8 +203,13 @@ def visualize_for_wandb( plotly_fig = plot_trajectory_actions(**info) visualizations[f"traj_{n}"] = plotly_fig - mpl_fig = plot_trajectory_overview_mpl(traj, **info) - visualizations[f"traj_{n}_mpl"] = mpl_fig + # plot qualitative action trajectory per dimension w/ and w/o action chunk + visualizations[f"traj_{n}_mpl"] = plot_trajectory_overview_mpl( + traj, act=info["unnorm_pred_actions_chunk"][:, :, :1], **info + ) + visualizations[f"traj_{n}_mpl_chunk"] = plot_trajectory_overview_mpl( + traj, act=info["unnorm_pred_actions_chunk"], **info + ) if ( add_images or not self.cache_viz_trajectories @@ -274,8 +284,8 @@ class RolloutVisualizer: max_episode_length (int): Max number of steps per rollout episode. vis_fps (int): FPS of logged rollout video video_subsample_rate (int): Subsampling rate for video logging (to reduce video size for high-frequency control) - norm_statistics_path (str, optional): Optional path to stats for de-normalizing policy actionsi & proprio. - text_processor (object, optional): Used to encode language instruction in task if not None. + norm_statistics (Union[str, dict], optional): Stats for de-normalizing policy actions & proprio. + use_temporal_averaging (bool): If true, uses temporal averaging of action chunks during rollout. """ env_name: str @@ -284,17 +294,27 @@ class RolloutVisualizer: max_episode_length: int vis_fps: int = 10 video_subsample_rate: int = 1 - norm_statistics_path: Optional[str] = None + norm_statistics: Optional[Union[str, Dict[str, Any]]] = None text_processor: object = None + use_temp_averaging: bool = False def __post_init__(self): self._env = gym.make(self.env_name) self._env = HistoryWrapper(self._env, self.history_length) - self._env = RHCWrapper(self._env, self.action_chunk) - if self.norm_statistics_path: - with tf.io.gfile.GFile(self.norm_statistics_path, "r") as f: - norm_stats = json.load(f) - norm_stats = tree_map(np.array, norm_stats) + if self.use_temp_averaging: + self._env = RHCWrapper(self._env, 1) + self._env = TemporalEnsembleWrapper(self._env, self.action_chunk) + else: + self._env = RHCWrapper(self._env, self.action_chunk) + if self.norm_statistics: + if isinstance(self.norm_statistics, str): + with tf.io.gfile.GFile(self.norm_statistics, "r") as f: + norm_stats = json.load(f) + norm_stats = jax.tree_map( + lambda x: np.array(x), + norm_stats, + is_leaf=lambda x: not isinstance(x, dict), + ) self._env = UnnormalizeActionProprio( self._env, norm_stats, normalization_type="normal" ) @@ -314,14 +334,19 @@ def listdict2dictlist(LD): for rollout_idx in tqdm.tqdm(range(n_rollouts)): obs, info = self._env.reset() task = self._env.get_task() - if "language_instruction" in task and self.text_processor: - task["language_instruction"] = self.text_processor.encode( - [s.decode("utf-8") for s in task["language_instruction"]] - ) + if jax.tree_util.tree_leaves(task)[0].shape[0] != 1: + task = jax.tree_map(lambda x: x[None], task) + if "language_instruction" in task: + if self.text_processor: + task["language_instruction"] = self.text_processor.encode( + [s.decode("utf-8") for s in task["language_instruction"]] + ) + else: + task.pop("language_instruction") images = [extract_images(obs)] episode_return = 0.0 metrics = [] - for _ in range(self.max_episode_length // self.action_chunk): + while len(images) < self.max_episode_length: # policy outputs are shape [batch, n_samples, pred_horizon, act_dim] # we remove batch dimension & use first sampled action, ignoring other samples actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task)[0, 0] @@ -340,6 +365,15 @@ def listdict2dictlist(LD): rollout_info["episode_metrics"].append( jax.tree_map(lambda x: np.mean(x), metrics) ) + if hasattr(self._env, "get_episode_metrics"): + if metrics: + rollout_info["episode_metrics"][-1].update( + self._env.get_episode_metrics() + ) + else: + rollout_info["episode_metrics"].append( + self._env.get_episode_metrics() + ) if rollout_idx < n_vis_rollouts: # save rollout video assert ( @@ -358,20 +392,13 @@ def listdict2dictlist(LD): rollout_info["episode_returns"] = wandb.Histogram( rollout_info["episode_returns"] ) - metrics = rollout_info.pop("episode_metrics") + metrics = listdict2dictlist(rollout_info.pop("episode_metrics")) for metric in metrics: rollout_info[metric] = wandb.Histogram(metrics[metric]) rollout_info[f"avg_{metric}"] = np.mean(metrics[metric]) return rollout_info -def tree_map(fn: Callable, tree: dict) -> dict: - """Maps a function over a nested dictionary.""" - return { - k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items() - } - - def unnormalize(arr, mean, std, **kwargs): return arr * np.array(std) + np.array(mean) @@ -525,7 +552,7 @@ def __exit__(self, exc_type, exc_value, traceback): def plot_trajectory_overview_mpl( traj, - unnorm_pred_actions_chunk, + act, unnorm_actions, unnorm_proprio, **info, @@ -541,11 +568,11 @@ def plot_trajectory_overview_mpl( for i in range(n_act_dims): ax = fig.add_subplot(gs[(i + 1) // grid_size, (i + 1) % grid_size]) ax.plot(unnorm_actions[:, i], label="action") - # plot predicted action chunks, unnorm_pred_actions_chunk.shape = [time, n_samples, chunk, act_dim] - chunk_length = unnorm_pred_actions_chunk.shape[2] - for t in range(unnorm_pred_actions_chunk.shape[0]): + # plot predicted action chunks, act.shape = [time, n_samples, chunk, act_dim] + chunk_length = act.shape[2] + for t in range(act.shape[0]): step_idx, chunk_idx = divmod(t, chunk_length) - unnorm_pred_actions_i = unnorm_pred_actions_chunk[ + unnorm_pred_actions_i = act[ int(step_idx * chunk_length), :, chunk_idx, i ] x = np.full((unnorm_pred_actions_i.shape[0],), t) @@ -556,10 +583,7 @@ def plot_trajectory_overview_mpl( s=4, alpha=0.5, ) - if ( - chunk_idx == 0 - and (unnorm_pred_actions_chunk.shape[0] // chunk_length) <= 20 - ): + if chunk_idx == 0 and (act.shape[0] // chunk_length) <= 20: ax.axvline(t, color="red", linestyle="--", alpha=0.2) ax.set_ylabel(f"dim {i}") fig.suptitle(traj["tasks"]["language_instruction"][0].decode("utf-8")) diff --git a/train.py b/train.py index 146aba8f..d90d9533 100644 --- a/train.py +++ b/train.py @@ -263,9 +263,7 @@ def process_text(batch): for env_name, visualizer_kwargs in FLAGS.config["rollout_envs"]: input_kwargs = dict( env_name=env_name, - history_length=FLAGS.config["dataset_kwargs"]["transform_kwargs"][ - "window_size" - ], + history_length=FLAGS.config["window_size"], action_chunk=FLAGS.config["model"]["heads"]["action"]["kwargs"].get( "pred_horizon", 1 ), From 86e22974e1719bf19a6a5ae563b01169afb3c1ed Mon Sep 17 00:00:00 2001 From: Karl Pertsch Date: Mon, 4 Dec 2023 22:53:22 -0800 Subject: [PATCH 4/4] fix finetune.py --- finetune.py | 53 ++++++++++++++++++++++++----------------------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/finetune.py b/finetune.py index a079c2b4..b50b8c2a 100644 --- a/finetune.py +++ b/finetune.py @@ -408,32 +408,28 @@ def wandb_log(info, step): ): timer.tick("total") - if not FLAGS.config.debug_sim: - with timer("dataset"): - batch = next(train_data_iter) + with timer("dataset"): + batch = next(train_data_iter) - with timer("train"): - train_state, update_info = train_step(train_state, batch) + with timer("train"): + train_state, update_info = train_step(train_state, batch) timer.tock("total") - if not FLAGS.config.debug_sim and (i + 1) % FLAGS.config.log_interval == 0: + if (i + 1) % FLAGS.config.log_interval == 0: update_info = jax.device_get(update_info) wandb_log( {"training": update_info, "timer": timer.get_average_times()}, step=i ) - if FLAGS.config.debug_sim or (i + 1) % FLAGS.config.eval_interval == 0: + if (i + 1) % FLAGS.config.eval_interval == 0: logging.info("Evaluating...") - if not FLAGS.config.debug_sim: - with timer("val"): - metrics = [] - for _, batch in zip( - range(FLAGS.config.num_val_batches), val_data_iter - ): - metrics.append(eval_step(train_state, batch)) - metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) - wandb_log({"validation": metrics}, step=i) + with timer("val"): + metrics = [] + for _, batch in zip(range(FLAGS.config.num_val_batches), val_data_iter): + metrics.append(eval_step(train_state, batch)) + metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) + wandb_log({"validation": metrics}, step=i) with timer("visualize"): policy_fn = batched_apply( @@ -443,24 +439,23 @@ def wandb_log(info, step): ), FLAGS.config.batch_size, ) - if not FLAGS.config.debug_sim: - raw_infos = visualizer.raw_evaluations(policy_fn, max_trajs=100) - metrics = visualizer.metrics_for_wandb(raw_infos) - images = visualizer.visualize_for_wandb(policy_fn, max_trajs=8) - wandb_log( - { - "offline_metrics": metrics, - "visualizations": images, - }, - step=i, - ) + raw_infos = visualizer.raw_evaluations(policy_fn, max_trajs=100) + metrics = visualizer.metrics_for_wandb(raw_infos) + images = visualizer.visualize_for_wandb(policy_fn, max_trajs=8) + wandb_log( + { + "offline_metrics": metrics, + "visualizations": images, + }, + step=i, + ) if rollout_visualizers: with timer("rollout"): for rollout_visualizer in rollout_visualizers: logging.info("Running rollouts...") rollout_infos = rollout_visualizer.run_rollouts( - policy_fn, n_rollouts=2 if FLAGS.config.debug_sim else 10 + policy_fn, n_rollouts=10 ) wandb_log( { @@ -470,7 +465,7 @@ def wandb_log(info, step): step=i, ) - if False: # (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None: + if (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None: logging.info("Saving checkpoint...") params_checkpointer.save(