Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lang-visual cross attention #83

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions octo/model/octo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class OctoTransformer(nn.Module):
max_horizon (int): The maximum number of timesteps that the transformer can be run with. Note that while the
transformer can be run with any horizon <= max_horizon, the model will only generate sane outputs for
horizon lengths smaller or equal to the pre-training horizon.
repeat_task_tokens: If true, repeats the task tokens at each observation timesetep.
"""

observation_tokenizers: Dict[str, nn.Module]
Expand All @@ -81,6 +82,7 @@ class OctoTransformer(nn.Module):
transformer_kwargs: Dict
token_embedding_size: int
max_horizon: int
repeat_task_tokens: bool

@nn.compact
def __call__(
Expand Down Expand Up @@ -209,6 +211,27 @@ def __call__(
attention_rules=observation_attention_rules,
)
)
if self.repeat_task_tokens:
logging.info(
"repeating task tokens at each timestep to perform cross-modal attention"
)
# get task tokens
for tasks in all_prefix_groups:
# lang (batch, n_tokens, token_embedding_size)
task_tokens = tasks.tokens[:, jnp.newaxis, :, :]
ws = all_timestep_groups[0].tokens.shape[1]
task_tokens = jnp.tile(task_tokens, [1, ws, 1, 1])
task_pad_mask = tasks.mask[:, jnp.newaxis, :]
task_pad_mask = jnp.tile(task_pad_mask, [1, ws, 1])
group_name = f"obs_{tasks.name}"
all_timestep_groups.append(
TimestepGroup(
tokens=task_tokens,
mask=task_pad_mask,
name=group_name,
attention_rules=observation_attention_rules,
)
)
#
# Finally, add the readout tokens
#
Expand Down Expand Up @@ -345,6 +368,7 @@ def create(
transformer_kwargs: Dict,
token_embedding_size: int,
max_horizon: int,
repeat_task_tokens: bool = False,
) -> "OctoModule":
"""
Canonical way to create an OctoModule from configuration.
Expand All @@ -357,6 +381,7 @@ def create(
token_embedding_size (int): The latent dimension of the token embeddings
max_horizon (int): Sets the size of positional embeddings, and provides an upper limit on the
maximum horizon of the model
repeat_task_tokens (bool): If true, repeats the task tokens at each observation timestep.
transformer_kwargs: additional kwargs to forward to the transformer, which include:
num_layers (int): number of layers
mlp_dim (int): hidden dimension of the MLPs
Expand All @@ -381,6 +406,7 @@ def create(
readouts=readouts,
token_embedding_size=token_embedding_size,
max_horizon=max_horizon,
repeat_task_tokens=repeat_task_tokens,
transformer_kwargs=transformer_kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions scripts/configs/octo_pretrain_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_config(config_string=None):
finetune_encoder=False,
),
}
config["model"]["repeat_task_tokens"] = True
config["model"]["readouts"] = {"action": 1}
config["model"]["heads"]["action"] = ModuleSpec.create(
DiffusionActionHead,
Expand Down
Loading