-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GGML_HIP_ROCWMMA_FATTN to enable rocWMMA for FlashAttention #12032
base: master
Are you sure you want to change the base?
Conversation
Adding @IMbackK for review |
As of right now I do not have the hardware necessary to test this code. Is anyone pledging to maintain it long-term if it gets merged? |
@JohannesGaessler So we will keep this code path and @adelj88 @thamwangjun 's optimized code path together in the future? |
ggml/src/ggml-cuda/fattn-wmma-f16.cu
Outdated
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; | ||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; | ||
#else | ||
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K; | ||
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V; | ||
typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b; | ||
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; | ||
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; | ||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason to do it like this and not with something like using namespace nvcuda:wmma
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to namespace alias instead of these ifdefs. Does that look good to you?
I am currently traveling and will take a look at this next week, thank you! |
Nice, +41% over master for pp2048 on gfx1100. Shame it doesn't affect text gen, where fa=1 is still ≥6% behind fa=0, and 22% behind radv |
@Beinsezii Maybe we need new version of rocwmma+rocm to make it work on gfx12? |
@Headcrabed gfx12 is part of the current release 6.3 |
gfx12 requires rocWMMA from ROCm 6.4 to compile and work. The one from latest 6.3.x release is broken. I may send another PR to enable that after ROCm 6.4 is released. |
@hjc4869 However, |
With this patch, on gfx908 I see prompt performance with FA=1 almost exactly equal to FA=0 (without the patch, FA=1 is both slower and scales worse). Token generation is somewhere between unpatched FA=0 and FA=1. Unfortunately, in some situations with long context, I get garbage output. Repro:
The model is from ollama The FA=0 case (or the FA=1 unpatched case) prints something like
While FA=1 patched (seemingly regardless of input content) produces:
Let me know if there are any other tests I can run.
|
Check the output of |
Here's the output:
|
Maybe this is obvious, but the failing cases above are just all of the cases where:
|
The crash is caused by lack of check in diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index ebb2ccae..dc6ed6ab 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3216,6 +3216,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
+
+#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
+ return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
+ op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16 && op->src[0]->ne[1] > 8;
+#elif defined(GGML_USE_HIP)
+ return false;
+#endif // defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
+
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
} Differences of test results on W7900DS between pr branch and master after above fixdiff --git a/home/david/Desktop/master.log b/home/david/Desktop/pr.log
index d4a3d4ee..47899baa 100644
--- a/home/david/Desktop/master.log
+++ b/home/david/Desktop/pr.log
@@ -196,16 +196,16 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -220,11 +220,11 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -244,16 +244,16 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,2,1,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,2,1,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,2,1,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -268,11 +268,11 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -284,11 +284,11 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -300,11 +300,11 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -316,11 +316,11 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -332,11 +332,11 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=1,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
@@ -348,11 +348,11 @@ Backend 1/2: ROCm0
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=32,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0]
- FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [ROCm0]
+ FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): �[1;32mOK�[0m
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [ROCm0]
FLASH_ATTN_EXT(hs=80,nh=4,nr=4,kv=512,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [ROCm0] I currently only have RDNA3 GPUs so I haven't tested the patch on CDNA. Perhaps we'll wait for @IMbackK to decide if we would enable this code path on a per device basis based on test results, or if there's something wrong with my changes that caused issues on CDNA. |
A slight off-topic, but I'm wondering if we need to enable different code paths for RDNA3/4 and CDNA in the future, as well as whether we use rocWMMA in those future optimisations (we might eventually need to roll out something similar to In the scenario where the PR above has some issues on CDNA, I wouldn't mind if this PR enabled this code path specifically for RDNA3/4, at least until optimised code paths are introduced. |
ggml/src/ggml-cuda/fattn.cu
Outdated
#if defined(GGML_HIP_ROCWMMA_FATTN) | ||
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) { | ||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||
return; | ||
} | ||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is Q->ne[1] > 8
correct? In my experience even for very small batch sizes like 2 the wmma kernel can be faster (on NVIDIA).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check was indeed not needed, it seems to be an early hack in my private branches to get rid of the cols_per_block = 8 cases. I've removed them.
…d GGML_HIP_ROCWMMA_FATTN is disabled.
With the
Before:
After
|
… for AMD and GGML_HIP_ROCWMMA_FATTN not enabled.
I pulled the latest changes and I can confirm that the test doesn't crash, and now the garbage output appears even at small prompt sizes. It also improved (garbage) token generation performance. Output of test-backend-ops -o FLASH_ATTN_EXT on gfx908/Mi100 w/latest code
llama-batched-bench output before/after PRThe numbers for with/without FA (post PR) are so close I have triple checked them. I feel like I'm missing something. Without this PR, FA=0
Without this PR, FA=1
With this PR, FA=1
Just for fun, the same bench run on 3090Without FA:
With FA:
|
…compiled for AMD and GGML_HIP_ROCWMMA_FATTN not enabled." This reverts commit 5516909.
GGML_HIP_ROCWMMA_FATTN
and defaults to OFFGGML_HIP_ROCWMMA_FATTN
is enabledFP16_MMA_AVAILABLE
whenGGML_HIP_ROCWMMA_FATTN
is enabled and target is supported by rocWMMA (CDNA / RDNA3)Related issue: #10439