diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 98520882..ee5c67e9 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1484,6 +1484,7 @@ def forward( q_input = self.to_latent_q(q_input) if is_multi_latent_attn: + assert not exists(rotary_pos_emb), 'rotary positions not supported yet' assert not qkv_receive_diff_residuals latent_kv_input = self.to_latent_kv(k_input)