Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

up matrix is not involved? #16

Open
yejunguo opened this issue Feb 24, 2025 · 6 comments
Open

up matrix is not involved? #16

yejunguo opened this issue Feb 24, 2025 · 6 comments

Comments

@yejunguo
Copy link

yejunguo commented Feb 24, 2025

Hi,

very roughly, MLA compresses the input into a latent tensor via DOWN matrix, caches the latent tensor, and then converts the latent tensor back to 'normal' QKVs via UP matrix before SDPA.

looks that FlashMLA does not accept the UP matrix in its parameters, and so the inputs of FlashMLA are 'normal' (MHA/GQA/MQA) QKVs?

imho, FlashMLA is expected to accept MLA caches and UP matrices etc, doing possible matrix absorb together with SDPA.

@MacavityT
Copy link

Same question, without the input of UP matrix, how can we use the MLA computation skills in DeepSeek-v2 paper?

@YLGH
Copy link

YLGH commented Feb 24, 2025

For MLA the q absorb and o absorb steps can be done separately from the attention.
e.g.
q: [bs, num_q_heads, 128 (head dim)] -> q: [bs, num_q_heads, 512 (latent dim)] concat q_rope: [bs, num_q_heads, 64)]
the output of MLA will be [bs, num_q_heads, 512)], which can then be down_projed independently.

@MacavityT
Copy link

For MLA the q absorb and o absorb steps can be done separately from the attention. e.g. q: [bs, num_q_heads, 128 (head dim)] -> q: [bs, num_q_heads, 512 (latent dim)] concat q_nope: [bs, num_q_heads, 64)] the output of MLA will be [bs, num_q_heads, 512)], which can then be down_projed independently.

But it seems to only work for models that adopted MLA during the training process. If we need to apply FlashMLA to models like the LLaMA series to accelerate inference decoding, it doesn't seem to make sense.

@hhding
Copy link

hhding commented Feb 24, 2025

For MLA the q absorb and o absorb steps can be done separately from the attention. e.g. q: [bs, num_q_heads, 128 (head dim)] -> q: [bs, num_q_heads, 512 (latent dim)] concat q_nope: [bs, num_q_heads, 64)] the output of MLA will be [bs, num_q_heads, 512)], which can then be down_projed independently.

But it seems to only work for models that adopted MLA during the training process. If we need to apply FlashMLA to models like the LLaMA series to accelerate inference decoding, it doesn't seem to make sense.

You can convert MHA to MLA: https://arxiv.org/abs/2502.14837 (Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs)

@Thomas-MMJ
Copy link

Can also convert GQA to MLA

https://github.com/fxmeng/TransMLA
https://arxiv.org/abs/2502.07864

@yejunguo
Copy link
Author

For MLA the q absorb and o absorb steps can be done separately from the attention. e.g. q: [bs, num_q_heads, 128 (head dim)] -> q: [bs, num_q_heads, 512 (latent dim)] concat q_rope: [bs, num_q_heads, 64)] the output of MLA will be [bs, num_q_heads, 512)], which can then be down_projed independently.

agree that down matrix is independent of FlashMLA, but looks that my concern about inputs of FlashMLA (MLA caches and UP matrices) is still there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants