diff --git a/octo/data/dataset.py b/octo/data/dataset.py index a549c2c0..28f89fcd 100644 --- a/octo/data/dataset.py +++ b/octo/data/dataset.py @@ -140,6 +140,8 @@ def apply_frame_transforms( image_augment_kwargs: Union[dict, Mapping[str, dict]] = {}, resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]] = {}, depth_resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]] = {}, + image_dropout_prob: float = 0.0, + image_dropout_keep_key: Optional[str] = None, num_parallel_calls: int = tf.data.AUTOTUNE, ) -> dl.DLataset: """Applies common transforms that happen at a frame level. These transforms are usually more @@ -159,6 +161,10 @@ def apply_frame_transforms( keys (so pass an empty dict to skip resizing for all images). depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth images. + image_dropout_prob (float): Probability of dropping out images, applied to each image key + independently. At least one image will always be present. + image_dropout_keep_key (str, optional): Optionally provide a key to always keep during image dropout + for example for image observations that are essential for action prediction. num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE. """ @@ -186,14 +192,22 @@ def apply_obs_transform(fn: Callable[[dict], dict], frame: dict) -> dict: if train: # augment all images with the same seed, skipping padding images - def aug(frame: dict): + def aug_and_dropout(frame: dict): seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32) + dropout_fn = partial( + obs_transforms.image_dropout, + seed=seed, + dropout_prob=image_dropout_prob, + always_keep_key=image_dropout_keep_key, + ) aug_fn = partial( obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs ) - return apply_obs_transform(aug_fn, frame) + frame = apply_obs_transform(dropout_fn, frame) + frame = apply_obs_transform(aug_fn, frame) + return frame - dataset = dataset.frame_map(aug, num_parallel_calls) + dataset = dataset.frame_map(aug_and_dropout, num_parallel_calls) return dataset diff --git a/octo/data/obs_transforms.py b/octo/data/obs_transforms.py index efb49942..97a28b41 100644 --- a/octo/data/obs_transforms.py +++ b/octo/data/obs_transforms.py @@ -2,7 +2,7 @@ Contains observation-level transforms used in the octo data pipeline. These transforms operate on the "observation" dictionary, and are applied at a per-frame level. """ -from typing import Mapping, Tuple, Union +from typing import Mapping, Optional, Tuple, Union from absl import logging import dlimp as dl @@ -39,6 +39,61 @@ def augment( return obs +def image_dropout( + obs: dict, + seed: tf.Tensor, + dropout_prob: float, + always_keep_key: Optional[str] = None, +) -> dict: + """Independently drops out image keys, each with probability `dropout_prob`, but always keeps at least one + image present. + """ + image_keys = [key for key in obs if key.startswith("image_")] + if not image_keys: + return obs + pad_mask = tf.stack([obs["pad_mask_dict"][key] for key in image_keys]) + # if any non-padding images exist, pick one of them to keep no matter what + shuffle_seed, seed = tf.unstack(tf.random.split(seed)) + + if always_keep_key: + assert ( + always_keep_key in image_keys + ), f"Specified always_keep_key {always_keep_key} not present in image_keys: {image_keys} during dropout." + always_keep_index = tf.constant( + image_keys.index(always_keep_key), dtype=tf.int64 + ) + else: + always_keep_index = tf.cond( + tf.reduce_any(pad_mask), + # pick a random index from the non-padding images + lambda: tf.random.experimental.stateless_shuffle( + tf.where(pad_mask)[:, 0], seed=shuffle_seed + )[0], + # all images are padding, so it doesn't matter + lambda: tf.constant(0, dtype=tf.int64), + ) + + # drop images independently, except for the one at always_keep_index + rands = tf.random.stateless_uniform([len(image_keys)], seed=seed) + pad_mask = tf.logical_and( + pad_mask, + tf.logical_or( + tf.range(len(image_keys), dtype=tf.int64) == always_keep_index, + rands > dropout_prob, + ), + ) + + # perform the dropout and update pad_mask_dict + for i, key in enumerate(image_keys): + obs["pad_mask_dict"][key] = pad_mask[i] + obs[key] = tf.cond( + pad_mask[i], + lambda: obs[key], + lambda: tf.zeros_like(obs[key]), + ) + return obs + + def decode_and_resize( obs: dict, resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]], diff --git a/scripts/configs/config.py b/scripts/configs/config.py index 1d74b74c..8185956f 100644 --- a/scripts/configs/config.py +++ b/scripts/configs/config.py @@ -135,6 +135,7 @@ def get_dataset_config(window_size=1): ), "frame_transform_kwargs": dict( resize_size=(256, 256), + image_dropout_prob=0.0, image_augment_kwargs=dict( random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), random_brightness=[0.2], diff --git a/scripts/configs/octo_pretrain_config.py b/scripts/configs/octo_pretrain_config.py index 39e47495..19e3622e 100644 --- a/scripts/configs/octo_pretrain_config.py +++ b/scripts/configs/octo_pretrain_config.py @@ -123,6 +123,9 @@ def get_config(config_string=None): rephrase_prob=0.5, ), ), + frame_transform_kwargs=dict( + image_dropout_prob=0.5, + ), batch_size=128, shuffle_buffer_size=500000, balance_weights=True,