Skip to content

Commit

Permalink
complete multi-latent attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 4, 2025
1 parent d3420e2 commit 7ad71d3
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,6 @@ def __init__(
kv_heads = None,
value_dim_head = None,
dim_out = None,
tensor_product = False, # https://arxiv.org/abs/2208.06061
add_zero_kv = False, # same as add_zero_attn in pytorch
rotate_num_heads = None,
data_dependent_alibi = False,
Expand All @@ -1213,13 +1212,14 @@ def __init__(
softclamp_logits = False,
logit_softclamp_value = 50.,
learned_value_residual_mix = False,
laser = False, # https://arxiv.org/abs/2411.03493v1
laser = False, # https://arxiv.org/abs/2411.03493v1
laser_softclamp_value = 15.,
qkv_receive_diff_residuals = False,
use_latent_q = False,
dim_latent_q = None,
use_latent_kv = False,
dim_latent_kv = None,
latent_rope_subheads = None,
onnxable = False,
attend_sdp_kwargs: dict = dict(
enable_flash = True,
Expand Down Expand Up @@ -1256,6 +1256,7 @@ def __init__(

self.to_latent_q = None
self.to_latent_kv = None
self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents)

dim_q_input = dim
dim_kv_input = dim_kv
Expand All @@ -1270,6 +1271,15 @@ def __init__(
self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
dim_kv_input = dim_latent_kv

if exists(latent_rope_subheads):
assert not exists(rotate_num_heads)
rotate_num_heads = latent_rope_subheads

k_dim = dim_head * (kv_heads - latent_rope_subheads)

self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)

self.use_latent_q = use_latent_q
self.use_latent_kv = use_latent_kv

Expand All @@ -1279,6 +1289,14 @@ def __init__(
self.to_k = LinearNoBias(dim_kv_input, k_dim)
self.to_v = LinearNoBias(dim_kv_input, v_dim)

# split and merge of attention heads

self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)

self.merge_heads = Rearrange('b h n d -> b n (h d)')

# whether qkv receives different residual stream combinations from hyper connections

self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
Expand Down Expand Up @@ -1429,7 +1447,10 @@ def __init__(
# the number of attention heads to rotate, for decoupled rope in multi-latent attention

rotate_num_heads = default(rotate_num_heads, heads)

assert 0 < rotate_num_heads <= heads
is_partial_rotate_heads = rotate_num_heads < heads
assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'

self.rotate_num_heads = rotate_num_heads

Expand Down Expand Up @@ -1506,9 +1527,17 @@ def forward(
k = self.to_k(k_input)
v = self.to_v(v_input)

q = rearrange(q, 'b n (h d) -> b h n d', h = h)
q = self.split_q_heads(q)
k = self.split_k_heads(k)
v = self.split_v_heads(v)

# take care of decoupled rope from multi-latent attention

if exists(self.to_rotateable_k):
rotate_k = self.to_rotateable_k(k_input)
rotate_k = self.split_rotateable_k_heads(rotate_k)

k, v = tuple(rearrange(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v))
k = cat((k, rotate_k), dim = 1)

# if previous values passed in for residual, either invoke resformer

Expand Down Expand Up @@ -1694,7 +1723,7 @@ def forward(

# merge heads

out = rearrange(out, 'b h n d -> b n (h d)')
out = self.merge_heads(out)

# hybrid module

Expand Down

0 comments on commit 7ad71d3

Please sign in to comment.