From 59342bbfa1503194a3fe52f3a3dbfb8446f540a7 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 23 Feb 2025 17:23:05 +0800 Subject: [PATCH] [Release] Bumpy version to v0.1.1 (#107) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove Torch CPP backend and update execution backend options - Remove TorchCPPKernelAdapter and related code from JIT modules - Update execution backend options in jit/__init__.py, kernel.py, and adapter/__init__.py - Remove "torch_cpp" from supported execution backend literals - Simplify backend validation and remove unused torch_cpp-related code 。 * lint fix * Add block sparse attention implementations for TileLang and Triton - Implement block sparse attention kernels for TileLang and Triton - Add example scripts for block sparse attention with top-k and threshold-based masking - Include utility functions for generating sparse attention masks - Demonstrate causal attention with block-level sparsity - Add test cases to validate sparse attention implementations against PyTorch reference * Bump version to 0.1.1 * Refactor block sparse attention examples for improved code quality - Apply consistent code formatting and style in TileLang and Triton block sparse attention implementations - Add ruff linter ignore comment for specific line in Triton implementation - Improve readability by adjusting indentation and line breaks - Standardize sparse mask generation and test function implementations - Minor optimizations in test case configurations * lint --- MANIFEST.in | 1 + VERSION | 2 +- .../block_sparse_attn_tilelang.py | 36 ++- .../block_sparse_attn_triton.py | 259 ++++++++++-------- 4 files changed, 165 insertions(+), 133 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index ba31202..88b2068 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ include CMakeLists.txt include requirements.txt include requirements-test.txt include requirements-dev.txt +include tilelang/jit/adapter/cython/cython_wrapper.pyx recursive-include src * recursive-include 3rdparty * recursive-exclude 3rdparty/clang* * diff --git a/VERSION b/VERSION index 6c6aa7c..6da28dd 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0 \ No newline at end of file +0.1.1 \ No newline at end of file diff --git a/examples/blocksparse_attention/block_sparse_attn_tilelang.py b/examples/blocksparse_attention/block_sparse_attn_tilelang.py index 0237eec..912ec7b 100644 --- a/examples/blocksparse_attention/block_sparse_attn_tilelang.py +++ b/examples/blocksparse_attention/block_sparse_attn_tilelang.py @@ -7,24 +7,28 @@ import tilelang.language as T import torch.nn.functional as F + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask + return dense_mask def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): - dense_mask = x > threshold + dense_mask = x > threshold if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask + return dense_mask def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): @@ -136,7 +140,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype) - + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -165,6 +169,7 @@ def main( return kernel_func(block_M, block_N, num_stages, threads) + def test_topk_sparse_attention(): # Config BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 @@ -177,13 +182,15 @@ def test_topk_sparse_attention(): k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) - sm_scale = 1.0 / (D_HEAD ** 0.5) + sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) - x_ds[:,:,:,0] = 100 + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) # Run Triton kernel @@ -194,25 +201,24 @@ def test_topk_sparse_attention(): # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), - torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal - + # PyTorch reference implementation attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = attn.masked_fill(~full_mask, float('-inf')) attn = F.softmax(attn, dim=-1) ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) - + print("ref_output", ref_output) print("tilelang_output", tilelang_output) - # Verify accuracy assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ "TileLang output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") + if __name__ == "__main__": test_topk_sparse_attention() diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index e459800..907d42d 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa: E712 import math import torch @@ -7,6 +8,7 @@ import triton.language as tl import torch.nn.functional as F + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -15,33 +17,40 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK sparse_index = torch.topk(x, topk, dim=-1).indices - dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) dense_mask.scatter_(-1, sparse_index, True) if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask + return dense_mask def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): - dense_mask = x > threshold + dense_mask = x > threshold if use_dense_for_last_block: - dense_mask[:, :,-2:,:] = True + dense_mask[:, :, -2:, :] = True dense_mask.tril_() - return dense_mask - - + return dense_mask @triton.jit def _fwd_kernel_inner( - acc, l_i, m_i, + acc, + l_i, + m_i, q, k_block_col_idx, block_mask_ptr, - k_ptrs, v_ptrs, - offs_m, offs_n, - stride_kt, stride_vt, stride_bmask_n, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, sm_scale, seqlen_k, past_len, @@ -51,8 +60,8 @@ def _fwd_kernel_inner( ): mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) - # print - + # print + if k_block_col_idx == 3: print("mask_val", mask_val) if mask_val == True: @@ -67,9 +76,9 @@ def _fwd_kernel_inner( qk *= sm_scale # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N - if LAST_K_BLOCK : - qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) - + if LAST_K_BLOCK: + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, + float('-inf')) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -78,7 +87,7 @@ def _fwd_kernel_inner( alpha = tl.exp(m_i - m_ij) l_i = l_i * alpha + l_ij acc = acc * alpha[:, None] - + # update acc v = tl.load(v_ptrs + start_n * stride_vt) @@ -90,21 +99,38 @@ def _fwd_kernel_inner( return acc, l_i, m_i - - @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, + Q, + K, + V, + sm_scale, block_mask_ptr, Out, - stride_qz, stride_qh, stride_qm, stride_qd, - stride_kz, stride_kh, stride_kn, stride_kd, - stride_vz, stride_vh, stride_vn, stride_vd, - stride_bmz, stride_bmh, stride_bmm, stride_bmn, - stride_oz, stride_oh, stride_om, stride_od, - H, N_CTX, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, PAST_LEN, - BLOCK_M: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): @@ -144,13 +170,19 @@ def _fwd_kernel( # loop over k, v and update accumulator for col_idx in range(k_block_start, k_block_end): acc, l_i, m_i = _fwd_kernel_inner( - acc, l_i, m_i, + acc, + l_i, + m_i, q, col_idx, mask_ptrs, - k_ptrs, v_ptrs, - offs_m, offs_n, - stride_kn, stride_vn, stride_bmn, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, sm_scale, N_CTX, PAST_LEN, @@ -162,27 +194,25 @@ def _fwd_kernel( m_i += tl.math.log(l_i) l_recip = 1 / l_i[:, None] acc = acc * l_recip - acc = acc.to(Out.dtype.element_ty) + acc = acc.to(Out.dtype.element_ty) - - off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ + None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) -def _forward( - ctx, - q, - k, - v, - block_sparse_mask, - sm_scale, - BLOCK_M=64, - BLOCK_N=64, - num_warps=None, - num_stages=1, - out=None - ): +def _forward(ctx, + q, + k, + v, + block_sparse_mask, + sm_scale, + BLOCK_M=64, + BLOCK_N=64, + num_warps=None, + num_stages=1, + out=None): assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert k.shape[2] == v.shape[2] @@ -200,19 +230,22 @@ def _forward( N_CTX = k.shape[2] PAST_LEN = N_CTX - q.shape[2] - H = q.shape[1] _fwd_kernel[grid]( - q, k, v, sm_scale, + q, + k, + v, + sm_scale, block_sparse_mask, o, - *q.stride(), - *k.stride(), - *v.stride(), - *block_sparse_mask.stride(), + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), *o.stride(), - H, N_CTX, + H, + N_CTX, PAST_LEN, BLOCK_M, BLOCK_N, @@ -224,8 +257,6 @@ def _forward( return o - - class _sparse_attention(torch.autograd.Function): @staticmethod @@ -239,8 +270,8 @@ def backward(ctx, do): raise NotImplementedError("It does not support gradient propagation yet") return None, None, None, None, None -block_sparse_triton_fn = _sparse_attention.apply +block_sparse_triton_fn = _sparse_attention.apply def test_topk_sparse_attention(): @@ -254,106 +285,100 @@ def test_topk_sparse_attention(): q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) - sm_scale = 1.0 / (D_HEAD ** 0.5) + sm_scale = 1.0 / (D_HEAD**0.5) # Create sparse mask (downsampled to block level) downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) print("downsample_len", downsample_len) - - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) - x_ds[:,:,:,0] = 100 + + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 print("x_ds.shape", x_ds.shape) - block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=downsample_len) + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) # print("block_mask", block_mask) print("block_mask.shape", block_mask.shape) # Run Triton kernel - triton_output = block_sparse_triton_fn( - q, k, v, - block_mask, - sm_scale - ) + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) # Compute reference # Expand block mask to full attention matrix - full_mask = torch.kron(block_mask.float(), - torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal - + # PyTorch reference implementation attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = attn.masked_fill(~full_mask, float('-inf')) attn = F.softmax(attn, dim=-1) ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) - + # print("ref_output", ref_output) # print("triton_output", triton_output) - # Verify accuracy assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ "Triton output doesn't match reference" print("Pass topk sparse attention test with qlen == klen") +def test_topk_sparse_attention_qlt_kl(): + BATCH, N_HEADS = 2, 4 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. + TOPK = 1 + BLOCK = 64 # block size used in downsampling + torch.manual_seed(0) -# def test_topk_sparse_attention_qlt_kl(): -# BATCH, N_HEADS = 2, 4 -# Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. -# TOPK = 1 -# BLOCK = 64 # block size used in downsampling -# torch.manual_seed(0) - -# # Create inputs. -# q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) -# k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) -# v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) -# # softmax scale -# sm_scale = 1.0 / (D_HEAD ** 0.5) + # Create inputs. + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) + # softmax scale + sm_scale = 1.0 / (D_HEAD**0.5) -# downsample_factor = BLOCK -# print("downsample_factor", downsample_factor) -# downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension -# print("downsample_len", downsample_len) -# x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, -# device='cuda', dtype=torch.bfloat16) -# # Force the first column to be high so that the first block is always selected. -# x_ds[:, :, :, 0] = 100 -# block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) -# print("block_mask", block_mask) -# print("block_mask.shape", block_mask.shape) -# # Run Triton kernel. -# triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + downsample_factor = BLOCK + print("downsample_factor", downsample_factor) + downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension + print("downsample_len", downsample_len) + x_ds = torch.randn( + BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16) + # Force the first column to be high so that the first block is always selected. + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + # Run Triton kernel. + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) -# past_len = K_LEN - Q_LEN + past_len = K_LEN - Q_LEN -# attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale -# full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() -# full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] -# effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) -# i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) -# j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) -# causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN) + final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) -# final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + attn = attn.masked_fill(~final_mask, float('-inf')) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) -# attn = attn.masked_fill(~final_mask, float('-inf')) -# attn = F.softmax(attn, dim=-1) -# ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + # Verify accuracy. + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ + "Triton output doesn't match reference when qlen < klen" -# # Verify accuracy. -# assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ -# "Triton output doesn't match reference when qlen < klen" - -# print("Pass topk sparse attention test with qlen < klen") + print("Pass topk sparse attention test with qlen < klen") if __name__ == "__main__": test_topk_sparse_attention() - # test_topk_sparse_attention_qlt_kl() \ No newline at end of file + test_topk_sparse_attention_qlt_kl()