You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using bfloat16 and certain shapes, the compiler will try to pack two bfloat16 together using __pack_nv_bfloat162, but this utility is not defined anywhere (on my systems at least).
Here is a modified version of the matmul tutorial which reproduces the problem on A100 and 4070 ti (I changed dtype to bfloat16, shapes and num_stages):
importtilelangimporttilelang.languageasT# `make_mma_swizzle_layout` is a python defined layout function# specifically designed for for MMA operations# which ensures the consistency with the nvidia CUTLASS Library.# to avoid bank conflicts and maximize the performance.fromtilelang.intrinsicsimport (
make_mma_swizzle_layoutasmake_swizzle_layout,
)
defmatmul(M, N, K, block_M, block_N, block_K, dtype="bfloat16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function@T.prim_funcdefmain(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Initialize Kernel ContextwithT.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared=T.alloc_shared((block_M, block_K), dtype)
B_shared=T.alloc_shared((block_K, block_N), dtype)
C_local=T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout (Optional)# If not specified, we will deduce the layout automatically# T.annotate_layout({# A_shared: make_swizzle_layout(A_shared),# B_shared: make_swizzle_layout(B_shared),# })# Enable rasterization for better L2 cache locality (Optional)# T.use_swizzle(panel_size=10, enable=True)# Clear local accumulationT.clear(C_local)
forkoinT.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A# This is a sugar syntax for parallelized copyT.copy(A[by*block_M, ko*block_K], A_shared)
# Demonstrate parallelized copy from global to shared for Bfork, jinT.Parallel(block_K, block_N):
B_shared[k, j] =B[ko*block_K+k, bx*block_N+j]
# Perform a tile-level GEMM on the shared buffers# Currently we dispatch to the cute/hip on Nvidia/AMD GPUsT.gemm(A_shared, B_shared, C_local)
# Copy result back to global memoryT.copy(C_local, C[by*block_M, bx*block_N])
returnmain# 1. Define the kernel (matmul) with the desired dimensionsM, N, K=1000, 64, 64func=matmul(M, N, K, 32, 64, 32)
# 2. Compile the kernel into a torch function# out_idx specifies the index of the output buffer in the argument list# if out_idx is specified, the tensor will be created during runtime# target currently can be "cuda" or "hip" or "cpu".jit_kernel=tilelang.compile(func, out_idx=[2], target="cuda")
# 3. Test the kernel in Python with PyTorch dataimporttorch# Create random input tensors on the GPUa=torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b=torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
# Run the kernel through the JIT-compiled functionc=jit_kernel(a, b)
# Reference multiplication using PyTorchref_c=a @ b# Validate correctnesstorch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)cuda_source=jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# 5.Pofile latency with the profilerprofiler=jit_kernel.get_profiler()
latency=profiler.do_bench()
print(f"Latency: {latency} ms")
When using bfloat16 and certain shapes, the compiler will try to pack two bfloat16 together using __pack_nv_bfloat162, but this utility is not defined anywhere (on my systems at least).
Here is a modified version of the matmul tutorial which reproduces the problem on A100 and 4070 ti (I changed dtype to bfloat16, shapes and num_stages):
which when run produces this traceback:
The text was updated successfully, but these errors were encountered: