-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add DeepSeek MLA decode example with Flash Attention implementation * Add GEMM SplitK and StreamK example implementations This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations. * Refactor GEMM SplitK and StreamK example implementations Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity * Add block sparse attention benchmarks for multiple libraries This commit introduces comprehensive block sparse attention benchmarks for different libraries: - TileLang block sparse FMHA implementation - Triton block sparse FMHA implementation - PyTorch reference block sparse FMHA implementation - FlashAttention dense FMHA reference implementation The benchmarks include: - Configurable benchmark parameters (batch size, heads, sequence length, etc.) - Sparse mask generation using top-k and threshold methods - Performance measurement for different sparse attention configurations - Utility functions for mask generation and benchmarking * Refactor block sparse attention benchmarks with code style improvements - Add Ruff linter ignore comments to benchmark files - Improve code formatting and line breaks - Remove unused imports - Standardize print statement formatting - Enhance code readability across multiple library benchmarks * lint fix * Add CUDA atomic operations for BFLOAT16 and update function naming - Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header - Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd) - Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values - Update kernel and language customization to use new function names - Add return type annotations in profiler module * lint fix * Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates: - Group Query Attention (GQA) implementation - Flash Attention forward pass - Performance benchmarking - Configurable parameters for batch, heads, sequence length, and dimension - Autotuning support - Reference implementation comparison * Refactor IR lowering pipeline into modular phases This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases: - `LowerAndLegalize`: Handles initial IR legalization and transformation - `OptimizeForTarget`: Applies target-specific optimizations The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability. * lintfix
- Loading branch information
1 parent
540aef4
commit e93c7c4
Showing
5 changed files
with
348 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import tilelang | ||
from tilelang import Profiler | ||
from tilelang.autotuner import * | ||
import tilelang.language as T | ||
import itertools | ||
import argparse | ||
from functools import partial | ||
|
||
|
||
def get_configs(): | ||
block_M = [128] | ||
block_N = [128] | ||
num_stages = [2] | ||
threads = [256] | ||
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) | ||
|
||
configs = [{ | ||
'block_M': c[0], | ||
'block_N': c[1], | ||
'num_stages': c[2], | ||
'threads': c[3] | ||
} for c in _configs] | ||
return configs | ||
|
||
|
||
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): | ||
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) | ||
head_kv = heads // groups | ||
q_shape = [batch, seq_len, heads, dim] | ||
kv_shape = [batch, seq_len, head_kv, dim] | ||
dtype = "float16" | ||
accum_dtype = "float" | ||
|
||
def kernel_func(block_M, block_N, num_stages, threads): | ||
|
||
@T.macro | ||
def MMA0( | ||
K: T.Buffer(kv_shape, dtype), | ||
Q_shared: T.Buffer([block_M, dim], dtype), | ||
K_shared: T.Buffer([block_N, dim], dtype), | ||
acc_s: T.Buffer([block_M, block_N], accum_dtype), | ||
k: T.int32, | ||
bx: T.int32, | ||
by: T.int32, | ||
bz: T.int32, | ||
): | ||
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) | ||
if is_causal: | ||
for i, j in T.Parallel(block_M, block_N): | ||
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, | ||
-T.infinity(acc_s.dtype)) | ||
else: | ||
T.clear(acc_s) | ||
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
|
||
@T.macro | ||
def MMA1( | ||
V: T.Buffer(kv_shape, dtype), | ||
V_shared: T.Buffer([block_M, dim], dtype), | ||
acc_s_cast: T.Buffer([block_M, block_N], dtype), | ||
acc_o: T.Buffer([block_M, dim], accum_dtype), | ||
k: T.int32, | ||
by: T.int32, | ||
bz: T.int32, | ||
): | ||
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) | ||
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) | ||
|
||
@T.macro | ||
def Softmax( | ||
acc_s: T.Buffer([block_M, block_N], accum_dtype), | ||
acc_s_cast: T.Buffer([block_M, block_N], dtype), | ||
scores_max: T.Buffer([block_M], accum_dtype), | ||
scores_max_prev: T.Buffer([block_M], accum_dtype), | ||
scores_scale: T.Buffer([block_M], accum_dtype), | ||
scores_sum: T.Buffer([block_M], accum_dtype), | ||
logsum: T.Buffer([block_M], accum_dtype), | ||
): | ||
T.copy(scores_max, scores_max_prev) | ||
T.fill(scores_max, -T.infinity(accum_dtype)) | ||
T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
# To do causal softmax, we need to set the scores_max to 0 if it is -inf | ||
# This process is called Check_inf in FlashAttention3 code, and it only need to be done | ||
# in the first ceil_div(kBlockM, kBlockN) steps. | ||
# for i in T.Parallel(block_M): | ||
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) | ||
for i in T.Parallel(block_M): | ||
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
for i, j in T.Parallel(block_M, block_N): | ||
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) - | ||
# max * log_2(e)) This allows the compiler to use the ffma | ||
# instruction instead of fadd and fmul separately. | ||
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) | ||
T.reduce_sum(acc_s, scores_sum, dim=1) | ||
for i in T.Parallel(block_M): | ||
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] | ||
T.copy(acc_s, acc_s_cast) | ||
|
||
@T.macro | ||
def Rescale( | ||
acc_o: T.Buffer([block_M, dim], accum_dtype), | ||
scores_scale: T.Buffer([block_M], accum_dtype), | ||
): | ||
for i, j in T.Parallel(block_M, dim): | ||
acc_o[i, j] *= scores_scale[i] | ||
|
||
@T.prim_func | ||
def main( | ||
Q: T.Buffer(q_shape, dtype), | ||
K: T.Buffer(kv_shape, dtype), | ||
V: T.Buffer(kv_shape, dtype), | ||
Output: T.Buffer(q_shape, dtype), | ||
): | ||
with T.Kernel( | ||
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): | ||
Q_shared = T.alloc_shared([block_M, dim], dtype) | ||
K_shared = T.alloc_shared([block_N, dim], dtype) | ||
V_shared = T.alloc_shared([block_N, dim], dtype) | ||
O_shared = T.alloc_shared([block_M, dim], dtype) | ||
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) | ||
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) | ||
acc_o = T.alloc_fragment([block_M, dim], accum_dtype) | ||
scores_max = T.alloc_fragment([block_M], accum_dtype) | ||
scores_max_prev = T.alloc_fragment([block_M], accum_dtype) | ||
scores_scale = T.alloc_fragment([block_M], accum_dtype) | ||
scores_sum = T.alloc_fragment([block_M], accum_dtype) | ||
logsum = T.alloc_fragment([block_M], accum_dtype) | ||
|
||
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) | ||
T.fill(acc_o, 0) | ||
T.fill(logsum, 0) | ||
T.fill(scores_max, -T.infinity(accum_dtype)) | ||
|
||
loop_range = ( | ||
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( | ||
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) | ||
|
||
for k in T.Pipelined(loop_range, num_stages=num_stages): | ||
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) | ||
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, | ||
scores_sum, logsum) | ||
Rescale(acc_o, scores_scale) | ||
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) | ||
for i, j in T.Parallel(block_M, dim): | ||
acc_o[i, j] /= logsum[i] | ||
T.copy(acc_o, O_shared) | ||
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) | ||
|
||
return main | ||
|
||
if tune: | ||
|
||
@autotune( | ||
configs=get_configs(), | ||
keys=["block_M", "block_N", "num_stages", "threads"], | ||
warmup=10, | ||
rep=10) | ||
@jit( | ||
out_idx=[3], | ||
supply_type=tilelang.TensorSupplyType.Integer, | ||
ref_prog=None, | ||
profiler="auto") | ||
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): | ||
return kernel_func(block_M, block_N, num_stages, threads) | ||
|
||
return kernel() | ||
else: | ||
|
||
def kernel(block_M, block_N, num_stages, threads): | ||
return kernel_func(block_M, block_N, num_stages, threads) | ||
|
||
return kernel | ||
|
||
|
||
def ref_program(Q, K, V, is_causal, groups=1): | ||
# Q: [B, T, HQ, D] | ||
# K: [B, T, HK, D] | ||
# V: [B, T, HV, D] | ||
# HQ = HKV * groups | ||
assert Q.size(2) == K.size( | ||
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" | ||
assert Q.size(2) == V.size( | ||
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" | ||
|
||
dim = Q.size(-1) | ||
K = K.repeat_interleave(groups, dim=2) | ||
V = V.repeat_interleave(groups, dim=2) | ||
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) | ||
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) | ||
if is_causal: | ||
seq_len = Q.size(1) | ||
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) | ||
mask = mask.unsqueeze(0).unsqueeze(0) | ||
scores = scores.masked_fill(mask == 0, float('-inf')) | ||
attention_weights = F.softmax(scores, dim=-1) | ||
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) | ||
return output | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--batch', type=int, default=8, help='batch size') | ||
parser.add_argument('--heads', type=int, default=32, help='heads') | ||
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') | ||
parser.add_argument('--dim', type=int, default=128, help='dim') | ||
parser.add_argument('--is_causal', action='store_true', help='causal') | ||
parser.add_argument('--tune', action='store_true', help='tune configs') | ||
parser.add_argument('--groups', type=int, default=8, help='groups') | ||
args = parser.parse_args() | ||
batch, heads, seq_len, dim, is_causal, groups = args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups | ||
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim | ||
total_flops = 2 * flops_per_matmul | ||
if is_causal: | ||
total_flops *= 0.5 | ||
|
||
if (not args.tune): | ||
program = flashattn( | ||
batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)( | ||
block_M=128, block_N=128, num_stages=1, threads=128) | ||
ref_program = partial(ref_program, is_causal=is_causal, groups=groups) | ||
mod, params = tilelang.lower(program) | ||
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) | ||
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) | ||
print("All checks pass.") | ||
latency = mod.do_bench(ref_program, warmup=500) | ||
print("Ref: {:.2f} ms".format(latency)) | ||
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) | ||
latency = mod.do_bench(mod.func, warmup=500) | ||
print("Tile-lang: {:.2f} ms".format(latency)) | ||
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) | ||
else: | ||
best_latency, best_config, _ = flashattn( | ||
batch, heads, seq_len, dim, is_causal, tune=args.tune) | ||
print(f"Best latency: {best_latency}") | ||
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") | ||
print(f"Best config: {best_config}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.