diff --git a/docs/assets/teaser.jpg b/docs/assets/teaser.jpg index a5d2056c..199cc12c 100644 Binary files a/docs/assets/teaser.jpg and b/docs/assets/teaser.jpg differ diff --git a/examples/01_inference_pretrained.ipynb b/examples/01_inference_pretrained.ipynb index 7c617dab..3786ff3f 100644 --- a/examples/01_inference_pretrained.ipynb +++ b/examples/01_inference_pretrained.ipynb @@ -8,11 +8,30 @@ "source": [ "# Step 1: Minimal Octo Inference Example\n", "\n", - "This Colab demonstrates how to load a pre-trained / finetuned Octo checkpoint, run inference on some images, and compare the outputs to the true actions.\n", + "This notebook demonstrates how to load a pre-trained / finetuned Octo checkpoint, run inference on some images, and compare the outputs to the true actions.\n", "\n", "First, let's start with a minimal example!" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "bae44461", + "metadata": {}, + "outputs": [], + "source": [ + "# run this block if you're using Colab\n", + "\n", + "# Download repo\n", + "!git clone https://github.com/octo-models/octo.git\n", + "%cd octo\n", + "# Install repo\n", + "!pip3 install -e .\n", + "!pip3 install -r requirements.txt\n", + "!pip3 install --upgrade \"jax[cuda11_pip]==0.4.20\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "!pip install numpy==1.21.1 # to fix colab AttributeError: module 'numpy' has no attribute '_no_nep50_warning', if the error still shows reload" + ] + }, { "cell_type": "code", "execution_count": null, @@ -31,7 +50,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "from octo.model.octo_model import OctoModel\n", "\n", "model = OctoModel.load_pretrained(\"hf://rail-berkeley/octo-small-1.5\")" @@ -51,7 +69,7 @@ "# download one example BridgeV2 image\n", "IMAGE_URL = \"https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg\"\n", "img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))\n", - "plt.imshow(img)\n" + "plt.imshow(img)" ] }, {