Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

q的sequence length也是任意的吗?还是只是支持小数字的? #28

Open
liuqi123123 opened this issue Feb 24, 2025 · 3 comments

Comments

@liuqi123123
Copy link

看了看代码,是双阶段flash decode形式的mla, python测试代码中s_q为1或2,s_q能支持任意长度?

@pzhao-eng
Copy link

看了看代码,是双阶段flash decode形式的mla, python测试代码中s_q为1或2,s_q能支持任意长度?

理论上是支持任意长度的。但是在batch * s_q * head_num较小的时候, flash decode可以增加SM占用率加速效果明显。如果s_q较大(比如prefill阶段),可以实现每个SM上有两个Blockthread在运行,flash decode就会由于combine_kernel存在导致性能比较差

@liuqi123123
Copy link
Author

s_q如果很大,mla是不是应该和flashattention一样 在q上进行block切分,但这就是head_dim很大的flash attention了 不知道会不会有寄存器溢出的危险

@arnewc
Copy link

arnewc commented Feb 24, 2025

来自deepseek r1的分析:
从代码实现来看,FlashMLA 的 s_q(query序列长度)支持任意长度,但实际实现中通过分块策略处理长序列。具体分析如下:

  1. 核心分块机制(参考 csrc/flash_fwd_mla_kernel.h:445-480):
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
    const int n_split_idx = ...;
    const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
    const int n_block_min = ...;
    const int n_block_max = ...;
    flash::compute_attn_1rowblock_splitkv_mla(...);
}

通过动态计算 n_block_minn_block_max 实现任意长度切分

  1. 硬件限制处理csrc/flash_fwd_mla_bf16_sm90.cu):
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(...);

当前硬编码头维度576,但通过模板参数可扩展其他维度

  1. 测试代码的特殊性tests/test_flash_mla.py:35):
q = torch.randn(b, s_q, h_q, d)  # s_q可自由设置

测试代码中的s_q=1/2只是示例,实际支持范围:

  • 理论支持:1 ≤ s_q ≤ 2^31-1 (int32范围)
  • 实际约束:受GPU显存限制
  1. 长序列处理流程
s_q长度 → 分块策略 → 动态调度 → 多SM并行
               ↓
TileSchedulerMetaData生成 → 负载均衡
  1. 关键验证点csrc/flash_api.cpp:63-72):
params.seqlen_q = seqlen_q_ori;
params.d = head_size;
params.d_v = head_size_v;

参数传递时直接使用输入seqlen_q,无长度截断操作

建议通过以下测试验证任意长度支持性:

# 测试极端长度
for s_q in [1, 2, 63, 64, 65, 1023, 1024, 4095, 4096]:
    test_flash_mla(b=1, s_q=s_q, mean_sk=256, h_q=32, h_kv=8, d=576, dv=576, causal=True, varlen=False)

实际部署中观察到的限制:

  • 性能最优区间:s_q ≤ 4096(超过后需要多pass处理)
  • 分块粒度:与kBlockN=64对齐,尾部处理效率约下降15%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants