Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Both layouts of Grouped GEMMs need to be aligned to the GEMM M block size ? #15

Open
CtfGo opened this issue Feb 26, 2025 · 5 comments
Open

Comments

@CtfGo
Copy link

CtfGo commented Feb 26, 2025

  1. The m_grouped_gemm_fp8_fp8_bf16_nt_contiguous function documentation says that

"On the M axis, inputs are grouped into several batches, of which batch sizes aligned to get_m_alignment_for_contiguous_layout() (128).",

which means the M of each group must be multiples of 128?

  1. I don't see such declaration on the m_grouped_gemm_fp8_fp8_bf16_nt_masked, but I test several M (8, 16, 64) that don't meet the requirement on test/test_core.py by test_m_grouped_gemm_masked and found they can't pass the correctness validation, so does the masked layout grouped gemm also has the limitation implicitly?
@LyricZhao
Copy link
Collaborator

which means the M of each group must be multiples of 128?

Yes.

but I test several M (8, 16, 64) that don't meet the requirement on test/test_core.py by test_m_grouped_gemm_masked and found they can't pass the correctness validation, so does the masked layout grouped gemm also has the limitation implicitly

No limitation for the masked kernel, could you please share you modified test script?

@CtfGo
Copy link
Author

CtfGo commented Feb 27, 2025

@LyricZhao Hi, thank you for your reply and appreciate your awesome work.

No limitation for the masked kernel, could you please share you modified test script?

I just modify the two lines on test/test_core.py test_m_grouped_gemm_masked to test other num_groups and m, the detail diff is:

Image

then run python test/test_core.py, and will see AssertionError of the diff check:

Image

test machine environment: Nvidia H800、cuda 12.4、torch 2.6.0+cu124

@LyricZhao
Copy link
Collaborator

LyricZhao commented Feb 27, 2025

Fixed in main, commit: ca13ce0 and 6da94d2

The reason is wrong TMA store block size when shape_m is smaller than BLOCK_M. So currently we restrict shape_m <= BLOCK_M or shape_m % BLOCK_M == 0 if for masked GEMM with more than one group.

Thanks for feedback!

@CtfGo
Copy link
Author

CtfGo commented Feb 28, 2025

It looks well now.

Are these limitations necessary to achieve good performance?can we remove them to support any m ? for example, by passing extra params that can indicate batch size or data address of each group, like exclusive_sum.

@LyricZhao
Copy link
Collaborator

Not necessary, they are TMA limitations. If we are using 2D [num_groups * m, k] TMA store, I guess there's no way to bypass the limitation (one group may overwrite to the next group).

A possible solution is to use 3D TMA [num_groups, m, k], but I don't think adding such support for some rare cases is good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants