Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why head_dim and vhead_dim only support 576,512 #43

Open
liuchanchen opened this issue Feb 25, 2025 · 1 comment
Open

why head_dim and vhead_dim only support 576,512 #43

liuchanchen opened this issue Feb 25, 2025 · 1 comment

Comments

@liuchanchen
Copy link

In the code, I noticed there are restrictions on head_dim and vhead_dim. What is the reason behind these constraints?

Image

@unchained369
Copy link

unchained369 commented Feb 26, 2025

Reason(s):

HARDWARE ARCHITECTURE ALIGNMENT (HOPPER SPECIFICS)
Tensor Core Math: Hopper’s FP16/BF16 tensor cores achieve peak FLOPs when operating on 128x128 matrices. The constraints (576=128*4 + 64, 512=128*4) allow splitting heads into warps that perfectly tile these matrices without padding waste.

Shared Memory Banks: Head dimensions must avoid bank conflicts during k/v cache loads. 576 and 512 are multiples of Hopper’s 32-bank structure—critical for parallel atomic updates in paged KV cache (block size=64)

MEMORY COHERENCE FOR PAGED KVCACHE
Block Size = 64: Each KV cache block holds 64 tokens. With head_dim=576:

576-dim vectors per token → 64 tokens/block → 36,864 elements/block  
36,864 elements * 2 bytes (bf16) = 73,728 bytes → Fits Hopper’s L2 cache line (128KB) with room for metadata.

Deviating from 576/512 would fragment blocks across cache lines → thrashing L2/TLB during page table walks.

WARP-LEVEL OPTIMIZATIONS: Each warp (32 threads) handles 16 query vectors (h_q // h_kv=16 common in GQA). For head_dim=576: 576 elements/vector ÷ 32 threads = 18 elements/thread → No remainder → Coalesced loads via PTX ldmatrix.sync.aligned.
Arbitrary dimensions would force partial warps or wasted cycles on padding elements.

CUTLASS TEMPLATE RESTRICTIONS: FlashMLA’s kernel inherits CUTLASS’s GemmUniversal tiling schemes optimized for:

Threadblock Shapes: 256x128x64 for Hopper MMAv9 instructions.

head_dim=512 → Fills threadblock tiles as (256 threads) * (2 elements/thread) = 512.

head_dim=576 → Uses residual tiles (512 + 64) mapped to async copy units via Hopper’s TMA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants