Skip to content

Commit

Permalink
loads paraphrases file from HuggingFace
Browse files Browse the repository at this point in the history
  • Loading branch information
mees committed May 8, 2024
1 parent 1613506 commit 35cefca
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
31 changes: 23 additions & 8 deletions octo/data/utils/task_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@

import pickle

from huggingface_hub import hf_hub_download
import tensorflow as tf

from octo.data.utils.data_utils import to_padding


def delete_and_rephrase(
traj, pickle_file_path: str, rephrase_prob: float, keep_image_prob: float
traj,
paraphrases_repo: str,
paraphrases_filename: str,
rephrase_prob: float,
keep_image_prob: float,
):
traj = rephrase_instruction(traj, pickle_file_path, rephrase_prob)
traj = rephrase_instruction(
traj, paraphrases_repo, paraphrases_filename, rephrase_prob
)
traj = delete_task_conditioning(traj, keep_image_prob)
return traj

Expand All @@ -28,25 +35,33 @@ def create_static_hash_table(self, dictionary):
hash_table = tf.lookup.StaticHashTable(initializer, default_value="")
return hash_table

def __init__(self, pickle_file_path: str):
if isinstance(pickle_file_path, str):
with tf.io.gfile.GFile(pickle_file_path, "rb") as file:
def __init__(self, paraphrases_repo: str, paraphrases_filename: str):
if isinstance(paraphrases_repo, str) and isinstance(paraphrases_filename, str):
with open(
hf_hub_download(
repo_id=paraphrases_repo,
filename=paraphrases_filename,
repo_type="dataset",
),
"rb",
) as file:
lang_paraphrases = pickle.load(file)
# Create StaticHashTable
self.rephrase_lookup = self.create_static_hash_table(lang_paraphrases)


def rephrase_instruction(
traj: dict, pickle_file_path: str, rephrase_prob: float
traj: dict, paraphrases_repo: str, paraphrases_filename: str, rephrase_prob: float
) -> dict:
"""Randomly rephrases language instructions with precomputed paraphrases
Args:
traj: A dictionary containing trajectory data. Should have a "task" key.
pickle_file_path: The path to the pickle file containing the paraphrases.
paraphrases_repo: The name of the HF repo containing the paraphrases file.
paraphrases_filename: The name of the file containing the paraphrases.
rephrase_prob: The probability of augmenting the language instruction. The probability of keeping the language
instruction is 1 - rephrase_prob.
"""
rephraser = Rephraser(pickle_file_path)
rephraser = Rephraser(paraphrases_repo, paraphrases_filename)

if "language_instruction" not in traj["task"]:
return traj
Expand Down
3 changes: 2 additions & 1 deletion scripts/configs/octo_pretrain_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def get_config(config_string=None):
future_action_window_size=3,
task_augment_strategy="delete_and_rephrase",
task_augment_kwargs=dict(
pickle_file_path="gs://rail-datasets-europe-west4/oxe/resize_256_256/paraphrases_oxe.pkl",
paraphrases_repo="rail-berkeley/OXE_paraphrases",
paraphrases_filename="paraphrases_oxe.pkl",
rephrase_prob=0.5,
),
),
Expand Down

0 comments on commit 35cefca

Please sign in to comment.