-
Notifications
You must be signed in to change notification settings - Fork 573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
q的sequence length也是任意的吗?还是只是支持小数字的? #28
Comments
理论上是支持任意长度的。但是在batch * s_q * head_num较小的时候, flash decode可以增加SM占用率加速效果明显。如果s_q较大(比如prefill阶段),可以实现每个SM上有两个Blockthread在运行,flash decode就会由于 |
s_q如果很大,mla是不是应该和flashattention一样 在q上进行block切分,但这就是head_dim很大的flash attention了 不知道会不会有寄存器溢出的危险 |
来自deepseek r1的分析:
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = ...;
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
const int n_block_min = ...;
const int n_block_max = ...;
flash::compute_attn_1rowblock_splitkv_mla(...);
} 通过动态计算
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(...); 当前硬编码头维度576,但通过模板参数可扩展其他维度
q = torch.randn(b, s_q, h_q, d) # s_q可自由设置 测试代码中的s_q=1/2只是示例,实际支持范围:
params.seqlen_q = seqlen_q_ori;
params.d = head_size;
params.d_v = head_size_v; 参数传递时直接使用输入seqlen_q,无长度截断操作 建议通过以下测试验证任意长度支持性: # 测试极端长度
for s_q in [1, 2, 63, 64, 65, 1023, 1024, 4095, 4096]:
test_flash_mla(b=1, s_q=s_q, mean_sk=256, h_q=32, h_kv=8, d=576, dv=576, causal=True, varlen=False) 实际部署中观察到的限制:
|
看了看代码,是双阶段flash decode形式的mla, python测试代码中s_q为1或2,s_q能支持任意长度?
The text was updated successfully, but these errors were encountered: