Skip to content

Commit

Permalink
Enable hdim=64, hdim_v=512, HasQv=true
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Feb 8, 2025
1 parent 5378bc3 commit 6055113
Show file tree
Hide file tree
Showing 44 changed files with 226 additions and 13 deletions.
13 changes: 10 additions & 3 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,13 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
if (!params.is_e4m3) {
if (params.is_bf16) {
#ifndef FLASHATTENTION_DISABLE_HDIM64
if (params.d <= 64) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
if (params.d <= 64) {
if (params.dv > 64 && Arch == 90) {
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
} else {
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
}
}
#endif
#ifndef FLASHATTENTION_DISABLE_HDIM96
if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
Expand Down Expand Up @@ -580,7 +586,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (head_size_v != head_size) {
TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]");
TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128)
|| (head_size <= 64 && head_size_v <= 512), "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]");
TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
if (head_size_v > 256) {
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
Expand Down Expand Up @@ -757,7 +764,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
}

if (q_v_.has_value()) {
TORCH_CHECK(false, "q_v should be None for now");
// TORCH_CHECK(false, "q_v should be None for now");
TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
"q_v is only supported for fp16 and bf16 data type");
Expand Down
2 changes: 1 addition & 1 deletion hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
// On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS)
static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen;
BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and false;
static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and true;
APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
// Only use Cluster if number of tiles along seqlen_q is even and not varlen
CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
Expand Down
2 changes: 2 additions & 0 deletions hopper/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def get_all_kernels() -> List[Kernel]:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")
if sm == 90 and head_dim == 192:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")
if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")
for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
9 changes: 9 additions & 0 deletions hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
9 changes: 9 additions & 0 deletions hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);
#endif
9 changes: 9 additions & 0 deletions hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
9 changes: 9 additions & 0 deletions hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
9 changes: 9 additions & 0 deletions hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);
#endif
9 changes: 9 additions & 0 deletions hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
// Splitting the different template instantiations to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM64
template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_packgqa_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu"
#include "flash_fwd_hdim96_bf16_packgqa_sm90.cu"
#include "flash_fwd_hdim128_bf16_packgqa_sm90.cu"
#include "flash_fwd_hdim192_bf16_packgqa_sm90.cu"
Expand Down
1 change: 1 addition & 0 deletions hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_paged_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu"
#include "flash_fwd_hdim96_bf16_paged_sm90.cu"
#include "flash_fwd_hdim128_bf16_paged_sm90.cu"
#include "flash_fwd_hdim192_bf16_paged_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_paged_split_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu"
#include "flash_fwd_hdim96_bf16_paged_split_sm90.cu"
#include "flash_fwd_hdim128_bf16_paged_split_sm90.cu"
#include "flash_fwd_hdim192_bf16_paged_split_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu"
Expand Down
1 change: 1 addition & 0 deletions hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_sm90.cu"
#include "flash_fwd_hdim96_bf16_sm90.cu"
#include "flash_fwd_hdim128_bf16_sm90.cu"
#include "flash_fwd_hdim192_bf16_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu"
#include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu"
#include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu"
#include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_softcap_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu"
#include "flash_fwd_hdim96_bf16_softcap_sm90.cu"
#include "flash_fwd_hdim128_bf16_softcap_sm90.cu"
#include "flash_fwd_hdim192_bf16_softcap_sm90.cu"
Expand Down
1 change: 1 addition & 0 deletions hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_split_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_split_sm90.cu"
#include "flash_fwd_hdim96_bf16_split_sm90.cu"
#include "flash_fwd_hdim128_bf16_split_sm90.cu"
#include "flash_fwd_hdim192_bf16_split_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_bf16_split_softcap_sm90.cu"
#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu"
#include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu"
#include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu"
#include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_fp16_packgqa_sm90.cu"
#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu"
#include "flash_fwd_hdim96_fp16_packgqa_sm90.cu"
#include "flash_fwd_hdim128_fp16_packgqa_sm90.cu"
#include "flash_fwd_hdim192_fp16_packgqa_sm90.cu"
Expand Down
1 change: 1 addition & 0 deletions hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_fp16_paged_sm90.cu"
#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu"
#include "flash_fwd_hdim96_fp16_paged_sm90.cu"
#include "flash_fwd_hdim128_fp16_paged_sm90.cu"
#include "flash_fwd_hdim192_fp16_paged_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_fp16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu"
#include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_fp16_paged_split_sm90.cu"
#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu"
#include "flash_fwd_hdim96_fp16_paged_split_sm90.cu"
#include "flash_fwd_hdim128_fp16_paged_split_sm90.cu"
#include "flash_fwd_hdim192_fp16_paged_split_sm90.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu"
#include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu"
Expand Down
1 change: 1 addition & 0 deletions hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_hdim64_fp16_sm90.cu"
#include "flash_fwd_hdim64_512_fp16_sm90.cu"
#include "flash_fwd_hdim96_fp16_sm90.cu"
#include "flash_fwd_hdim128_fp16_sm90.cu"
#include "flash_fwd_hdim192_fp16_sm90.cu"
Expand Down
Loading

0 comments on commit 6055113

Please sign in to comment.