-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Always insert wgmma.fence before MmaOp #3680
base: main
Are you sure you want to change the base?
Conversation
Before: ``` [ RUN ] HopperMatmulTest.MLPBenchmarkFwdGEMM ptxas info : (C7519) warpgroup.arrive is injected in around line 723 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' ptxas info : (C7519) warpgroup.arrive is injected in around line 827 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' ptxas info : (C7519) warpgroup.arrive is injected in around line 932 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' ptxas info : (C7519) warpgroup.arrive is injected in around line 1057 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' ptxas info : (C7519) warpgroup.arrive is injected in around line 1118 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' ptxas info : (C7519) warpgroup.arrive is injected in around line 1165 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' ptxas info : (C7519) warpgroup.arrive is injected in around line 1213 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' ptxas info : 3 bytes gmem ptxas info : Compiling entry function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' for 'sm_90a' ptxas info : Function properties for _ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_ ptxas . 0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads ptxas info : Used 154 registers, used 1 barriers ptxas info : Compile time = 154.222 ms [ OK ] HopperMatmulTest.MLPBenchmarkFwdGEMM (6443 ms) ``` After: ``` [ RUN ] HopperMatmulTest.MLPBenchmarkFwdGEMM ptxas info : 3 bytes gmem ptxas info : Compiling entry function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_91517110_54524nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' for 'sm_90a' ptxas info : Function properties for _ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_91517110_54524nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_ ptxas . 0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads ptxas info : Used 154 registers, used 1 barriers ptxas info : Compile time = 168.444 ms [ OK ] HopperMatmulTest.MLPBenchmarkFwdGEMM (8376 ms) ``` On H100 80GB HBM3 (prod machine) perf is basically unchanged: this kernel runs 2.222 ms before and 2.198 ms after this PR.
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some updates to the code comments.
// Makes sure that writes to operands in the generic proxy are visible | ||
// to the async proxy | ||
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>(); | ||
registerInsertBefore(expr, wgmma_fence, scope); | ||
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>(); | |
// fence.proxy.async makes sure that writes to operands in the generic proxy are visible | |
// to the async proxy | |
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized that kir::FenceAsyncProxy
permutes the words of fence.proxy.async
ptx.
@@ -393,11 +393,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { | |||
if (auto mma = dynamic_cast<MmaOp*>(expr)) { | |||
if (mma->isHopper()) { | |||
auto scope = scope_.empty() ? nullptr : scope_.back(); | |||
// Makes sure that writes to operands in the generic proxy are visible | |||
// to the async proxy | |||
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>(); | |
// When wgmma_fence needs to be issued by all warps: | |
// 1) Before the first wgmma.mma_async operation in a warp group. | |
// 2) Between a register access by a thread in the warp group and any | |
// wgmma.mma_async instruction that accesses the same registers, either | |
// as accumulator or input register containing fragments of matrix A, | |
// except when these are accumulator register accesses across multiple | |
// wgmma.mma_async instructions of the same shape. In the latter case, | |
// an ordering guarantee is provided by default. | |
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>(); |
We previously did not insert fences before each MMA if they were guarded by mbarrier. The compiler warned us that these needed to be inserted and did so automatically. This PR inserts them in order to silence the warning.
Before:
After:
On H100 80GB HBM3 (prod machine) perf is unchanged: this kernel runs 2.222 ms before and 2.198 ms after this PR.
I confirmed that there is no change in the compiled SASS.