Skip to content

Commit

Permalink
Merge branch 'main' into sim_eval_loop
Browse files Browse the repository at this point in the history
  • Loading branch information
kpertsch committed Dec 5, 2023
2 parents 86e2297 + f2923d0 commit 57910b1
Show file tree
Hide file tree
Showing 50 changed files with 140,630 additions and 690 deletions.
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ We tokenize **task definitions** (like language instructions or goals), **observ
and **actions**. Given the sequence of input tokens, the model is trained to predict the action tokens.

## Installation
```
```bash
conda create -n orca python=3.10
conda activate orca
pip install -e .
pip install -r requirements.txt
```
For GPU:
```
pip install --upgrade "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```bash
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

For TPU
```
pip install --upgrade "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax.

Test the installation by training on the debug dataset:
```
```bash
python train.py --config tests/debug_config.py --debug
```

Expand Down Expand Up @@ -68,6 +68,17 @@ python train.py --config config.py:vit_s --name=orca --config.dataset_kwargs.oxe
| Encoders | [tokenizers.py](orca/model/components/tokenizers.py) | Tokenizers that encode image / text inputs into tokens. |
| Model + Objective | [orca_policy.py](orca/model/orca_policy.py) | Sort tokens into sequence, run forward pass, compute loss. |
| Visualization | [visualization_lib.py](orca/utils/visualization_lib.py) | Utilities for offline qualitative & quantitative eval. |
| Sim Evaluation | [sim_eval.sh](orca/scripts/sim_eval.sh) | Script to run model evaluation. |

## Run Evaluation in Simulation

To run evaluation on a trained model, you can use the following command:
```bash
# requires pybullet and kinpy to be installed
bash scripts/sim_eval.sh
```

This will spawn a `pybullet` environment with a WidowX robot. Since this environment is not used in training, this is purely for testing the pipeline.

## Contributing
Experimental things and training/eval scripts should go in `experiments/<your_name>`. To make any changes to files outside of your experiments directory, please open a pull request.
Expand Down
33 changes: 19 additions & 14 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ def get_config(
clip_gradient=1.0,
frozen_keys=tuple(),
),
batch_size=1024,
eval_batch_size=128,
shuffle_buffer_size=100000,
prefetch_num_batches=0,
val_shuffle_buffer_size=1000,
num_val_batches=16,
start_step=placeholder(int),
log_interval=100,
eval_interval=5000,
save_interval=5000,
viz_interval=20000,
save_interval=10000,
trajs_for_metrics=100,
trajs_for_viz=8,
resume_path=placeholder(str),
Expand Down Expand Up @@ -114,7 +114,7 @@ def get_dataset_config(modality="multimodal", window_size=1):
raise ValueError(f"Unknown modality {modality}")

return {
# oxe_kwargs will generate data_kwargs_list and sampling weights
# oxe_kwargs will generate dataset_kwargs_list and sampling weights
"oxe_kwargs": dict(
data_mix=placeholder(str),
# for v4 TPUs: "gs://rail-orca-central2/resize_336_336"
Expand All @@ -123,18 +123,19 @@ def get_dataset_config(modality="multimodal", window_size=1):
n_wrist_cameras=0,
load_depth=False,
),
# common_kwargs override specific kwargs from data_kwargs_list
"common_kwargs": dict(
ram_budget=1, # limit RAM per dataset
num_parallel_reads=8, # for reading from GCS
num_parallel_calls=16, # for the less CPU-intensive ops in initial dataset construction
# common_dataset_kwargs override specific kwargs from dataset_kwargs_list
"common_dataset_kwargs": dict(
action_proprio_normalization_type=normalization_type,
),
"transform_kwargs": dict(
resize_size=(256, 256),
num_parallel_calls=32, # for the most CPU-intensive ops (decoding, resizing, augmenting)
"traj_transform_kwargs": dict(
window_size=window_size,
additional_action_window_size=0,
goal_relabeling_strategy="uniform",
subsample_length=100,
**task_augmentation,
),
"frame_transform_kwargs": dict(
resize_size=(256, 256),
image_augment_kwargs=dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.2],
Expand All @@ -149,9 +150,13 @@ def get_dataset_config(modality="multimodal", window_size=1):
"random_hue",
],
),
goal_relabeling_strategy="uniform",
**task_augmentation,
),
"traj_transform_threads": 48, # shared between all datasets
"traj_read_threads": 48, # shared between all datasets
"frame_transform_threads": 200, # not shared between datasets
"shuffle_buffer_size": 100000, # shared between all datasets
"batch_size": 1024,
"balance_weights": True,
}


Expand Down
29 changes: 18 additions & 11 deletions experiments/dibya/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def get_config(
"action_encoding": ActionEncoding.EEF_POS,
"action_proprio_normalization_type": "normal",
# If the default data loading speed is too slow, try these:
# and "num_parallel_calls" in `transform_kwargs` below
# "num_parallel_reads": 8, # for reading from disk / GCS
# "num_parallel_calls": 16, # for initial dataset construction
}
Expand Down Expand Up @@ -70,7 +69,7 @@ def get_config(
wandb=dict(
project="orca_finetune", group=placeholder(str), entity=placeholder(str)
),
finetuning_dataset=FINETUNING_KWARGS,
dataset_kwargs=FINETUNING_KWARGS,
modality=task,
finetuning_mode=mode,
window_size=window_size,
Expand Down Expand Up @@ -107,9 +106,18 @@ def get_config(
else:
raise ValueError("Invalid modality")

transform_kwargs = dict(
traj_transform_kwargs = dict(
window_size=window_size,
additional_action_window_size=0,
goal_relabeling_strategy=goal_relabeling_strategy,
task_augmentation_strategy="delete_task_conditioning",
task_augmentation_kwargs=dict(
delete_key_groups_probs=delete_key_groups_probs,
),
# If the default data loading speed is too slow, try these:
# num_parallel_calls=16, # for less CPU-intensive ops
)
frame_transform_kwargs = dict(
resize_size=(256, 256),
image_augment_kwargs=dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
Expand All @@ -125,13 +133,12 @@ def get_config(
"random_hue",
],
),
goal_relabeling_strategy=goal_relabeling_strategy,
task_augmentation_strategy="delete_task_conditioning",
task_augmentation_kwargs=dict(
delete_key_groups_probs=delete_key_groups_probs,
),
# If the default data loading speed is too slow, try these:
# num_parallel_calls=16, # for the most CPU-intensive ops (decoding, resizing, augmenting)
)
config["data_transforms"] = transform_kwargs
# If the default data loading speed is too slow, try these:
config[
"frame_transform_threads"
] = 16 # for the most CPU-intensive ops (decoding, resizing, augmenting)

config["traj_transform_kwargs"] = traj_transform_kwargs
config["frame_transform_kwargs"] = frame_transform_kwargs
return ConfigDict(config)
Loading

0 comments on commit 57910b1

Please sign in to comment.