Skip to content

Commit

Permalink
Refactor block sparse attention benchmarks with code style improvements
Browse files Browse the repository at this point in the history
- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks
  • Loading branch information
LeiWang1999 committed Feb 24, 2025
1 parent c2cffd9 commit 789781c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import math
# ruff: noqa
import torch

import torch.nn.functional as F
from tilelang.profiler import do_bench

def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
Expand Down Expand Up @@ -42,6 +40,7 @@ def benchmark_topk_sparse_attention():
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)

import flash_attn

def benchmark_fn():
flash_attn.flash_attn_func(q, k, v, causal=True)

Expand All @@ -50,7 +49,9 @@ def benchmark_fn():
warmup=10,
rep=100,
)
print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}")
print(
f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}"
)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import math
import torch

import tilelang
from tilelang import language as T
from tilelang.profiler import do_bench


def is_hip():
return False


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
Expand Down Expand Up @@ -190,12 +193,14 @@ def benchmark_topk_sparse_attention():
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
program = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=4)

def benchmark_fn():
# Compute reference
# Expand block mask to full attention matrix
Expand All @@ -206,7 +211,9 @@ def benchmark_fn():
warmup=10,
rep=100,
)
print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}")
print(
f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}"
)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import math
import torch

import torch.nn.functional as F
from tilelang.profiler import do_bench


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
Expand Down Expand Up @@ -47,11 +49,11 @@ def benchmark_topk_sparse_attention():
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

def benchmark_fn():
# Compute reference
# Expand block mask to full attention matrix
Expand All @@ -71,7 +73,9 @@ def benchmark_fn():
warmup=10,
rep=100,
)
print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}")
print(
f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}"
)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import math
import torch

import triton
import triton.language as tl
from tilelang.profiler import do_bench


def is_hip():
return False


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
Expand Down Expand Up @@ -286,22 +289,24 @@ def benchmark_topk_sparse_attention():
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

def benchmark_fn():
# Compute reference
# Expand block mask to full attention matrix
block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
block_sparse_triton_fn(q, k, v, block_mask, sm_scale) # noqa: B023

ref_latency = do_bench(
benchmark_fn,
warmup=10,
rep=100,
)
print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}")
print(
f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}"
)


if __name__ == "__main__":
Expand Down

0 comments on commit 789781c

Please sign in to comment.