Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

identifier "__pack_nv_bfloat162" is undefined #115

Open
falkaer opened this issue Feb 24, 2025 · 3 comments
Open

identifier "__pack_nv_bfloat162" is undefined #115

falkaer opened this issue Feb 24, 2025 · 3 comments

Comments

@falkaer
Copy link

falkaer commented Feb 24, 2025

When using bfloat16 and certain shapes, the compiler will try to pack two bfloat16 together using __pack_nv_bfloat162, but this utility is not defined anywhere (on my systems at least).

Here is a modified version of the matmul tutorial which reproduces the problem on A100 and 4070 ti (I changed dtype to bfloat16, shapes and num_stages):

import tilelang
import tilelang.language as T

# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
    make_mma_swizzle_layout as make_swizzle_layout,
)


def matmul(M, N, K, block_M, block_N, block_K, dtype="bfloat16", accum_dtype="float"):
    # add decorator @tilelang.jit if you want to return a torch function
    @T.prim_func
    def main(
        A: T.Buffer((M, K), dtype),
        B: T.Buffer((K, N), dtype),
        C: T.Buffer((M, N), dtype),
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
            bx,
            by,
        ):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            # Apply layout optimizations or define your own layout (Optional)
            # If not specified, we will deduce the layout automatically
            # T.annotate_layout({
            #     A_shared: make_swizzle_layout(A_shared),
            #     B_shared: make_swizzle_layout(B_shared),
            # })

            # Enable rasterization for better L2 cache locality (Optional)
            # T.use_swizzle(panel_size=10, enable=True)

            # Clear local accumulation
            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
                # Copy tile of A
                # This is a sugar syntax for parallelized copy
                T.copy(A[by * block_M, ko * block_K], A_shared)

                # Demonstrate parallelized copy from global to shared for B
                for k, j in T.Parallel(block_K, block_N):
                    B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]

                # Perform a tile-level GEMM on the shared buffers
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
                T.gemm(A_shared, B_shared, C_local)

            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


# 1. Define the kernel (matmul) with the desired dimensions
M, N, K = 1000, 64, 64
func = matmul(M, N, K, 32, 64, 32)

# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)


# Run the kernel through the JIT-compiled function
c = jit_kernel(a, b)

# Reference multiplication using PyTorch
ref_c = a @ b

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)

# 5.Pofile latency with the profiler
profiler = jit_kernel.get_profiler()

latency = profiler.do_bench()

print(f"Latency: {latency} ms")

which when run produces this traceback:

Traceback (most recent call last):
  File "/home/falkaer/testvenv/repro2.py", line 70, in <module>
    jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/jit/__init__.py", line 121, in compile
    return JITKernel(
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/jit/kernel.py", line 84, in __init__
    adapter = self._compile_and_create_adapter(func)
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/jit/kernel.py", line 130, in _compile_and_create_adapter
    rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host)
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/engine/lower.py", line 245, in lower
    device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "Python/ceval.c", line 5888, in call_function
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/engine/lower.py", line 74, in tilelang_callback_cuda_compile
    ptx = nvcc.compile_cuda(
  File "/home/falkaer/testvenv/.venv/lib/python3.10/site-packages/tilelang/contrib/nvcc.py", line 118, in compile_cuda
    raise RuntimeError(msg)
RuntimeError: #include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>

extern "C" __global__ void __launch_bounds__(128) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  float C_local[16];
  #pragma unroll
  for (int i = 0; i < 8; ++i) {
    *(float2*)(C_local + (i * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
  }
  for (int ko = 0; ko < 2; ++ko) {
    __syncthreads();
    uint4 condval;
    if (((((((int)blockIdx.y) * 4) + (((int)threadIdx.x) >> 5)) < 125) && (((((int)blockIdx.y) * 4) + (((int)threadIdx.x) >> 5)) < 125))) {
      condval = *(uint4*)(A + ((((((int)blockIdx.y) * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (ko * 32)) + ((((int)threadIdx.x) & 3) * 8)));
    } else {
      condval = make_uint4(__pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)));
    }
    *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + ((((((int)threadIdx.x) >> 2) * 32) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 8))) = condval;
    #pragma unroll
    for (int i_1 = 0; i_1 < 2; ++i_1) {
      *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + ((((((i_1 * 1024) + ((((int)threadIdx.x) >> 3) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 1024)) = *(uint4*)(B + (((ko * 2048) + (i_1 * 1024)) + (((int)threadIdx.x) * 8)));
    }
    __syncthreads();
    tl::gemm_ss<32, 64, 32, 2, 2, 0, 0>((&(((bfloat16_t*)buf_dyn_shmem)[0])), (&(((bfloat16_t*)buf_dyn_shmem)[1024])), (&(C_local[0])));
  }
  #pragma unroll
  for (int i_2 = 0; i_2 < 8; ++i_2) {
    if ((((((int)blockIdx.y) * 4) + (((((int)threadIdx.x) & 63) >> 5) * 2)) + (i_2 & 1)) < 125) {
      uint1 __1;
      float2 v_ = *(float2*)(C_local + (i_2 * 2));
      ((nv_bfloat162*)(&(__1.x)))->x = (bfloat16_t)(v_.x);
      ((nv_bfloat162*)(&(__1.x)))->y = (bfloat16_t)(v_.y);
      *(uint1*)(C + (((((((((int)blockIdx.y) * 2048) + (((((int)threadIdx.x) & 63) >> 5) * 1024)) + ((i_2 & 1) * 512)) + (((((int)threadIdx.x) & 31) >> 2) * 64)) + ((i_2 >> 1) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = __1;
    }
  }
}


Compilation error:
/tmp/tmpyr134y_f/tvm_kernels.cu(21): error: identifier "__pack_nv_bfloat162" is undefined
        condval = make_uint4(__pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)));
                             ^
@falkaer
Copy link
Author

falkaer commented Feb 24, 2025

Also happens when using atomic_addx2 with bfloat16

@LeiWang1999
Copy link
Contributor

Thanks for your reporting @falkaer , I'll take a look.

@LeiWang1999
Copy link
Contributor

pr #116 may fix this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants