diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index a928ec1ec..4af31737e 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -551,7 +551,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, - c10::optional gen_) { + c10::optional gen_, + c10::optional &tree_end_position_id_k_, + c10::optional &tree_start_position_id_q_) { auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -591,6 +593,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(tree_start_position_id_q_.has_value() == tree_end_position_id_k_.has_value(), "tree_start_position_id and tree_end_position_id must be passed together"); + if (tree_end_position_id_k_.has_value()) { + const at::Tensor tree_end_position_id_k = tree_end_position_id_k_.value(), tree_start_position_id_q = tree_start_position_id_q_.value(); + TORCH_CHECK(is_causal, "In tree attention, is_causal must be True"); + TORCH_CHECK(window_size_left == -1 && window_size_right == -1, "In tree attention, is_local must be False"); + TORCH_CHECK(!alibi_slopes_.has_value(), "tree attention does not support alibi"); + TORCH_CHECK(tree_start_position_id_q.dtype() == torch::kInt32, "tree_start_position_id_q must have dtype int32"); + TORCH_CHECK(tree_end_position_id_k.dtype() == torch::kInt32, "tree_end_position_id_k must have dtype int32"); + TORCH_CHECK(tree_start_position_id_q.sizes().size() == 1, "tree_start_position_id_q must be 1D tensor"); + TORCH_CHECK(tree_end_position_id_k.sizes().size() == 1, "tree_end_position_id_k must be 1D tensor"); + TORCH_CHECK(tree_start_position_id_q.sizes()[0] == q.sizes()[0], "tree_start_position_id_q and q must have the same length"); + TORCH_CHECK(tree_end_position_id_k.sizes()[0] == k.sizes()[0], "tree_end_position_id_k and k must have the same length"); + CHECK_DEVICE(tree_start_position_id_q); + CHECK_DEVICE(tree_end_position_id_k); + } + const auto sizes = q.sizes(); const int batch_size = cu_seqlens_q.numel() - 1; @@ -770,6 +788,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + if (tree_end_position_id_k_.has_value()) { + params.is_tree_attention = true; + params.tree_end_position_id_k = static_cast(tree_end_position_id_k_.value().data_ptr()); + params.tree_start_position_id_q = static_cast(tree_start_position_id_q_.value().data_ptr()); + } else { + params.is_tree_attention = false; + params.tree_end_position_id_k = nullptr; + params.tree_start_position_id_q = nullptr; + } + if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream, paged_KV); @@ -1062,7 +1090,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const float softcap, const bool deterministic, c10::optional gen_, - c10::optional &rng_state) { + c10::optional &rng_state, + c10::optional &tree_end_position_id_k_, + c10::optional &tree_start_position_id_q_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -1105,6 +1135,20 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(tree_start_position_id_q_.has_value() == tree_end_position_id_k_.has_value(), "tree_start_position_id and tree_end_position_id must be passed together"); + if (tree_end_position_id_k_.has_value()) { + const at::Tensor tree_end_position_id_k = tree_end_position_id_k_.value(), tree_start_position_id_q = tree_start_position_id_q_.value(); + TORCH_CHECK(is_causal, "In tree attention, is_causal must be True"); + TORCH_CHECK(tree_start_position_id_q.dtype() == torch::kInt32, "tree_start_position_id_q must have dtype int32"); + TORCH_CHECK(tree_end_position_id_k.dtype() == torch::kInt32, "tree_end_position_id_k must have dtype int32"); + TORCH_CHECK(tree_start_position_id_q.sizes().size() == 1, "tree_start_position_id_q must be 1D tensor"); + TORCH_CHECK(tree_end_position_id_k.sizes().size() == 1, "tree_end_position_id_k must be 1D tensor"); + TORCH_CHECK(tree_start_position_id_q.sizes()[0] == q.sizes()[0], "tree_start_position_id_q and q must have the same length"); + TORCH_CHECK(tree_end_position_id_k.sizes()[0] == k.sizes()[0], "tree_end_position_id_k and k must have the same length"); + CHECK_DEVICE(tree_start_position_id_q); + CHECK_DEVICE(tree_end_position_id_k); + } + const auto sizes = q.sizes(); const int total_q = sizes[0]; @@ -1270,6 +1314,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + if (tree_end_position_id_k_.has_value()) { + params.is_tree_attention = true; + params.tree_end_position_id_k = static_cast(tree_end_position_id_k_.value().data_ptr()); + params.tree_start_position_id_q = static_cast(tree_start_position_id_q_.value().data_ptr()); + } else { + params.is_tree_attention = false; + params.tree_end_position_id_k = nullptr; + params.tree_start_position_id_q = nullptr; + } + if (max_seqlen_q > 0) { launch(params, stream); } else { diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 6f597fbee..257fe42d3 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -78,6 +78,10 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_k; int * __restrict__ leftpad_k; + // tree attention + int * __restrict__ tree_end_position_id_k; + int * __restrict__ tree_start_position_id_q; + // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; @@ -129,6 +133,7 @@ struct Flash_fwd_params : public Qkv_params { bool is_bf16; bool is_causal; + bool is_tree_attention; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 4f95bd34a..60293b519 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -76,7 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -503,6 +503,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in flash::apply_mask(scores, binfo.actual_seqlen_k, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); } + } else if (Is_tree_attention) { + flash::apply_mask_causal_tree_attention(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + AtomLayoutMS * 16, + params.tree_end_position_id_k + binfo.sum_s_k, + params.tree_start_position_id_q + binfo.sum_s_q); } else if (Is_causal) { // Putting this causal masking right after acc_s is *much* slower for some reason. // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short @@ -510,11 +518,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // But we still want to mask out elements beyond actual_seqlen_k. if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { - flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, - AtomLayoutMS * 16); + flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + AtomLayoutMS * 16); } } else if (Is_local) { if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right @@ -820,7 +828,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -830,7 +838,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 727d87e93..7b6093ffc 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -35,10 +35,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_tree_attention) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - flash::compute_dq_dk_dv_seqk_parallel(params); + flash::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -95,17 +95,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + TREE_ATTENTION_SWITCH(params.is_tree_attention, Is_tree_attention, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 788f3790e..342efda66 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -45,7 +45,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -282,7 +282,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope, binfo.sum_s_k, binfo.sum_s_q); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -323,7 +323,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16, + params.tree_end_position_id_k, params.tree_start_position_id_q ); flash::cp_async_wait<0>(); @@ -337,8 +338,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) - : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 Tensor rP = flash::convert_type(acc_s); @@ -398,10 +399,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16, + params.tree_end_position_id_k, params.tree_start_position_id_q ); - softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; @@ -1069,7 +1071,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1085,7 +1087,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 900cf4671..5301b08b5 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -26,10 +26,10 @@ template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, bool Is_tree_attention) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - flash::compute_attn(params); + flash::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -68,25 +68,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + TREE_ATTENTION_SWITCH(params.is_tree_attention, Is_tree_attention, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 7ba435a37..ccb8325ba 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -80,6 +80,48 @@ __forceinline__ __device__ void apply_mask_causal(Tensor &tensor max_seqlen_q, warp_row_stride, -1, 0); } +template +__forceinline__ __device__ void apply_mask_causal_tree_attention(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int *tree_end_position_id_k, const int *tree_start_position_id_q) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + // apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + // max_seqlen_q, warp_row_stride, -1, 0); + const int window_size_left = -1; + const int window_size_right = 0; + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || row_idx < max_seqlen_q && tree_start_position_id_q[row_idx] > tree_end_position_id_k[col_idx]) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + template __forceinline__ __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, @@ -107,20 +149,24 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( } } -template +template struct Mask { const int max_seqlen_k, max_seqlen_q; const int window_size_left, window_size_right; + const int sum_s_k, sum_s_q; const float alibi_slope; __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, const int window_size_left, const int window_size_right, - const float alibi_slope=0.f) + const float alibi_slope=0.f, + const int sum_s_k = -1, const int sum_s_q = -1) : max_seqlen_k(max_seqlen_k) , max_seqlen_q(max_seqlen_q) , window_size_left(window_size_left) , window_size_right(window_size_right) + , sum_s_k(sum_s_k) + , sum_s_q(sum_s_q) , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { }; @@ -129,17 +175,19 @@ struct Mask { __forceinline__ __device__ void apply_mask(Tensor &tensor_, const int col_idx_offset_, const int row_idx_offset, - const int warp_row_stride) { + const int warp_row_stride, + const int *tree_end_position_id_k = nullptr, + const int *tree_start_position_id_q = nullptr) { static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); static_assert(Layout::rank == 3, "Only support 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); - static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN || Is_tree_attention; // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } if constexpr (Need_masking) { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); // Do we need both row and column indices, or just column incides? - static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask && !Is_tree_attention; const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; if constexpr (Col_idx_only) { @@ -162,6 +210,10 @@ struct Mask { } } } else { + if constexpr (Is_tree_attention) { + tree_end_position_id_k += sum_s_k; + tree_start_position_id_q += sum_s_q; + } #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -194,6 +246,11 @@ struct Mask { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } + if constexpr (Is_tree_attention) { + if (col_idx < max_seqlen_k && row_idx < max_seqlen_q && tree_start_position_id_q[row_idx] > tree_end_position_id_k[col_idx]) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { // Causal and Local already handles MN masking if (col_idx >= max_seqlen_k) { diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index a57702f6c..d9add0966 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -76,6 +76,16 @@ #define LOCAL_SWITCH BOOL_SWITCH #endif +#ifdef FLASHATTENTION_DISABLE_TREE_ATTENTION + #define TREE_ATTENTION_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define TREE_ATTENTION_SWITCH BOOL_SWITCH +#endif + #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \ diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index ecb3515c0..a8710b6b2 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -85,6 +85,8 @@ def _flash_attn_varlen_forward( block_table=None, leftpad_k=None, seqused_k=None, + tree_end_position_id_k=None, + tree_start_position_id_q=None ): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( @@ -109,6 +111,8 @@ def _flash_attn_varlen_forward( softcap, return_softmax, None, + tree_end_position_id_k, + tree_start_position_id_q ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -187,6 +191,8 @@ def _flash_attn_varlen_backward( alibi_slopes, deterministic, rng_state=None, + tree_end_position_id_k=None, + tree_start_position_id_q=None ): # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] @@ -220,6 +226,8 @@ def _flash_attn_varlen_backward( deterministic, None, rng_state, + tree_end_position_id_k, + tree_start_position_id_q ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() @@ -614,6 +622,8 @@ def forward( deterministic, return_softmax, block_table, + tree_end_position_id_k, + tree_start_position_id_q ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -633,9 +643,12 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + tree_end_position_id_k=tree_end_position_id_k, + tree_start_position_id_q=tree_start_position_id_q ) ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, + tree_end_position_id_k, tree_start_position_id_q ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -650,7 +663,7 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, tree_end_position_id_k, tree_start_position_id_q = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, @@ -674,11 +687,13 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + tree_end_position_id_k=tree_end_position_id_k, + tree_start_position_id_q=tree_start_position_id_q ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -1065,6 +1080,8 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, + tree_end_position_id_k=None, + tree_start_position_id_q=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1138,6 +1155,8 @@ def flash_attn_varlen_func( deterministic, return_attn_probs, block_table, + tree_end_position_id_k, + tree_start_position_id_q ) diff --git a/tests/test_flash_attn_tree_attetnion_ck.py b/tests/test_flash_attn_tree_attetnion_ck.py new file mode 100644 index 000000000..4e993b3ec --- /dev/null +++ b/tests/test_flash_attn_tree_attetnion_ck.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.nn.functional as F +from flash_attn import flash_attn_func, flash_attn_varlen_func +from test_flash_attn import attention_ref + +def create_mask(seqlen, lens): + mask = torch.zeros((len(lens), seqlen), dtype=torch.bool, device="cuda") + for i in range(len(lens)): + mask[i, : lens[i]] = 1 + return mask + +def create_pos_id(lens, ids=[0, 1, 2]): + pos_ids = [] + for i in range(len(lens)): + for l, id in zip(lens[i], ids): + pos_ids.append(torch.ones((l, ), dtype=torch.int, device="cuda") * id) + return torch.cat(pos_ids) + +batch_size = 3 +causal = True + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("head_dim", [32, 64, 96, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("heads", [32, 16]) +@pytest.mark.parametrize( + "random_range", + [ + (20, 70), + (128, 217), + (10, 211), + (108, 256), + (256, 512), + ], +) +def test_flash_attn_tree_attention_output(heads, mha_type, head_dim, dtype, random_range): + if mha_type == "mha": + kv_heads = heads + elif mha_type == "gqa": + kv_heads = 1 + else: + kv_heads = 8 + + prompt_lens = torch.randint(*random_range, (batch_size, ), device="cuda").repeat_interleave(2) + half_prompts_lens = prompt_lens.clone() + half_prompts_lens[0::2].zero_() + tot_lens = torch.randint(*random_range, (batch_size * 2, ), device="cuda") + prompt_lens + max_seqlen = tot_lens.max().item() + + mask_ref = create_mask(max_seqlen, tot_lens) + mask = mask_ref & ~create_mask(max_seqlen, half_prompts_lens) + + q = torch.randn((batch_size * 2, max_seqlen, heads, head_dim), dtype=dtype, device="cuda") + k = torch.randn((batch_size * 2, max_seqlen, kv_heads, head_dim), dtype=dtype, device="cuda") + v = torch.randn((batch_size * 2, max_seqlen, kv_heads, head_dim), dtype=dtype, device="cuda") + + for i in range(0, len(prompt_lens), 2): + q[i, :prompt_lens[i]] = q[i+1, :prompt_lens[i]] + k[i, :prompt_lens[i]] = k[i+1, :prompt_lens[i]] + v[i, :prompt_lens[i]] = v[i+1, :prompt_lens[i]] + + lens = torch.column_stack([prompt_lens[::2], tot_lens[::2] - prompt_lens[::2], tot_lens[1::2] - prompt_lens[::2]]) + seqlens = lens.sum(1) + cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0)).int() + max_seqlen_varlen = seqlens.max().item() + + start_pos_id = create_pos_id(lens, [0, 1, 2]) + end_pos_id = create_pos_id(lens, [2, 1, 2]) + + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + output, _ = attention_ref(q, k, v, mask_ref, mask_ref, causal=True) + grad_output = torch.rand_like(output) + grad_output.masked_fill_(~mask[..., None, None], 0) + output.backward(grad_output) + + q_f = q.detach().clone().requires_grad_(True) + k_f = k.detach().clone().requires_grad_(True) + v_f = v.detach().clone().requires_grad_(True) + output_f = flash_attn_func(q_f, k_f, v_f, causal=True) + output_f.backward(grad_output) + + q_pack = q.detach().clone()[mask].requires_grad_(True) + k_pack = k.detach().clone()[mask].requires_grad_(True) + v_pack = v.detach().clone()[mask].requires_grad_(True) + output_pack = flash_attn_varlen_func(q_pack, k_pack, v_pack, + cu_seqlens, cu_seqlens, + max_seqlen_varlen, max_seqlen_varlen, + causal=True, + tree_end_position_id_k=end_pos_id, + tree_start_position_id_q=start_pos_id) + output_pack.backward(grad_output[mask]) + + def merge(x): + x = x.clone().detach() + for i in range(0, len(prompt_lens), 2): + x[i, :prompt_lens[i]] += x[i+1, :prompt_lens[i]] + return x + + def ck(name, x, y, z): + x = x.detach().float() + y = y.detach().float() + z = z.detach().float() + print(f"{name} max diff: {(x - y).abs().max().item()}") + print(f"{name} mean diff: {(x - y).abs().mean().item()}") + + print(f"{name} ref max diff: {(x - z).abs().max().item()}") + print(f"{name} ref mean diff: {(x - z).abs().mean().item()}") + + assert (x - y).abs().max().item() <= 2 * (x - z).abs().max().item() + + ck("Output", output[mask], output_pack, output_f[mask]) + ck("dQ", merge(q.grad)[mask], q_pack.grad, merge(q_f.grad)[mask]) + ck("dK", merge(k.grad)[mask], k_pack.grad, merge(k_f.grad)[mask]) + ck("dV", merge(v.grad)[mask], v_pack.grad, merge(v_f.grad)[mask]) \ No newline at end of file