-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
70 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,74 +152,66 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { | |
TiledMma tiled_mma; | ||
auto thr_mma = tiled_mma.get_slice(tidx); | ||
// GEMM-I: S = [email protected] | ||
auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) | ||
auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) | ||
auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) | ||
auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) | ||
|
||
auto tSrQ_rope = thr_mma.partition_fragment_A(sQ_rope); // (MMA,MMA_M,MMA_K) | ||
auto tSrK_rope = thr_mma.partition_fragment_B(sK_rope); // (MMA,MMA_N,MMA_K) | ||
|
||
// s2r tiled copy for qkv | ||
SmemTiledCopyQ smem_tiled_copy_Q; | ||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); | ||
auto tSsQ = smem_thr_copy_Q.partition_S(sQ); | ||
auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); | ||
auto tCsQ = smem_thr_copy_Q.partition_S(sQ); | ||
auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); | ||
|
||
auto tSsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); | ||
auto tSrQ_rope_copy_view = smem_thr_copy_Q.retile_D(tSrQ_rope); | ||
auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); | ||
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); | ||
|
||
SmemTiledCopyK smem_tiled_copy_K; | ||
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); | ||
auto tSsK = smem_thr_copy_K.partition_S(sK); | ||
auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); | ||
auto tCsK = smem_thr_copy_K.partition_S(sK); | ||
auto tCrK = smem_thr_copy_K.retile_D(tSrK); | ||
|
||
auto tSsK_rope = smem_thr_copy_K.partition_S(sK_rope); | ||
auto tSrK_rope_copy_view = smem_thr_copy_K.retile_D(tSrK_rope); | ||
auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); | ||
auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); | ||
|
||
// S = [email protected] | ||
// tSrAccS: (MMA,MMA_M,MMA_N) | ||
auto compute_qk = [&](auto& tSrAccS) { | ||
// tSrS: (MMA,MMA_M,MMA_N) | ||
auto compute_qk = [&](auto& tSrS) { | ||
// prefetch kv | ||
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, _0{}), tSrQ_copy_view(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_Q, tCsQ(_, _, _0{}), tCrQ(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_K, tCsK(_, _, _0{}), tCrK(_, _, _0{})); | ||
|
||
CUTE_UNROLL | ||
for (int ki = 0; ki < size<2>(tSrQ); ++ki) { | ||
// prefetch next kv | ||
if (ki != size<2>(tSrQ) - 1) { | ||
const auto next_ki = ki + 1; | ||
cute::copy(smem_tiled_copy_Q, | ||
tSsQ(_, _, next_ki), | ||
tSrQ_copy_view(_, _, next_ki)); | ||
cute::copy(smem_tiled_copy_K, | ||
tSsK(_, _, next_ki), | ||
tSrK_copy_view(_, _, next_ki)); | ||
cute::copy(smem_tiled_copy_Q, tCsQ(_, _, next_ki), tCrQ(_, _, next_ki)); | ||
cute::copy(smem_tiled_copy_K, tCsK(_, _, next_ki), tCrK(_, _, next_ki)); | ||
} | ||
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); | ||
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrS); | ||
} | ||
}; | ||
|
||
auto compute_qk_rope = [&](auto& tSrAccS) { | ||
auto compute_qk_rope = [&](auto& tSrS) { | ||
// prefetch qk_rope | ||
cute::copy(smem_tiled_copy_Q, | ||
tSsQ_rope(_, _, _0{}), | ||
tSrQ_rope_copy_view(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_K, | ||
tSsK_rope(_, _, _0{}), | ||
tSrK_rope_copy_view(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{})); | ||
|
||
CUTE_UNROLL | ||
for (int ki = 0; ki < size<2>(tSrQ_rope); ++ki) { | ||
// prefetch next qk_rope | ||
if (ki != size<2>(tSrQ_rope) - 1) { | ||
const auto next_ki = ki + 1; | ||
cute::copy(smem_tiled_copy_Q, | ||
tSsQ_rope(_, _, next_ki), | ||
tSrQ_rope_copy_view(_, _, next_ki)); | ||
tCsQ_rope(_, _, next_ki), | ||
tCrQ_rope(_, _, next_ki)); | ||
cute::copy(smem_tiled_copy_K, | ||
tSsK_rope(_, _, next_ki), | ||
tSrK_rope_copy_view(_, _, next_ki)); | ||
tCsK_rope(_, _, next_ki), | ||
tCrK_rope(_, _, next_ki)); | ||
} | ||
cute::gemm(tiled_mma, tSrQ_rope(_, _, ki), tSrK_rope(_, _, ki), tSrAccS); | ||
cute::gemm(tiled_mma, tSrQ_rope(_, _, ki), tSrK_rope(_, _, ki), tSrS); | ||
} | ||
}; | ||
|
||
|
@@ -228,69 +220,69 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { | |
|
||
SmemTiledCopyVt smem_tiled_copy_Vt; | ||
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); | ||
auto tOsVt = smem_thr_copy_Vt.partition_S(sVt); | ||
auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt); | ||
auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); | ||
auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); | ||
|
||
// O = softmax(S)*V | ||
// tSrAccS: (MMA,MMA_M,MMA_N) | ||
// tSrS: (MMA,MMA_M,MMA_N) | ||
// tOrAccO: (MMA,MMA_M,MMA_K) | ||
auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) { | ||
auto compute_sv = [&](const auto& tSrS, auto& tOrO) { | ||
// cast scores from Accumulator to Element | ||
auto tSrS = make_tensor_like<DType>(tSrAccS); | ||
fast_cast(tSrAccS, tSrS); | ||
auto tSrS_ = make_tensor_like<DType>(tSrS); | ||
fast_cast(tSrS, tSrS_); | ||
|
||
// convert layout from gemm-I C to gemm-II A | ||
auto tOrS = make_tensor(tSrS.data(), Layout::to_mma_a(tSrS.layout())); | ||
auto tOrS = make_tensor(tSrS_.data(), Layout::to_mma_a(tSrS_.layout())); | ||
|
||
// prefetch V^t | ||
cute::copy( | ||
smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_Vt, tCsVt(_, _, _0{}), tCrVt(_, _, _0{})); | ||
CUTE_UNROLL | ||
for (int ki = 0; ki < size<2>(tOrS); ++ki) { | ||
// prefetch next V^t | ||
if (ki != size<2>(tOrS) - 1) { | ||
const auto next_ki = ki + 1; | ||
cute::copy(smem_tiled_copy_Vt, | ||
tOsVt(_, _, next_ki), | ||
tOrVt_copy_view(_, _, next_ki)); | ||
cute::copy( | ||
smem_tiled_copy_Vt, tCsVt(_, _, next_ki), tCrVt(_, _, next_ki)); | ||
} | ||
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); | ||
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrO); | ||
} | ||
}; | ||
|
||
// tOrAccO: (MMA,MMA_M,MMA_K) | ||
auto epilogue = [&](const auto& tOrAccO) { | ||
// tOrO: (MMA,MMA_M,MMA_K) | ||
auto epilogue = [&](const auto& tOrO) { | ||
// write output to gmem | ||
// 1> cast output from ElementAccumulator to Element | ||
auto tOrO = make_tensor_like<DType>(tOrAccO); | ||
fast_cast(tOrAccO, tOrO); | ||
auto tOrO_ = make_tensor_like<DType>(tOrO); | ||
fast_cast(tOrO, tOrO_); | ||
|
||
// 2. copy output from reg to smem (reuse sQ) | ||
auto sO = make_tensor(sQ.data(), SmemLayoutO{}); | ||
|
||
SmemTiledCopyO smem_tiled_copy_O; | ||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); | ||
auto taccOrO = smem_thr_copy_O.retile_S(tOrO); | ||
auto taccOsO = smem_thr_copy_O.partition_D(sO); | ||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); | ||
// 2. copy output from reg to smem (reuse sQ) | ||
{ | ||
SmemTiledCopyO smem_tiled_copy_O; | ||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); | ||
auto tCrO = smem_thr_copy_O.retile_S(tOrO_); | ||
auto tCsO = smem_thr_copy_O.partition_D(sO); | ||
cute::copy(smem_tiled_copy_O, tCrO, tCsO); | ||
} | ||
|
||
// 3. copy output from smem to gmem | ||
GmemTiledCopyO gmem_tiled_copy_O; | ||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); | ||
{ | ||
GmemTiledCopyO gmem_tiled_copy_O; | ||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); | ||
|
||
auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) | ||
auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) | ||
auto tCsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) | ||
auto tCgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) | ||
|
||
// wait for smem copy done before gmem copy | ||
__syncthreads(); | ||
cute::copy(gmem_tiled_copy_O, tOsO, tOgO); | ||
// wait for smem copy done before gmem copy | ||
__syncthreads(); | ||
cute::copy(gmem_tiled_copy_O, tCsO, tCgO); | ||
} | ||
}; | ||
|
||
// output accumulator, (MMA,MMA_M,MMA_K) | ||
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{}); | ||
auto tOrAccO_rc_view = | ||
make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout())); | ||
clear(tOrAccO); | ||
auto tOrO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{}); | ||
auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_rowcol(tOrO.layout())); | ||
clear(tOrO); | ||
|
||
const int n_block_min = 0; | ||
const int n_block_max = cute::ceil_div(kv_len, kBlockN); | ||
|
@@ -311,23 +303,22 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { | |
CUTE_NO_UNROLL | ||
for (int ni = n_block_min; ni < n_block_max; ++ni) { | ||
// attention score accumulator, (MMA,MMA_M,MMA_N) | ||
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); | ||
auto tSrAccS_rc_view = | ||
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); | ||
clear(tSrAccS); | ||
auto tSrS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); | ||
auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_rowcol(tSrS.layout())); | ||
clear(tSrS); | ||
|
||
// wait key, queue: [q, q_rope, kv, k_rope] => [] | ||
cp_async_wait<0>(); | ||
__syncthreads(); | ||
|
||
// 1> S = [email protected] | ||
compute_qk(tSrAccS); | ||
compute_qk(tSrS); | ||
|
||
// 2> S = [email protected] + Q_rope@K_rope.T | ||
compute_qk_rope(tSrAccS); | ||
// 2> S += Q_rope@K_rope.T | ||
compute_qk_rope(tSrS); | ||
|
||
// 3> O = softmax(S)*V | ||
compute_sv(tSrAccS, tOrAccO); | ||
compute_sv(tSrS, tOrO); | ||
|
||
// produce next key: [] => [kv, k_rope] | ||
if (ni != n_block_max - 1) { | ||
|
@@ -339,7 +330,7 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { | |
|
||
// ############### Epilogue ############### | ||
// write output to gmem | ||
epilogue(tOrAccO); | ||
epilogue(tOrO); | ||
} | ||
|
||
template <typename Traits, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters