Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always insert wgmma.fence before MmaOp #3680

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jan 8, 2025

We previously did not insert fences before each MMA if they were guarded by mbarrier. The compiler warned us that these needed to be inserted and did so automatically. This PR inserts them in order to silence the warning.

Before:

  #pragma unroll 3
  for(nvfuser_index_t i21 = 0; i21 < i4; ++i21) {
    nvfuser_index_t i22;
    i22 = (3 + i21) % 4;
    nvfuser_index_t i23;
    i23 = i21 % 4;
    unsigned i24;
    i24 = i11 + (4096 * i23);
    unsigned i25;
    i25 = i7 + (8192 * i23);
    if (((Hopper::electSync(4294967295U) && b16) && b17)) {
      mbarrier::arriveExpectTX(toSmem((&T9[((3LL + i21) % 4)])), 8192U);
      Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(48 + (16 * i21)), i6}), toSmem((&T9[((3LL + i21) % 4)])) }), (i7 + (8192 * i22)));
      mbarrier::arriveExpectTX(toSmem((&T9[((3LL + i21) % 4)])), 4096U);
      Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(48 + (16 * i21)), i9}), toSmem((&T9[((3LL + i21) % 4)])) }), (i10 + (4096 * i22)));
    }
    mbarrier::waitParity(toSmem((&T9[(i21 % 4)])), (uint32_t)(((i21 / 4) % 2)));
    asm volatile(
      "{\n"
      "  .reg .pred p0; \n"
      "  setp.ne.b32 p0, %130, 0;\n"
      "  wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, %128, %129, p0, %131, %132, %133, %134;\n"
[ RUN      ] HopperMatmulTest.MLPBenchmarkFwdGEMM
ptxas info    : (C7519) warpgroup.arrive is injected in around line 723 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 827 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 932 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1057 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1118 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1165 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1213 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' for 'sm_90a'
ptxas info    : Function properties for _ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 154 registers, used 1 barriers
ptxas info    : Compile time = 154.222 ms

[       OK ] HopperMatmulTest.MLPBenchmarkFwdGEMM (6443 ms)

After:

  for(nvfuser_index_t i21 = 0; i21 < i4; ++i21) {
    nvfuser_index_t i22;
    i22 = (3 + i21) % 4;
    nvfuser_index_t i23;
    i23 = i21 % 4;
    unsigned i24;
    i24 = i11 + (4096 * i23);
    unsigned i25;
    i25 = i7 + (8192 * i23);
    if (((Hopper::electSync(4294967295U) && b16) && b17)) {
      mbarrier::arriveExpectTX(toSmem((&T9[((3LL + i21) % 4)])), 8192U);
      Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(48 + (16 * i21)), i6}), toSmem((&T9[((3LL + i21) % 4)])) }), (i7 + (8192 * i22)));
      mbarrier::arriveExpectTX(toSmem((&T9[((3LL + i21) % 4)])), 4096U);
      Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(48 + (16 * i21)), i9}), toSmem((&T9[((3LL + i21) % 4)])) }), (i10 + (4096 * i22)));
    }
    mbarrier::waitParity(toSmem((&T9[(i21 % 4)])), (uint32_t)(((i21 / 4) % 2)));
    asm volatile("wgmma.fence.sync.aligned;\n");
    asm volatile(
      "{\n"
      "  .reg .pred p0; \n"
      "  setp.ne.b32 p0, %130, 0;\n"
      "  wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, %128, %129, p0, %131, %132, %133, %134;\n"
[ RUN      ] HopperMatmulTest.MLPBenchmarkFwdGEMM
ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_91517110_54524nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' for 'sm_90a'
ptxas info    : Function properties for _ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_91517110_54524nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 154 registers, used 1 barriers
ptxas info    : Compile time = 168.444 ms

[       OK ] HopperMatmulTest.MLPBenchmarkFwdGEMM (8376 ms)

On H100 80GB HBM3 (prod machine) perf is unchanged: this kernel runs 2.222 ms before and 2.198 ms after this PR.

I confirmed that there is no change in the compiled SASS.

Before:
```
[ RUN      ] HopperMatmulTest.MLPBenchmarkFwdGEMM
ptxas info    : (C7519) warpgroup.arrive is injected in around line 723 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 827 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 932 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1057 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1118 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1165 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : (C7519) warpgroup.arrive is injected in around line 1213 by compiler to allow use of registers in GMMA in function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_'
ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' for 'sm_90a'
ptxas info    : Function properties for _ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_c4bef0bf_32424nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 154 registers, used 1 barriers
ptxas info    : Compile time = 154.222 ms

[       OK ] HopperMatmulTest.MLPBenchmarkFwdGEMM (6443 ms)
```
After:
```
[ RUN      ] HopperMatmulTest.MLPBenchmarkFwdGEMM
ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_91517110_54524nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_' for 'sm_90a'
ptxas info    : Function properties for _ZN69_GLOBAL__N__00000000_32___tmp_kernel_none_f0_c0_r0_g0_cu_91517110_54524nvfuser_none_f0_c0_r0_g0ENS_6TensorINS_8__bfloatELi2ELi2EEES2_NS_9TensorMapES3_S3_S2_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 154 registers, used 1 barriers
ptxas info    : Compile time = 168.444 ms

[       OK ] HopperMatmulTest.MLPBenchmarkFwdGEMM (8376 ms)
```
On H100 80GB HBM3 (prod machine) perf is basically unchanged: this
kernel runs 2.222 ms before and 2.198 ms after this PR.
@jacobhinkle
Copy link
Collaborator Author

!test

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some updates to the code comments.

// Makes sure that writes to operands in the generic proxy are visible
// to the async proxy
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>();
registerInsertBefore(expr, wgmma_fence, scope);
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>();
// fence.proxy.async makes sure that writes to operands in the generic proxy are visible
// to the async proxy
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that kir::FenceAsyncProxy permutes the words of fence.proxy.async ptx.

@@ -393,11 +393,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
if (auto mma = dynamic_cast<MmaOp*>(expr)) {
if (mma->isHopper()) {
auto scope = scope_.empty() ? nullptr : scope_.back();
// Makes sure that writes to operands in the generic proxy are visible
// to the async proxy
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>();
// When wgmma_fence needs to be issued by all warps:
// 1) Before the first wgmma.mma_async operation in a warp group.
// 2) Between a register access by a thread in the warp group and any
// wgmma.mma_async instruction that accesses the same registers, either
// as accumulator or input register containing fragments of matrix A,
// except when these are accumulator register accesses across multiple
// wgmma.mma_async instructions of the same shape. In the latter case,
// an ordering guarantee is provided by default.
auto wgmma_fence = IrBuilder::create<kir::WgMmaFence>();

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants