Skip to content

Commit

Permalink
change template parameter for SeqLenTraits for ease of further extension
Browse files Browse the repository at this point in the history
  • Loading branch information
jayhshah committed Aug 19, 2024
1 parent 7aeccc8 commit 62f4fe9
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 24 deletions.
8 changes: 4 additions & 4 deletions hopper/epilogue_fwd_sm90_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ struct CollectiveEpilogueFwd {
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) {
if constexpr (!Seqlen_traits::kUseVarSeqLen) {
if constexpr (!Seqlen_traits::UseVarSeqLen) {
cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
}
}
Expand Down Expand Up @@ -173,7 +173,7 @@ struct CollectiveEpilogueFwd {
);
}
TiledCopyO gmem_tiled_copy_O;
flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
flash::write_O<!Seqlen_traits::UseVarSeqLen, NumCopyThreads>(
epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
Expand Down Expand Up @@ -222,7 +222,7 @@ struct CollectiveEpilogueFwd {
Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
int const seqlen_q = [&] {
if constexpr(Seqlen_traits::kUseVarSeqLen) { return seqlen_traits_q.actual_seq_len; }
if constexpr(Seqlen_traits::UseVarSeqLen) { return seqlen_traits_q.actual_seq_len; }
else { return shape<2>(epilogue_params.layout_LSE); }
}();
if (get<1>(taccOcO_row(_0{})) == 0) {
Expand All @@ -240,7 +240,7 @@ struct CollectiveEpilogueFwd {
}
TiledCopyO gmem_tiled_copy_O;
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
flash::write_O<!Seqlen_traits::UseVarSeqLen, NumCopyThreads>(
epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
Expand Down
10 changes: 5 additions & 5 deletions hopper/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
if ((Is_causal || seqlen_traits_k.UseVarSeqLen) && n_block_max <= 0) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
continue;
Expand Down Expand Up @@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
if ((Is_causal || seqlen_traits_k.UseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
continue;
}
Expand Down Expand Up @@ -205,7 +205,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,

static_assert(Ktraits::Is_WS);
static constexpr bool Is_WS = Ktraits::Is_WS;
static constexpr bool kUseVarSeqLen = Seqlen_traits::kUseVarSeqLen;
static constexpr bool UseVarSeqLen = Seqlen_traits::UseVarSeqLen;

static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
Expand Down Expand Up @@ -293,7 +293,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;

if constexpr(kUseVarSeqLen) {
if constexpr(UseVarSeqLen) {
seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
Expand Down Expand Up @@ -346,7 +346,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;

if constexpr(kUseVarSeqLen) {
if constexpr(UseVarSeqLen) {
seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
Expand Down
12 changes: 6 additions & 6 deletions hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
using Scheduler = std::conditional_t<
Seqlen_traits::kUseVarSeqLen,
Seqlen_traits::UseVarSeqLen,
flash::SingleTileScheduler,
std::conditional_t<!Is_causal,
flash::StaticPersistentTileScheduler,
Expand Down Expand Up @@ -128,7 +128,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Seqlen_traits
Expand All @@ -144,7 +144,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Seqlen_traits
Expand All @@ -165,7 +165,7 @@ void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
!Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
!Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
});
Expand All @@ -184,7 +184,7 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
!Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
!Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
});
Expand All @@ -203,7 +203,7 @@ void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
!Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
!Seqlen_traits::UseVarSeqLen, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
});
Expand Down
8 changes: 4 additions & 4 deletions hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ struct CollectiveMainloopFwd {
) {
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{});
// int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
// int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
// int const seqlen_q = Seqlen_traits::UseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
// int const seqlen_k = Seqlen_traits::UseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
int const seqlen_q = seqlen_traits_q.actual_seq_len;
int const seqlen_k = seqlen_traits_k.actual_seq_len;
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
Expand Down Expand Up @@ -842,8 +842,8 @@ struct CollectiveMainloopFwd {

tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
// workaround for fp8 only perf regression pending change to seqlen traits class
int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
int const seqlen_q = Seqlen_traits::UseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
int const seqlen_k = Seqlen_traits::UseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
int n_block = n_block_count - 1;

cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
Expand Down
12 changes: 7 additions & 5 deletions hopper/seq_len.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ namespace flash {

static constexpr int kMaxTileSize = 128;

template <bool UseVarSeqLen> class SeqLenTraits {
template <int SeqLenType> class SeqLenTraits {
public:
static_assert(SeqLenType == 0 || SeqLenType == 1 || SeqLenType == 2,
"SeqLenType must be 0, 1, or 2");

// Total number of queries / keys. Unpadded.
int sum_s = 0;
// seq len offsets.
Expand All @@ -23,7 +26,7 @@ template <bool UseVarSeqLen> class SeqLenTraits {
int actual_seq_len = -1;

// Whether this is for fixed-seq-len or var-seq-len.
static constexpr bool kUseVarSeqLen = UseVarSeqLen;
static constexpr bool UseVarSeqLen = SeqLenType == 1;

using ShapeT = std::conditional_t<
UseVarSeqLen,
Expand Down Expand Up @@ -103,9 +106,8 @@ template <bool UseVarSeqLen> class SeqLenTraits {
}
};

using FixedSeqLenTraits = SeqLenTraits<false>;

using VarSeqLenTraits = SeqLenTraits<true>;
using FixedSeqLenTraits = SeqLenTraits<0>;
using VarSeqLenTraits = SeqLenTraits<1>;

// Returns the static layout of a var-seq-len tensor in global memory based on
// max_seq_len and max_batch_size.
Expand Down

0 comments on commit 62f4fe9

Please sign in to comment.