Skip to content

Commit

Permalink
[Feature] Add CTypes JIT kernel support (#100)
Browse files Browse the repository at this point in the history
* [Feature] Add CTypes JIT kernel support for dynamic shapes and multi-stream execution

- Enhance CtypesKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in CTypes backend
- Implement dynamic shape handling in test_tilelang_jit_gemm_ctypes.py
- Add symbolic shape utility function in tilelang.language
- Update profiler to improve flexibility in benchmark selection

* Remove redundant thread binding in GEMM kernel implementations

- Remove unnecessary `thread_binding` line in GEMM kernel functions
- Clean up code in `examples/gemm/README.md` and `testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py`
- Enhance code readability by removing redundant thread binding annotation

* Fix indentation in int4 GEMM kernel test file

- Correct indentation for function calls in `test_tilelang_kernel_int4_gemm_mma.py`
- Remove extra indentation in `mma_emitter.ldmatrix_a()` and `mma_emitter.ldmatrix_b()` calls
- Improve code formatting for better readability
  • Loading branch information
LeiWang1999 authored Feb 20, 2025
1 parent 15b926a commit 778dbd2
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 253 deletions.
2 changes: 0 additions & 2 deletions examples/gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_binding = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
Expand Down
169 changes: 167 additions & 2 deletions testing/python/jit/test_tilelang_jit_gemm_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
Expand All @@ -27,8 +28,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tilelang.language as T

@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
Expand Down Expand Up @@ -235,5 +234,171 @@ def test_gemm_jit_kernel():
)


def run_ctypes_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")

profiler = matmul_kernel.get_profiler()

ctypes_latency = profiler.do_bench(func=matmul_kernel, profiler="torch")
print(f"Ctypes Latency: {ctypes_latency} ms")

assert ctypes_latency is not None

tvm_latency = profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")

assert tvm_latency is not None


def test_ctypes_kernel_do_bench():
run_ctypes_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)


def run_ctypes_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")

tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()

if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()

num_streams = 4
for _ in range(num_streams):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
matmul_kernel(tensor_a, tensor_b, tensor_c)


def test_ctypes_kernel_multi_stream():
run_ctypes_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)


def run_ctypes_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()

if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()

matmul_kernel(tensor_a, tensor_b, tensor_c)

tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)


def test_ctypes_dynamic_shape():
run_ctypes_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)

run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)

run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)


if __name__ == "__main__":
tilelang.testing.main()
12 changes: 4 additions & 8 deletions testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def main(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_binding = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
Expand Down Expand Up @@ -138,14 +136,14 @@ def main(
A_local,
A_shared,
ki,
)
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
Expand Down Expand Up @@ -294,8 +292,6 @@ def main(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_binding = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
Expand Down Expand Up @@ -325,14 +321,14 @@ def main(
A_local,
A_shared,
ki,
)
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
Expand Down
Loading

0 comments on commit 778dbd2

Please sign in to comment.