diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 7dad5b9c7..afb86a98e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -271,7 +271,13 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { + if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_(params, stream); } @@ -580,7 +586,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (head_size_v != head_size) { - TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); + TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) + || (head_size <= 64 && head_size_v <= 512), "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); if (head_size_v > 256) { TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, @@ -757,7 +764,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } if (q_v_.has_value()) { - TORCH_CHECK(false, "q_v should be None for now"); + // TORCH_CHECK(false, "q_v should be None for now"); TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index b4f80a04e..26e28f8bd 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -198,7 +198,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and false; + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and true; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index 7a5eb47d0..a80c08957 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -138,6 +138,8 @@ def get_all_kernels() -> List[Kernel]: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") if sm == 90 and head_dim == 192: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu new file mode 100644 index 000000000..2f4ceaaed --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu new file mode 100644 index 000000000..5fd59af34 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu new file mode 100644 index 000000000..e0f885b0f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu new file mode 100644 index 000000000..6dcda0196 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..5d20be6d2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu new file mode 100644 index 000000000..47463a715 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..622b5533c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu new file mode 100644 index 000000000..c83f44722 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu new file mode 100644 index 000000000..5c9130f86 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu new file mode 100644 index 000000000..a152022cb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu new file mode 100644 index 000000000..ef05aa203 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu new file mode 100644 index 000000000..19fe6d94f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu new file mode 100644 index 000000000..6eb2d3d13 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu new file mode 100644 index 000000000..ffbc99821 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..3d35075b4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu new file mode 100644 index 000000000..c2af33cf5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..e07547c92 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu new file mode 100644 index 000000000..1a04eb01f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu new file mode 100644 index 000000000..da9afc115 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu new file mode 100644 index 000000000..5e63a1551 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu index e8ed21cda..8b8ab6552 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu index f7de8fa20..b93a766c0 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu index 64e5ce4a3..3cc545947 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu index 44619cce5..cdad1a517 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu index a05973582..deff0fa0b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu index daea288fe..7a3eae6b0 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_sm90.cu" #include "flash_fwd_hdim96_bf16_sm90.cu" #include "flash_fwd_hdim128_bf16_sm90.cu" #include "flash_fwd_hdim192_bf16_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu index 62640192c..a94bf60d0 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu index 79b0d52fa..d8db86752 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu index 333406cb4..657679ec2 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_split_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_split_sm90.cu" #include "flash_fwd_hdim96_bf16_split_sm90.cu" #include "flash_fwd_hdim128_bf16_split_sm90.cu" #include "flash_fwd_hdim192_bf16_split_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu index b6c1fb54c..1e162b661 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu index 34d176348..0ef211157 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu index 326a2ea90..c490b02a4 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu index a9e032a07..4ec61214a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu index d7cc300b8..38c498e17 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu index fa4de4e29..7b3fe0a50 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu index cb3455866..bf6cf4b57 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_sm90.cu" #include "flash_fwd_hdim96_fp16_sm90.cu" #include "flash_fwd_hdim128_fp16_sm90.cu" #include "flash_fwd_hdim192_fp16_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu index 5dbd70ec5..3a695964a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu index 9a97b9604..71f55e133 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu index 5aacbf026..79bda2b3c 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_split_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_split_sm90.cu" #include "flash_fwd_hdim96_fp16_split_sm90.cu" #include "flash_fwd_hdim128_fp16_split_sm90.cu" #include "flash_fwd_hdim192_fp16_split_sm90.cu" diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu index cfaabd990..ae0968eae 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu @@ -3,6 +3,7 @@ // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index e9cd8c9d6..2a64acb01 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -50,8 +50,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -115,7 +115,8 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + # for dv in [128, d] if d > 128 and d <= 192 else [d]: + for dv in [512]: q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -276,8 +277,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -335,7 +336,8 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + # for dv in [128, d] if d > 128 and d <= 192 else [d]: + for dv in [512]: q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -584,7 +586,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -641,8 +643,10 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else [d] - has_qv_vals = [False] + # dv_vals = [128, d] if d > 128 and d <= 192 else [d] + dv_vals = [512] + # has_qv_vals = [False] + has_qv_vals = [False, True] for dv, has_qv in itertools.product(dv_vals, has_qv_vals): q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: