Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update with many small changes/fixes #92

Merged
merged 16 commits into from
May 23, 2024
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ To evaluate on your own environment, simply wrap it in a Gym interface and follo
| Visualization | [visualization_lib.py](octo/utils/visualization_lib.py) | Utilities for offline qualitative & quantitative eval. |

## FAQ
#### What is the `pad_mask` in the observation dictionary?
The `pad_mask` indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set `pad_mask=False` at the corresponding index. If you use Octo with a window size of 1, pad_mask should always just be `[True]`, indicating that the one and only observation in the window should be attended to. Note that if you wrap your robot environment with the `HistoryWrapper` (see [gym_wrappers.py](octo/utils/gym_wrappers.py)), the `pad_mask` key will be added to the observation dictionary for you.
#### What is the `timestep_pad_mask` in the observation dictionary?
The `timestep_pad_mask` indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set `timestep_pad_mask=False` at the corresponding index. If you use Octo with a window size of 1, `timestep_pad_mask` should always just be `[True]`, indicating that the one and only observation in the window should be attended to. Note that if you wrap your robot environment with the `HistoryWrapper` (see [gym_wrappers.py](octo/utils/gym_wrappers.py)), the `timestep_pad_mask` key will be added to the observation dictionary for you.
#### What is `pad_mask_dict` in the observation dictionary?
While `pad_mask` indicates which observations should be attended to on a timestep level, `pad_mask_dict` indicates which elements of the observation should be attended to within a single timestep. For example, for datasets without language labels, `pad_mask_dict["language_instruction"]` is set to `False`. For datasets without a wrist camera, `pad_mask_dict["image_wrist"]` is set to `False`. For convenience, if a key is missing from the observation dict, it is equivalent to setting `pad_mask_dict` to `False` for that key.
#### Does `model.sample_actions([...])` return the full trajectory to solve a task?
Expand Down
144 changes: 20 additions & 124 deletions examples/01_inference_pretrained.ipynb

Large diffs are not rendered by default.

16 changes: 7 additions & 9 deletions examples/02_finetune_new_observation_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import wandb

from octo.data.dataset import make_single_dataset
from octo.data.utils.data_utils import NormalizationType
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
Expand Down Expand Up @@ -70,14 +69,12 @@ def main(_):
name="aloha_sim_cube_scripted_dataset",
data_dir=FLAGS.data_dir,
image_obs_keys={"primary": "top"},
state_obs_keys=["state"],
proprio_obs_key="state",
language_key="language_instruction",
action_proprio_normalization_type=NormalizationType.NORMAL,
absolute_action_mask=[True] * 14,
),
traj_transform_kwargs=dict(
window_size=1,
future_action_window_size=49, # so we get 50 actions for our action chunk
action_horizon=50,
),
frame_transform_kwargs=dict(
resize_size={"primary": (256, 256)},
Expand Down Expand Up @@ -116,10 +113,10 @@ def process_batch(batch):
high=2.0,
obs_keys=["proprio"],
)
# Fully override the old action head with a new one (for smaller changes, you can use update_module_config)
# Fully override the old action head with a new one (for smaller changes, you can use update_config)
config["model"]["heads"]["action"] = ModuleSpec.create(
L1ActionHead,
pred_horizon=50,
action_horizon=50,
action_dim=14,
readout_key="readout_action",
)
Expand Down Expand Up @@ -162,13 +159,14 @@ def loss_fn(params, batch, rng, train=True):
transformer_embeddings = bound_module.octo_transformer(
batch["observation"],
batch["task"],
batch["observation"]["pad_mask"],
batch["observation"]["timestep_pad_mask"],
train=train,
)
action_loss, action_metrics = bound_module.heads["action"].loss(
transformer_embeddings, # Action head knows to pull out the action readout_key
batch["action"],
pad_mask=batch["observation"]["pad_mask"],
batch["observation"]["timestep_pad_mask"],
batch["action_pad_mask"],
train=train,
)
return action_loss, action_metrics
Expand Down
30 changes: 16 additions & 14 deletions examples/03_eval_finetuned.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This script demonstrates how to load and rollout a finetuned Octo model.
We use the Octo model finetuned on ALOHA sim data from the examples/finetune_new_observation_action.py script.
We use the Octo model finetuned on ALOHA sim data from the examples/02_finetune_new_observation_action.py script.

For installing the ALOHA sim environment, clone: https://github.com/tonyzhaozh/act
Then run:
Expand All @@ -15,6 +15,7 @@
cd examples
python3 03_eval_finetuned.py --finetuned_path=<path_to_finetuned_aloha_checkpoint>
"""
from functools import partial
import sys

from absl import app, flags, logging
Expand All @@ -23,12 +24,15 @@
import numpy as np
import wandb

sys.path.append("path/to/your/act")
# sys.path.append("path/to/your/act")
sys.path.append("/nfs/nfs2/users/homer/act")

from envs.aloha_sim_env import AlohaGymEnv # keep this to register ALOHA sim env
# keep this to register ALOHA sim env
from envs.aloha_sim_env import AlohaGymEnv # noqa

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio
from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper
from octo.utils.train_callbacks import supply_rng

FLAGS = flags.FLAGS

Expand All @@ -49,15 +53,13 @@ def main(_):
##################################################################################################################
# environment needs to implement standard gym interface + return observations of the following form:
# obs = {
# "image_0": ...
# "image_1": ...
# "image_primary": ...
# }
# it should also implement an env.get_task() function that returns a task dict with goal and/or language instruct.
# task = {
# "language_instruction": "some string"
# "goal": {
# "image_0": ...
# "image_1": ...
# "image_primary": ...
# }
# }
##################################################################################################################
Expand All @@ -67,9 +69,11 @@ def main(_):
env = HistoryWrapper(env, horizon=1)
env = RHCWrapper(env, exec_horizon=50)

# wrap env to handle action/proprio normalization -- match normalization type to the one used during finetuning
env = UnnormalizeActionProprio(
env, model.dataset_statistics, normalization_type="normal"
policy_fn = supply_rng(
partial(
model.sample_actions,
unnormalization_statistics=model.dataset_statistics["action"],
),
)

# running rollouts
Expand All @@ -85,9 +89,7 @@ def main(_):
episode_return = 0.0
while len(images) < 400:
# model returns actions of shape [batch, pred_horizon, action_dim] -- remove batch
actions = model.sample_actions(
jax.tree_map(lambda x: x[None], obs), task, rng=jax.random.PRNGKey(0)
)
actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task)
actions = actions[0]

# step env -- info contains full "chunk" of observations for logging
Expand Down
48 changes: 19 additions & 29 deletions examples/04_eval_finetuned_on_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import (
HistoryWrapper,
TemporalEnsembleWrapper,
UnnormalizeActionProprio,
)
from octo.utils.gym_wrappers import HistoryWrapper, TemporalEnsembleWrapper
from octo.utils.train_callbacks import supply_rng

np.set_printoptions(suppress=True)

Expand All @@ -50,9 +47,10 @@
flags.DEFINE_integer("im_size", None, "Image size", required=True)
flags.DEFINE_string("video_save_path", None, "Path to save video")
flags.DEFINE_integer("num_timesteps", 120, "num timesteps")
flags.DEFINE_integer("horizon", 1, "Observation history length")
flags.DEFINE_integer("pred_horizon", 1, "Length of action sequence from model")
flags.DEFINE_integer("exec_horizon", 1, "Length of action sequence to execute")
flags.DEFINE_integer("window_size", 2, "Observation history length")
flags.DEFINE_integer(
"action_horizon", 4, "Length of action sequence to execute/ensemble"
)


# show image flag
Expand All @@ -64,10 +62,9 @@
Bridge data was collected with non-blocking control and a step duration of 0.2s.
However, we relabel the actions to make it look like the data was collected with
blocking control and we evaluate with blocking control.
We also use a step duration of 0.4s to reduce the jerkiness of the policy.
Be sure to change the step duration back to 0.2 if evaluating with non-blocking control.
Be sure to use a step duration of 0.2 if evaluating with non-blocking control.
"""
STEP_DURATION = 0.4
STEP_DURATION = 0.2
STICKY_GRIPPER_NUM_STEPS = 1
WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]
CAMERA_TOPICS = [{"name": "/blue/image_raw"}]
Expand Down Expand Up @@ -107,16 +104,12 @@ def main(_):
)

# wrap the robot environment
env = UnnormalizeActionProprio(
env, model.dataset_statistics["bridge_dataset"], normalization_type="normal"
)
env = HistoryWrapper(env, FLAGS.horizon)
env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon)
env = HistoryWrapper(env, FLAGS.window_size)
env = TemporalEnsembleWrapper(env, FLAGS.action_horizon)
# switch TemporalEnsembleWrapper with RHCWrapper for receding horizon control
# env = RHCWrapper(env, FLAGS.exec_horizon)
# env = RHCWrapper(env, FLAGS.action_horizon)

# create policy function
@jax.jit
# create policy functions
def sample_actions(
pretrained_model: OctoModel,
observations,
Expand All @@ -129,22 +122,19 @@ def sample_actions(
observations,
tasks,
rng=rng,
unnormalization_statistics=pretrained_model.dataset_statistics[
"bridge_dataset"
]["action"],
)
# remove batch dim
return actions[0]

def supply_rng(f, rng=jax.random.PRNGKey(0)):
def wrapped(*args, **kwargs):
nonlocal rng
rng, key = jax.random.split(rng)
return f(*args, rng=key, **kwargs)

return wrapped

policy_fn = supply_rng(
partial(
policy_fn = partial(
supply_rng(
sample_actions,
model,
argmax=FLAGS.deterministic,
temperature=FLAGS.temperature,
)
)

Expand Down
271 changes: 18 additions & 253 deletions examples/05_dataloading.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/06_pytorch_oxe_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __len__(self):
traj_transform_kwargs=dict(
goal_relabeling_strategy="uniform",
window_size=2,
future_action_window_size=3,
action_horizon=4,
subsample_length=100,
),
frame_transform_kwargs=dict(
Expand Down
2 changes: 0 additions & 2 deletions examples/envs/widowx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ def convert_obs(obs, im_size):
# NOTE: assume image_1 is not available
return {
"image_primary": image_obs,
"proprio": proprio,
}


def null_obs(img_size):
return {
"image_primary": np.zeros((img_size, img_size, 3), dtype=np.uint8),
"proprio": np.zeros((8,), dtype=np.float64),
}


Expand Down
Loading
Loading