diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index a0480864..fe752543 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1199,7 +1199,6 @@ def __init__( kv_heads = None, value_dim_head = None, dim_out = None, - tensor_product = False, # https://arxiv.org/abs/2208.06061 add_zero_kv = False, # same as add_zero_attn in pytorch rotate_num_heads = None, data_dependent_alibi = False, @@ -1213,13 +1212,14 @@ def __init__( softclamp_logits = False, logit_softclamp_value = 50., learned_value_residual_mix = False, - laser = False, # https://arxiv.org/abs/2411.03493v1 + laser = False, # https://arxiv.org/abs/2411.03493v1 laser_softclamp_value = 15., qkv_receive_diff_residuals = False, use_latent_q = False, dim_latent_q = None, use_latent_kv = False, dim_latent_kv = None, + latent_rope_subheads = None, onnxable = False, attend_sdp_kwargs: dict = dict( enable_flash = True, @@ -1256,6 +1256,7 @@ def __init__( self.to_latent_q = None self.to_latent_kv = None + self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents) dim_q_input = dim dim_kv_input = dim_kv @@ -1270,6 +1271,15 @@ def __init__( self.to_latent_kv = LinearNoBias(dim, dim_latent_kv) dim_kv_input = dim_latent_kv + if exists(latent_rope_subheads): + assert not exists(rotate_num_heads) + rotate_num_heads = latent_rope_subheads + + k_dim = dim_head * (kv_heads - latent_rope_subheads) + + self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads) + self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads) + self.use_latent_q = use_latent_q self.use_latent_kv = use_latent_kv @@ -1279,6 +1289,14 @@ def __init__( self.to_k = LinearNoBias(dim_kv_input, k_dim) self.to_v = LinearNoBias(dim_kv_input, v_dim) + # split and merge of attention heads + + self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads) + self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head) + self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head) + + self.merge_heads = Rearrange('b h n d -> b n (h d)') + # whether qkv receives different residual stream combinations from hyper connections self.qkv_receive_diff_residuals = qkv_receive_diff_residuals @@ -1429,7 +1447,10 @@ def __init__( # the number of attention heads to rotate, for decoupled rope in multi-latent attention rotate_num_heads = default(rotate_num_heads, heads) + assert 0 < rotate_num_heads <= heads + is_partial_rotate_heads = rotate_num_heads < heads + assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet' self.rotate_num_heads = rotate_num_heads @@ -1506,9 +1527,17 @@ def forward( k = self.to_k(k_input) v = self.to_v(v_input) - q = rearrange(q, 'b n (h d) -> b h n d', h = h) + q = self.split_q_heads(q) + k = self.split_k_heads(k) + v = self.split_v_heads(v) + + # take care of decoupled rope from multi-latent attention + + if exists(self.to_rotateable_k): + rotate_k = self.to_rotateable_k(k_input) + rotate_k = self.split_rotateable_k_heads(rotate_k) - k, v = tuple(rearrange(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v)) + k = cat((k, rotate_k), dim = 1) # if previous values passed in for residual, either invoke resformer @@ -1694,7 +1723,7 @@ def forward( # merge heads - out = rearrange(out, 'b h n d -> b n (h d)') + out = self.merge_heads(out) # hybrid module