From 28818a9035394aa50201e7267d7e8ac508617f4f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 4 Feb 2025 05:48:49 -0800 Subject: [PATCH] in multi latent attention, cache the lightweight latent kv --- x_transformers/x_transformers.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index a4292414..d4a65a82 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1458,7 +1458,7 @@ def forward( cache: Intermediates | None = None, value_residual = None ): - b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals + b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention' @@ -1476,16 +1476,24 @@ def forward( k_input, mem_packed_shape = pack([mem, k_input], 'b * d') v_input, _ = pack([mem, v_input], 'b * d') - # maybe project to latent queries and cache-able latent key values + # multi-latent attention logic # https://arxiv.org/abs/2405.04434 - Deepseek-AI team if self.use_latent_q: q_input = self.to_latent_q(q_input) - if self.use_latent_kv: + if is_multi_latent_attn: assert not qkv_receive_diff_residuals latent_kv_input = self.to_latent_kv(k_input) + + if exists(cache): + cached_latent_kv = cache.cached_kv + latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2) + + if return_intermediates: + cached_kv = latent_kv_input + k_input = v_input = latent_kv_input # query, key, value projection @@ -1519,7 +1527,7 @@ def forward( # take care of caching - if exists(cache): + if not is_multi_latent_attn and exists(cache): ck, cv = cache.cached_kv if exists(mem): @@ -1533,7 +1541,7 @@ def forward( k = cat((mk, k), dim = -2) v = cat((mv, v), dim = -2) - if return_intermediates: + if not is_multi_latent_attn and return_intermediates: mem_len = mem.shape[-2] if exists(mem) else 0 cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :]) @@ -2219,7 +2227,7 @@ def forward( attn_cache = [] if exists(cache): - assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))]) + assert self.causal and not any([*map(exists, (mask, attn_mask))]) if exists(context): context = context[:, :0]