diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 221fc2ce..487b488d 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -20,6 +20,7 @@ jobs: fetch-depth: 0 - name: Run clang-format + shell: /usr/bin/bash {0} run: | diff=`git-clang-format --extensions="c,h,cc,cp,cpp,c++,cxx,hh,hpp,hxx,inc,cu,cuh" --commit ${{ github.event.pull_request.base.sha }} --diff` [ "$diff" = "no modified files to format" ] && exit 0 diff --git a/src/kernels/attention/cute_extensions.cuh b/src/kernels/attention/cute_extensions.cuh index 36d7bd7a..354a54c8 100644 --- a/src/kernels/attention/cute_extensions.cuh +++ b/src/kernels/attention/cute_extensions.cuh @@ -18,8 +18,45 @@ constexpr bool has_with_bool() .with(declval()))>> = true; + +template +CUTE_HOST_DEVICE constexpr auto with_shape(Layout l, Shape s) { + if constexpr (is_underscore::value) { + return l; + } else { + return l.with_shape(s); + } +} + } // namespace detail +// returns a fragment with the a shape (MMA, mma_m, mma_k) +template +CUTE_HOST_DEVICE constexpr auto partition_fragment_A(const ThrMMA& thr_mma, + ATensor&& atensor, + const ShapeM& mma_m, + const ShapeK& mma_k) { + auto a = thr_mma.partition_A(atensor); + auto l = get_nonswizzle_portion(a.layout()); + auto a_l = make_layout(get<0>(l), + detail::with_shape(get<1>(l), mma_m), + detail::with_shape(get<2>(l), mma_k)); + return thr_mma.make_fragment_A(a_l); +} + +template +CUTE_HOST_DEVICE constexpr auto partition_fragment_B(const ThrMMA& thr_mma, + BTensor&& btensor, + const ShapeN& mma_n, + const ShapeK& mma_k) { + auto b = thr_mma.partition_B(btensor); + auto l = get_nonswizzle_portion(b.layout()); + auto b_l = make_layout(get<0>(l), + detail::with_shape(get<1>(l), mma_n), + detail::with_shape(get<2>(l), mma_k)); + return thr_mma.make_fragment_B(b_l); +} + template CUTE_HOST_DEVICE constexpr auto permute( const ComposedLayout, Offset, LayoutB>& c) { diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index 6c8f05a9..b9406935 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -23,7 +23,8 @@ template -__global__ void mla_kernel_sm80(__grid_constant__ const Params params) { +__global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( + __grid_constant__ const Params params) { using namespace cute; constexpr int kBlockM = Traits::kBlockM; @@ -44,11 +45,13 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // type alias using DType = typename Traits::DType; - using TiledMma = typename Traits::TiledMma; + using TiledMma_QK = typename Traits::TiledMma_QK; + using TiledMma_PV = typename Traits::TiledMma_PV; using Layout = typename Traits::LayoutConvertor; using SmemLayoutQ = typename Traits::SmemLayoutQ; using SmemLayoutKV = typename Traits::SmemLayoutKV; + using SmemLayoutP = typename Traits::SmemLayoutP; using SmemLayoutQRope = typename Traits::SmemLayoutQRope; using SmemLayoutKRope = typename Traits::SmemLayoutKRope; using SmemLayoutVt = typename Traits::SmemLayoutVt; @@ -60,33 +63,41 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; using SmemTiledCopyK = typename Traits::SmemTiledCopyK; + using SmemTiledCopyS = typename Traits::SmemTiledCopyS; + using SmemTiledCopyP = typename Traits::SmemTiledCopyP; using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + const int m_block_idx = blockIdx.x; + const int batch_idx = blockIdx.y; + const int tidx = threadIdx.x; + MLATile tile(params); // ProblemShape // Q/O: (q_packed_len, HEAD_DIM) // KV: (kv_len, HEAD_DIM) // Q/K_ROPE: (q_packed_len, ROPE_HEAD_DIM) - auto [Q, Q_ROPE, O] = tile.template get_qo_tile(blockIdx.y); - auto [KV, K_ROPE] = tile.template get_kv_tile(blockIdx.y); + auto [Q, Q_ROPE, O] = tile.template get_qo_tile(batch_idx); + auto [KV, K_ROPE] = tile.template get_kv_tile(batch_idx); - if (blockIdx.x * kBlockM >= size<0>(Q)) { + if (m_block_idx * kBlockM >= size<0>(Q)) { // m out of bound, return return; } // Gmem // (BLK_M, BLK_K, STAGES) - Tensor gQ = local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(blockIdx.x, _)); - Tensor gO = local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(blockIdx.x, _)); + Tensor gQ = + local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); + Tensor gO = + local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); // (BLK_N, BLK_K, n, STAGES) Tensor gKV = local_tile(KV, Shape<_BLK_N, _BLK_K>{}, make_coord(_, _)); // (BLK_M, ROPE_HEAD_DIM) Tensor gQ_rope = local_tile( - Q_ROPE, Shape<_BLK_M, _ROPE_HEAD_DIM>{}, make_coord(blockIdx.x, _0{})); + Q_ROPE, Shape<_BLK_M, _ROPE_HEAD_DIM>{}, make_coord(m_block_idx, _0{})); // (BLK_N, ROPE_HEAD_DIM, n) Tensor gK_rope = local_tile(K_ROPE, Shape<_BLK_N, _ROPE_HEAD_DIM>{}, make_coord(_, _0{})); @@ -95,7 +106,8 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { extern __shared__ char smem[]; DType* q_smem = (DType*)smem; DType* kv_smem = q_smem + cosize(SmemLayoutQ{}); - DType* q_rope_smem = kv_smem + cosize(SmemLayoutKV{}); + DType* p_smem = kv_smem + cosize(SmemLayoutKV{}); + DType* q_rope_smem = p_smem + cosize(SmemLayoutP{}); DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{}); // (BLK_M, BLK_K, STAGES), k-major @@ -103,6 +115,9 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // (BLK_N, BLK_K, STAGES), k-major Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{}); + // (BLK_M, BLK_N), k-major + Tensor sP = make_tensor(make_smem_ptr(p_smem), SmemLayoutP{}); + // (BLK_M, ROPE_HEAD_DIM), k-major Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{}); // (BLK_N, ROPE_HEAD_DIM), k-major @@ -112,12 +127,15 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // (BLK_K, BLK_N, STAGES) Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{}); + // (BLK_M, BLK_K, STAGES), reuse smem + Tensor sO = make_tensor(make_smem_ptr(q_smem), SmemLayoutO{}); + // Tiled Copy // g2s tiled copy for qkv GmemTiledCopyQ gmem_tiled_copy_Q; GmemTiledCopyKV gmem_tiled_copy_KV; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(threadIdx.x); - auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(threadIdx.x); + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx); + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx); auto produce_q = [&](int stage) { // gQ/sQ: (BLK_M, BLK_K, STAGES) @@ -151,137 +169,182 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { cp_async_fence(); }; - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_slice(threadIdx.x); + TiledMma_QK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_slice(tidx); // GEMM-I: S = Q@K.T - // gQ/sQ: (BLK_M, BLK_K, STAGES) - auto tSrQ = thr_mma.partition_fragment_A(sQ(_, _, _0{})); - auto tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{})); + // sQ/sK: (BLK_M, BLK_K, STAGES) + auto tSrQ = partition_fragment_A( + thr_mma_qk, sQ(_, _, _0{}), _, _2{}); // (MMA, MMA_M, _2) + auto tSrK = partition_fragment_B( + thr_mma_qk, sK(_, _, _0{}), _, _2{}); // (MMA, MMA_N, _2) - auto tSrQ_rope = thr_mma.partition_fragment_A(sQ_rope); - auto tSrK_rope = thr_mma.partition_fragment_B(sK_rope); + auto tSrQ_rope = + partition_fragment_A(thr_mma_qk, sQ_rope, _, _2{}); // (MMA, MMA_M, _2) + auto tSrK_rope = + partition_fragment_B(thr_mma_qk, sK_rope, _, _2{}); // (MMA, MMA_N, _2) // s2r tiled copy for qkv SmemTiledCopyQ smem_tiled_copy_Q; - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(threadIdx.x); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx); // (CPY, CPY_M, CPY_K, STAGES) auto tCsQ = smem_thr_copy_Q.partition_S(sQ); - // (CPY, CPY_M, CPY_K) + // (CPY, CPY_M, _2) auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); + // (CPY, CPY_M, CPY_K) auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); + // (CPY, CPY_M, _2) auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); SmemTiledCopyK smem_tiled_copy_K; - auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(threadIdx.x); + auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx); // (CPY, CPY_N, CPY_K, STAGES) auto tCsK = smem_thr_copy_K.partition_S(sK); - // (CPY, CPY_M, CPY_K) + // (CPY, CPY_N, _2) auto tCrK = smem_thr_copy_K.retile_D(tSrK); + // (CPY, CPY_N, CPY_K) auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); + // (CPY, CPY_N, _2) auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); // S = Q@K.T // tSrS: (MMA,MMA_M,MMA_N) - auto compute_qk = [&](auto& tSrS, int stage) { + auto compute_qk = [&](auto& tSrS, int s) { // (CPY, CPY_M, CPY_K, STAGES) - auto tCsQ_ = tCsQ(_, _, _, stage); - auto tCsK_ = tCsK(_, _, _, stage); + auto tCsQ_s = tCsQ(_, _, _, s); + auto tCsK_s = tCsK(_, _, _, s); // prefetch kv - cute::copy(smem_tiled_copy_Q, tCsQ_(_, _, _0{}), tCrQ(_, _, _0{})); - cute::copy(smem_tiled_copy_K, tCsK_(_, _, _0{}), tCrK(_, _, _0{})); + cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, _0{}), tCrQ(_, _, _0{})); + cute::copy(smem_tiled_copy_K, tCsK_s(_, _, _0{}), tCrK(_, _, _0{})); CUTE_UNROLL - for (int ki = 0; ki < size<2>(tSrQ); ++ki) { + for (int k = 0; k < size<2>(tCsQ_s); ++k) { // prefetch next kv - if (ki != size<2>(tSrQ) - 1) { - const auto next_ki = ki + 1; + if (k != size<2>(tCsQ_s) - 1) { + const auto next_k = k + 1; cute::copy( - smem_tiled_copy_Q, tCsQ_(_, _, next_ki), tCrQ(_, _, next_ki)); + smem_tiled_copy_Q, tCsQ_s(_, _, next_k), tCrQ(_, _, (next_k & 1))); cute::copy( - smem_tiled_copy_K, tCsK_(_, _, next_ki), tCrK(_, _, next_ki)); + smem_tiled_copy_K, tCsK_s(_, _, next_k), tCrK(_, _, (next_k & 1))); } - cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrS); + cute::gemm(tiled_mma_qk, tSrQ(_, _, (k & 1)), tSrK(_, _, (k & 1)), tSrS); } }; auto compute_qk_rope = [&](auto& tSrS) { - // prefetch qk_rope + // tCsQ_rope: (CPY, CPY_M, CPY_K) => tCrQ_rope: (CPY, CPY_M, _2) cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{})); CUTE_UNROLL - for (int ki = 0; ki < size<2>(tSrQ_rope); ++ki) { - // prefetch next qk_rope - if (ki != size<2>(tSrQ_rope) - 1) { - const auto next_ki = ki + 1; + for (int k = 0; k < size<2>(tCsQ_rope); ++k) { + if (k != size<2>(tCsQ_rope) - 1) { + const auto next_k = k + 1; cute::copy(smem_tiled_copy_Q, - tCsQ_rope(_, _, next_ki), - tCrQ_rope(_, _, next_ki)); + tCsQ_rope(_, _, next_k), + tCrQ_rope(_, _, (next_k & 1))); cute::copy(smem_tiled_copy_K, - tCsK_rope(_, _, next_ki), - tCrK_rope(_, _, next_ki)); + tCsK_rope(_, _, next_k), + tCrK_rope(_, _, (next_k & 1))); } - cute::gemm(tiled_mma, tSrQ_rope(_, _, ki), tSrK_rope(_, _, ki), tSrS); + cute::gemm(tiled_mma_qk, + tSrQ_rope(_, _, (k & 1)), + tSrK_rope(_, _, (k & 1)), + tSrS); } }; // GEMM-II: O = softmax(S)@V - // (MMA, MMA_M, MMA_N) - auto tOrVt = thr_mma.partition_fragment_B(sVt(_, _, _0{})); + TiledMma_PV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_slice(tidx); + // sS: (BLK_M, BLK_N) + // (MMA, MMA_M, _2) + auto tOrP = partition_fragment_A(thr_mma_pv, sP, _, _2{}); + // sVt: (BLK_K, BLK_N, STAGES) + // (MMA, MMA_N, _2) + auto tOrVt = partition_fragment_B(thr_mma_pv, sVt(_, _, _0{}), _, _2{}); + + SmemTiledCopyP smem_tiled_copy_P; + auto smem_thr_copy_P = smem_tiled_copy_P.get_slice(tidx); + // (CPY, CPY_M, CPY_K) + auto tCsP = smem_thr_copy_P.partition_S(sP); + // (CPY, CPY_M, _2) + auto tCrP = smem_thr_copy_P.retile_D(tOrP); SmemTiledCopyVt smem_tiled_copy_Vt; - auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(threadIdx.x); + auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx); // (CPY, CPY_N, CPY_K, STAGES) auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); - // (CPY, CPY_M, CPY_K) + // (CPY, CPY_N, _2) auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); // O = softmax(S)*V // tOrS: (MMA,MMA_M,MMA_K) - // tOrO: (MMA,MMA_M,MMA_N, STAGES) - auto compute_sv = [&](const auto& tOrS, auto& tOrO, int stage) { + auto compute_pv = [&](auto& tOrO, int s) { + // (MMA,MMA_M,MMA_N, STAGES) + auto tOrO_s = tOrO(_, _, _, s); + // (CPY, CPY_N, CPY_K, STAGES) - auto tCsVt_ = tCsVt(_, _, _, stage); - auto tOrO_ = tOrO(_, _, _, stage); - // prefetch V^t - cute::copy(smem_tiled_copy_Vt, tCsVt_(_, _, _0{}), tCrVt(_, _, _0{})); + auto tCsVt_s = tCsVt(_, _, _, s); + // tCsVt_s: (CPY, CPY_N, CPY_K) => tCrVt: (CPY, CPY_N, _2) + cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{})); + cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{})); + CUTE_UNROLL - for (int ki = 0; ki < size<2>(tOrS); ++ki) { - // prefetch next V^t - if (ki != size<2>(tOrS) - 1) { - const auto next_ki = ki + 1; + for (int k = 0; k < size<2>(tCsVt_s); ++k) { + if (k != size<2>(tCsVt_s) - 1) { + const auto next_k = k + 1; cute::copy( - smem_tiled_copy_Vt, tCsVt_(_, _, next_ki), tCrVt(_, _, next_ki)); + smem_tiled_copy_P, tCsP(_, _, next_k), tCrP(_, _, (next_k & 1))); + cute::copy(smem_tiled_copy_Vt, + tCsVt_s(_, _, next_k), + tCrVt(_, _, (next_k & 1))); } - cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrO_); + cute::gemm( + tiled_mma_pv, tOrP(_, _, (k & 1)), tOrVt(_, _, (k & 1)), tOrO_s); } }; + SmemTiledCopyS smem_tiled_copy_S; + auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(tidx); + + auto save_scores = [&](const auto& tSrS) { + // cast Accumulator to Element type + auto tSrS_ = make_tensor_like(tSrS); + fast_cast(tSrS, tSrS_); + // copy scores from rmem to smem + auto tCrS = smem_thr_copy_S.retile_S(tSrS_); + auto tCsS = smem_thr_copy_S.partition_D(sP); + cute::copy(smem_tiled_copy_S, tCrS, tCsS); + }; + // tOrO: (MMA,MMA_M,MMA_K,STAGES) auto epilogue = [&](const auto& tOrO) { // write output to gmem - // 1. cast output from ElementAccumulator to Element - auto tOrO_ = make_tensor_like(tOrO); - fast_cast(tOrO, tOrO_); + // 1. copy output from reg to smem (reuse sQ) + SmemTiledCopyO smem_tiled_copy_O; + auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx); + CUTE_UNROLL + for (int s = 0; s < kStages; ++s) { + auto tOrO_s = tOrO(_, _, _, s); + auto sO_s = sO(_, _, s); + + // cast Accumulator to Element type + auto tOrO_ = make_tensor_like(tOrO_s); + fast_cast(tOrO_s, tOrO_); - auto sO = make_tensor(sQ.data(), SmemLayoutO{}); - // 2. copy output from reg to smem (reuse sQ) - { - SmemTiledCopyO smem_tiled_copy_O; - auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(threadIdx.x); auto tCrO = smem_thr_copy_O.retile_S(tOrO_); - auto tCsO = smem_thr_copy_O.partition_D(sO); + auto tCsO = smem_thr_copy_O.partition_D(sO_s); cute::copy(smem_tiled_copy_O, tCrO, tCsO); } // wait for smem copy done before gmem copy __syncthreads(); - // 3. copy output from smem to gmem + // 2. copy output from smem to gmem { GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(threadIdx.x); + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx); auto tCsO = gmem_thr_copy_O.partition_S(sO); auto tCgO = gmem_thr_copy_O.partition_D(gO); @@ -290,7 +353,8 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { }; // output accumulator: (MMA, MMA_M, MMA_K, STAGES) - auto tOrO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_K, _STAGES>{}); + auto tOrO = + partition_fragment_C(thr_mma_pv, Shape<_BLK_M, _BLK_K, _STAGES>{}); auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mns(tOrO.layout())); clear(tOrO); @@ -320,7 +384,7 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { CUTE_NO_UNROLL for (int ni = n_block_min; ni < n_block_max; ++ni) { // attention score accumulator, (MMA,MMA_M,MMA_N) - auto tSrS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + auto tSrS = partition_fragment_C(tiled_mma_qk, Shape<_BLK_M, _BLK_N>{}); auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_mn(tSrS.layout())); clear(tSrS); @@ -344,18 +408,24 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { softmax.rescale(tSrS_mn, tOrO_mn); + // save tSrS from rmem to smem + save_scores(tSrS); + __syncthreads(); + // 3> O = softmax(S)*V - // cast scores from Accumulator to Element - auto tSrS_ = make_tensor_like(tSrS); - fast_cast(tSrS, tSrS_); - // convert layout from gemm-I C to gemm-II A - auto tOrS = make_tensor(tSrS_.data(), Layout::to_mma_a(tSrS_.layout())); - const auto next_ni = (ni + 1 < n_block_max) ? ni + 1 : ni; - produce_k_rope(next_ni); - CUTE_UNROLL - for (int s = 0; s < kStages; ++s) { - compute_sv(tOrS, tOrO, s); - produce_kv(next_ni, s); + const auto next_ni = ni + 1; + if (next_ni != n_block_max) { + produce_k_rope(next_ni); + CUTE_UNROLL + for (int s = 0; s < kStages; ++s) { + compute_pv(tOrO, s); + produce_kv(next_ni, s); + } + } else { + CUTE_UNROLL + for (int s = 0; s < kStages; ++s) { + compute_pv(tOrO, s); + } } } diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu index 37a75b71..f24683ac 100644 --- a/src/kernels/attention/mla_sm80_bench.cu +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -14,28 +14,18 @@ using namespace llm; [&] { \ if (HEAD_DIM_V <= 64) { \ constexpr static int HEAD_DIM_NAME = 64; \ - constexpr static int BLK_N = 64; \ - constexpr static int BLK_K = 64; \ return __VA_ARGS__(); \ } else if (HEAD_DIM_V <= 128) { \ constexpr static int HEAD_DIM_NAME = 128; \ - constexpr static int BLK_N = 64; \ - constexpr static int BLK_K = 128; \ return __VA_ARGS__(); \ } else if (HEAD_DIM_V <= 256) { \ constexpr static int HEAD_DIM_NAME = 256; \ - constexpr static int BLK_N = 64; \ - constexpr static int BLK_K = 128; \ return __VA_ARGS__(); \ } else if (HEAD_DIM_V <= 384) { \ constexpr static int HEAD_DIM_NAME = 384; \ - constexpr static int BLK_N = 64; \ - constexpr static int BLK_K = 128; \ return __VA_ARGS__(); \ } else if (HEAD_DIM_V <= 512) { \ constexpr static int HEAD_DIM_NAME = 512; \ - constexpr static int BLK_N = 32; \ - constexpr static int BLK_K = 128; \ return __VA_ARGS__(); \ } else { \ assert(false); \ @@ -97,8 +87,8 @@ void mla_bench_sm80(nvbench::state& state) { HEAD_DIM, /*ROPE_HEAD_DIM=*/64, /*BLK_M=*/64, - BLK_N, - BLK_K>; + /*BLK_N=*/64, + /*BLK_K=*/64>; launch_mla_kernel_sm80(params, launch.get_stream()); }); diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h index 92c44895..ee6c325a 100644 --- a/src/kernels/attention/mla_traits_sm80.h +++ b/src/kernels/attention/mla_traits_sm80.h @@ -74,9 +74,16 @@ struct MLATraitsSM80 { std::conditional_t, MMA_Atom, MMA_Atom>; - using TiledMma = TiledMMA>, // warp layout 4x1x1 - Tile<_64, _16, _16>>; // Prom Shape 64x16x16 + using TiledMma_QK = TiledMMA>, // warp layout 4x1x1 + Tile<_64, _16, _16>>; // Prom Shape 64x16x16 + + using TiledMma_PV = TiledMMA>, // warp layout 4x1x1 + Tile<_64, _16, _16>>; // Prom Shape 64x16x16 + + // use 128-bit vectorizing copy + using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; // Layout convertor for TiledMMA (64x16x16) using LayoutConvertor = detail::LayoutConvertor; @@ -96,6 +103,10 @@ struct MLATraitsSM80 { decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _BLK_K, _STAGES>{})); + // P smem: (BLK_M, BLK_N) + using SmemLayoutP = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_N>{})); + // V^T smem: (BLK_K, BLK_N, STAGES) using SmemLayoutVt = decltype(permute<1, 0, 2>(SmemLayoutKV{})); @@ -126,18 +137,31 @@ struct MLATraitsSM80 { // g2s tiled copy for kv using GmemTiledCopyKV = GmemTiledCopyQ; - // s2r tiled copy for gemm-I + // s2r tiled copy for gemm-I S = Q*K^T using SmemTiledCopyQ = decltype(make_tiled_copy_A(Copy_Atom{}, - TiledMma{})); + TiledMma_QK{})); using SmemTiledCopyK = decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma{})); + TiledMma_QK{})); - // s2r tiled copy for gemm-II + // r2s tiled copy for gemm-I S + using SmemTiledCopyS = + decltype(make_tiled_copy_C(Copy_Atom{}, + TiledMma_QK{})); + + // s2r tiled copy for gemm-II: O = P*V^T + using SmemTiledCopyP = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma_PV{})); using SmemTiledCopyVt = decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma{})); + TiledMma_PV{})); + + // r2s tiled copy for gemm-II O + using SmemTiledCopyO = + decltype(make_tiled_copy_C(Copy_Atom{}, + TiledMma_PV{})); // ******* Epilogue ******* @@ -145,9 +169,6 @@ struct MLATraitsSM80 { using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_K, _STAGES>{})); - // use 128-bit vectorizing copy - using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; - // s2g tiled copy for O using GmemTiledCopyO = decltype(make_tiled_copy( Copy_Atom{}, @@ -155,17 +176,14 @@ struct MLATraitsSM80 { Layout>{} // Val layout: 8 vals per read )); - // r2s tiled copy for O - using SmemTiledCopyO = - decltype(make_tiled_copy_C(Copy_Atom{}, - TiledMma{})); - // constexpr values for kernel launch static constexpr size_t kSmemSize = - sizeof(DType) * (cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{}) + - cosize(SmemLayoutQRope{}) + cosize(SmemLayoutKRope{})); + sizeof(DType) * + (cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{}) + cosize(SmemLayoutP{}) + + cosize(SmemLayoutQRope{}) + cosize(SmemLayoutKRope{})); - static constexpr size_t kThreadNum = size(TiledMma{}); + static constexpr size_t kThreadNum = + std::max(size(TiledMma_QK{}), size(TiledMma_PV{})); }; } // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_traits_test.cpp b/src/kernels/attention/mla_traits_test.cpp index 94805893..c094d6be 100644 --- a/src/kernels/attention/mla_traits_test.cpp +++ b/src/kernels/attention/mla_traits_test.cpp @@ -13,7 +13,7 @@ using namespace cute; template void test_mla_traits() { // type alias - using TiledMma = typename Traits::TiledMma; + using TiledMma_QK = typename Traits::TiledMma_QK; using Layout = typename Traits::LayoutConvertor; using SmemLayoutQ = typename Traits::SmemLayoutQ; @@ -47,9 +47,9 @@ void test_mla_traits() { // print("sQ_rope:"); print(sQ_rope);print("\n"); // print("sKV_rope:"); print(sKV_rope);print("\n"); - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_slice(0); - auto tOrVt = thr_mma.partition_fragment_B(sVt); + TiledMma_QK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_slice(0); + // auto tOrVt = thr_mma_qk.partition_fragment_B(sVt); // TODO: add tests for layout conformance } @@ -58,8 +58,8 @@ TEST(MLATraitsTest, TraitsSM80) { /*HEAD_DIM=*/256, /*ROPE_HEAD_DIM=*/64, /*BLK_M=*/64, - /*BLK_N=*/32, - /*BLK_K=*/256>>(); + /*BLK_N=*/64, + /*BLK_K=*/64>>(); } } // namespace llm \ No newline at end of file