Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Feb 23, 2025
1 parent 85b8e61 commit b102597
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions examples/blocksparse_attention/block_sparse_attn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,14 @@ def test_topk_sparse_attention_qlt_kl():
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# softmax scale
sm_scale = 1.0 / (D_HEAD ** 0.5)
sm_scale = 1.0 / (D_HEAD**0.5)

downsample_factor = BLOCK
print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len)
x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len,
device='cuda', dtype=torch.bfloat16)
x_ds = torch.randn(
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
Expand All @@ -363,7 +363,7 @@ def test_topk_sparse_attention_qlt_kl():
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)

i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)

final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
Expand All @@ -378,6 +378,7 @@ def test_topk_sparse_attention_qlt_kl():

print("Pass topk sparse attention test with qlen < klen")


if __name__ == "__main__":
test_topk_sparse_attention()
test_topk_sparse_attention_qlt_kl()

0 comments on commit b102597

Please sign in to comment.