Skip to content

Commit

Permalink
in multi latent attention, cache the lightweight latent kv
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 4, 2025
1 parent 62237f8 commit 28818a9
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@ def forward(
cache: Intermediates | None = None,
value_residual = None
):
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv

assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'

Expand All @@ -1476,16 +1476,24 @@ def forward(
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
v_input, _ = pack([mem, v_input], 'b * d')

# maybe project to latent queries and cache-able latent key values
# multi-latent attention logic
# https://arxiv.org/abs/2405.04434 - Deepseek-AI team

if self.use_latent_q:
q_input = self.to_latent_q(q_input)

if self.use_latent_kv:
if is_multi_latent_attn:
assert not qkv_receive_diff_residuals

latent_kv_input = self.to_latent_kv(k_input)

if exists(cache):
cached_latent_kv = cache.cached_kv
latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)

if return_intermediates:
cached_kv = latent_kv_input

k_input = v_input = latent_kv_input

# query, key, value projection
Expand Down Expand Up @@ -1519,7 +1527,7 @@ def forward(

# take care of caching

if exists(cache):
if not is_multi_latent_attn and exists(cache):
ck, cv = cache.cached_kv

if exists(mem):
Expand All @@ -1533,7 +1541,7 @@ def forward(
k = cat((mk, k), dim = -2)
v = cat((mv, v), dim = -2)

if return_intermediates:
if not is_multi_latent_attn and return_intermediates:
mem_len = mem.shape[-2] if exists(mem) else 0
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])

Expand Down Expand Up @@ -2219,7 +2227,7 @@ def forward(
attn_cache = []

if exists(cache):
assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
assert self.causal and not any([*map(exists, (mask, attn_mask))])

if exists(context):
context = context[:, :0]
Expand Down

0 comments on commit 28818a9

Please sign in to comment.