Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GGML_HIP_ROCWMMA_FATTN to enable rocWMMA for FlashAttention #12032

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp on
option(GGML_HIP "ggml: use HIP" OFF)
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF)
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
Expand Down
10 changes: 8 additions & 2 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA

#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand Down Expand Up @@ -223,12 +227,14 @@ static bool fast_fp16_hardware_available(const int cc) {

// Any FP16 tensor core instructions are available for ggml code.
static bool fp16_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3;
}

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3;
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
Expand Down
65 changes: 63 additions & 2 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
#include "fattn-wmma-f16.cuh"

#ifdef FP16_MMA_AVAILABLE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#include <mma.h>
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#include <rocwmma/rocwmma.hpp>
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // FP16_MMA_AVAILABLE

// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
Expand Down Expand Up @@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
Expand All @@ -68,11 +73,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#else
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K;
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V;
typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b;
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to do it like this and not with something like using namespace nvcuda:wmma?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to namespace alias instead of these ifdefs. Does that look good to you?


constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
Expand Down Expand Up @@ -162,7 +175,11 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
#else
rocwmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
}

Expand All @@ -176,20 +193,36 @@ static __global__ void flash_attn_ext_f16(
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
#else
rocwmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#else
rocwmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
#else
rocwmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
#else
rocwmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
}

Expand Down Expand Up @@ -308,10 +341,17 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
#else
rocwmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
}

Expand All @@ -320,18 +360,30 @@ static __global__ void flash_attn_ext_f16(
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
#else
rocwmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}

#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;

frag_a_V v_a;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#else
rocwmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
#else
rocwmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
}
}
Expand All @@ -343,10 +395,17 @@ static __global__ void flash_attn_ext_f16(
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
nvcuda::wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major);
#else
rocwmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, rocwmma::mem_col_major);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
}

Expand Down Expand Up @@ -425,7 +484,7 @@ static __global__ void flash_attn_ext_f16(
}
#else
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}

constexpr int get_max_power_of_2(int x) {
Expand Down Expand Up @@ -574,6 +633,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
Expand All @@ -586,6 +646,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
default:
GGML_ABORT("fatal error");
break;
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst

// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)

if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ endif()
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
if (GGML_HIP_ROCWMMA_FATTN)
CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
if (NOT ${FOUND_ROCWMMA})
message(FATAL_ERROR "rocwmma is not found")
endif()
endif()

if (${hip_VERSION} VERSION_LESS 5.5)
message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
Expand Down Expand Up @@ -107,6 +113,10 @@ if (GGML_HIP_NO_VMM)
add_compile_definitions(GGML_HIP_NO_VMM)
endif()

if (GGML_HIP_ROCWMMA_FATTN)
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
endif()

if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
Expand Down