Skip to content

Commit

Permalink
[Release] Bumpy version to v0.1.1 (#107)
Browse files Browse the repository at this point in the history
* Remove Torch CPP backend and update execution backend options

- Remove TorchCPPKernelAdapter and related code from JIT modules
- Update execution backend options in jit/__init__.py, kernel.py, and adapter/__init__.py
- Remove "torch_cpp" from supported execution backend literals
- Simplify backend validation and remove unused torch_cpp-related code
。

* lint fix

* Add block sparse attention implementations for TileLang and Triton

- Implement block sparse attention kernels for TileLang and Triton
- Add example scripts for block sparse attention with top-k and threshold-based masking
- Include utility functions for generating sparse attention masks
- Demonstrate causal attention with block-level sparsity
- Add test cases to validate sparse attention implementations against PyTorch reference

* Bump version to 0.1.1

* Refactor block sparse attention examples for improved code quality

- Apply consistent code formatting and style in TileLang and Triton block sparse attention implementations
- Add ruff linter ignore comment for specific line in Triton implementation
- Improve readability by adjusting indentation and line breaks
- Standardize sparse mask generation and test function implementations
- Minor optimizations in test case configurations

* lint
  • Loading branch information
LeiWang1999 authored Feb 23, 2025
1 parent 0b1bcc5 commit 59342bb
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 133 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ include CMakeLists.txt
include requirements.txt
include requirements-test.txt
include requirements-dev.txt
include tilelang/jit/adapter/cython/cython_wrapper.pyx
recursive-include src *
recursive-include 3rdparty *
recursive-exclude 3rdparty/clang* *
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.0
0.1.1
36 changes: 21 additions & 15 deletions examples/blocksparse_attention/block_sparse_attn_tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,28 @@
import tilelang.language as T
import torch.nn.functional as F


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask


def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :,-2:,:] = True
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
return dense_mask


def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
Expand Down Expand Up @@ -136,7 +140,7 @@ def main(
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)

T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
Expand Down Expand Up @@ -165,6 +169,7 @@ def main(

return kernel_func(block_M, block_N, num_stages, threads)


def test_topk_sparse_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
Expand All @@ -177,13 +182,15 @@ def test_topk_sparse_attention():
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)

sm_scale = 1.0 / (D_HEAD ** 0.5)
sm_scale = 1.0 / (D_HEAD**0.5)

# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16)
x_ds[:,:,:,0] = 100
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

# Run Triton kernel
Expand All @@ -194,25 +201,24 @@ def test_topk_sparse_attention():

# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(),
torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal

# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)

print("ref_output", ref_output)
print("tilelang_output", tilelang_output)


# Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \
"TileLang output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")


if __name__ == "__main__":
test_topk_sparse_attention()
Loading

0 comments on commit 59342bb

Please sign in to comment.