Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: matmul dequantization improvements #12015

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

netrunnereve
Copy link
Collaborator

This basically makes the mul_mm shaders load and dequantize 4 or 8 values at a time like how it's done in mat_vec (old quants only).

Results on my RX 470:

PR

model size params backend ngl threads main_gpu sm test t/s
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 100 8 1 none pp512 158.37 ± 0.80
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 100 8 1 none pp512 153.76 ± 0.52
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   38 runs - 26996.37 us/run -  60.13 GFLOP/run -   2.23 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   38 runs - 26764.32 us/run -  60.13 GFLOP/run -   2.25 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   34 runs - 30210.91 us/run -  60.13 GFLOP/run -   1.99 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   36 runs - 29015.64 us/run -  60.13 GFLOP/run -   2.07 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   36 runs - 27984.17 us/run -  60.13 GFLOP/run -   2.15 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 36 runs - 28179.08 us/run -  60.13 GFLOP/run -   2.13 TFLOPS

Master
PR:

model size params backend ngl threads main_gpu sm test t/s
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 100 8 1 none pp512 151.66 ± 0.86
llama 8B Q8_0 7.95 GiB 8.03 B Vulkan 100 8 1 none pp512 149.71 ± 0.14
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   36 runs - 28187.53 us/run -  60.13 GFLOP/run -   2.13 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   36 runs - 28343.00 us/run -  60.13 GFLOP/run -   2.12 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   32 runs - 31629.72 us/run -  60.13 GFLOP/run -   1.90 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   34 runs - 30898.97 us/run -  60.13 GFLOP/run -   1.95 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   36 runs - 28930.81 us/run -  60.13 GFLOP/run -   2.08 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 36 runs - 28959.25 us/run -  60.13 GFLOP/run -   2.08 TFLOPS

I'm only seeing a small improvement as most of the GPU time is spent doing the actual multiplication, and I think we'll see better results on something that supports coopmat.

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Feb 21, 2025
@jeffbolznv
Copy link
Collaborator

I did a quick run on RTX 4070 using the KHR_coopmat path (GGML_VK_DISABLE_COOPMAT2=1). Perf is about neutral on average, maybe down a tiny bit?

before
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   332 runs -  3023.55 us/run -  60.13 GFLOP/run -  19.89 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   322 runs -  3114.34 us/run -  60.13 GFLOP/run -  19.31 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  776 runs -  1289.13 us/run -  60.13 GFLOP/run -  46.64 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  748 runs -  1338.91 us/run -  60.13 GFLOP/run -  44.91 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  674 runs -  1485.07 us/run -  60.13 GFLOP/run -  40.49 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  670 runs -  1493.24 us/run -  60.13 GFLOP/run -  40.27 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  632 runs -  1585.79 us/run -  60.13 GFLOP/run -  37.92 TFLOPS
  
after
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   322 runs -  3118.21 us/run -  60.13 GFLOP/run -  19.28 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   320 runs -  3138.63 us/run -  60.13 GFLOP/run -  19.16 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  734 runs -  1365.62 us/run -  60.13 GFLOP/run -  44.03 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  660 runs -  1515.89 us/run -  60.13 GFLOP/run -  39.67 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  710 runs -  1409.35 us/run -  60.13 GFLOP/run -  42.66 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  708 runs -  1414.56 us/run -  60.13 GFLOP/run -  42.51 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  650 runs -  1542.48 us/run -  60.13 GFLOP/run -  38.98 TFLOPS

The backend tests all passed.

@netrunnereve
Copy link
Collaborator Author

Perf is about neutral on average, maybe down a tiny bit?

Interesting. Let's wait for some more results.

@0cc4m
Copy link
Collaborator

0cc4m commented Feb 25, 2025

Here are my results:

Nvidia RTX 3090
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 3090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: KHR_coopmat
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: NVIDIA GeForce RTX 3090
  Device memory: 24576 MB (24576 MB free)


MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     23.12 TFLOPS       23.37 TFLOPS        0.25 TFLOPS
MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     23.37 TFLOPS       23.68 TFLOPS        0.31 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]): not supported
MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    46.43 TFLOPS       46.16 TFLOPS       -0.27 TFLOPS
MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    44.80 TFLOPS       44.17 TFLOPS       -0.63 TFLOPS
MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    41.93 TFLOPS       46.39 TFLOPS        4.46 TFLOPS
MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    42.65 TFLOPS       46.68 TFLOPS        4.03 TFLOPS
MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    40.53 TFLOPS       46.34 TFLOPS        5.81 TFLOPS
MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    42.55 TFLOPS       42.92 TFLOPS        0.37 TFLOPS
MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    33.67 TFLOPS       34.42 TFLOPS        0.75 TFLOPS
MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    34.69 TFLOPS       35.68 TFLOPS        0.99 TFLOPS
MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    28.03 TFLOPS       28.10 TFLOPS        0.07 TFLOPS
MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    34.40 TFLOPS       34.75 TFLOPS        0.35 TFLOPS
MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 26.45 TFLOPS       28.34 TFLOPS        1.89 TFLOPS
MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  29.01 TFLOPS       27.52 TFLOPS       -1.49 TFLOPS
MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   26.05 TFLOPS       25.80 TFLOPS       -0.25 TFLOPS
MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 25.85 TFLOPS       28.14 TFLOPS        2.29 TFLOPS
MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   32.20 TFLOPS       27.11 TFLOPS       -5.09 TFLOPS
MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   22.22 TFLOPS       23.41 TFLOPS        1.19 TFLOPS
MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  41.81 TFLOPS       29.14 TFLOPS      -12.67 TFLOPS
MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   27.23 TFLOPS       26.45 TFLOPS       -0.78 TFLOPS
MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  35.02 TFLOPS       30.61 TFLOPS       -4.41 TFLOPS
AMD Radeon Pro VII
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon (TM) Pro VII (RADV VEGA20) (radv) | uma: 0 | fp16: 1 | warp size: 64 | shared memory: 65536 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: AMD Radeon (TM) Pro VII (RADV VEGA20)
  Device memory: 16368 MB (16368 MB free)


MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                      4.96 TFLOPS        4.93 TFLOPS       -0.03 TFLOPS
MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                      4.21 TFLOPS        4.19 TFLOPS       -0.02 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]): not supported
MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     4.15 TFLOPS        4.32 TFLOPS        0.17 TFLOPS
MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     4.16 TFLOPS        4.32 TFLOPS        0.16 TFLOPS
MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     3.95 TFLOPS        4.05 TFLOPS        0.10 TFLOPS
MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     4.01 TFLOPS        4.10 TFLOPS        0.09 TFLOPS
MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     4.11 TFLOPS        4.18 TFLOPS        0.07 TFLOPS
MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     3.92 TFLOPS        3.91 TFLOPS       -0.01 TFLOPS
MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     3.61 TFLOPS        3.60 TFLOPS       -0.01 TFLOPS
MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     3.59 TFLOPS        3.60 TFLOPS        0.01 TFLOPS
MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     3.44 TFLOPS        3.43 TFLOPS       -0.01 TFLOPS
MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     3.62 TFLOPS        3.60 TFLOPS       -0.02 TFLOPS
MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  3.67 TFLOPS        3.65 TFLOPS       -0.02 TFLOPS
MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   3.74 TFLOPS        3.71 TFLOPS       -0.03 TFLOPS
MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    3.07 TFLOPS        3.02 TFLOPS       -0.05 TFLOPS
MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  3.65 TFLOPS        3.67 TFLOPS        0.02 TFLOPS
MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    3.85 TFLOPS        3.87 TFLOPS        0.02 TFLOPS
MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    3.66 TFLOPS        3.66 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   4.14 TFLOPS        4.29 TFLOPS        0.15 TFLOPS
MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    3.71 TFLOPS        3.69 TFLOPS       -0.02 TFLOPS
MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   3.79 TFLOPS        3.79 TFLOPS        0.00 TFLOPS
Intel A770
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 65536 | matrix cores: none
Testing 2 devices

Backend 1/2: Vulkan0
  Device description: Intel(R) Arc(tm) A770 Graphics (DG2)
  Device memory: 16032 MB (16032 MB free)


MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                      1.52 TFLOPS        1.52 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                      1.44 TFLOPS        1.44 TFLOPS        0.00 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]): not supported
MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.22 TFLOPS        1.41 TFLOPS        0.19 TFLOPS
MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.18 TFLOPS        1.39 TFLOPS        0.21 TFLOPS
MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.14 TFLOPS        1.32 TFLOPS        0.18 TFLOPS
MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.14 TFLOPS        1.38 TFLOPS        0.24 TFLOPS
MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.14 TFLOPS        1.27 TFLOPS        0.13 TFLOPS
MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.34 TFLOPS        1.33 TFLOPS       -0.01 TFLOPS
MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.31 TFLOPS        1.32 TFLOPS        0.01 TFLOPS
MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.33 TFLOPS        1.33 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.32 TFLOPS        1.32 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                     1.30 TFLOPS        1.30 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  1.13 TFLOPS        1.12 TFLOPS       -0.01 TFLOPS
MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   1.32 TFLOPS        1.31 TFLOPS       -0.01 TFLOPS
MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    1.17 TFLOPS        1.16 TFLOPS       -0.01 TFLOPS
MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  1.16 TFLOPS        1.18 TFLOPS        0.02 TFLOPS
MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    1.26 TFLOPS        1.26 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    1.33 TFLOPS        1.33 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   1.19 TFLOPS        1.35 TFLOPS        0.16 TFLOPS
MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                    1.30 TFLOPS        1.30 TFLOPS        0.00 TFLOPS
MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   1.32 TFLOPS        1.32 TFLOPS        0.00 TFLOPS

Looks like it's mostly perf-neutral on AMD and Intel, probably since they are compute-limited. Some minor improvements. But on RTX 3090 it makes a much larger difference. Looks good for legacy and k-quants, but iq quants seem to be negative. Any ideas?

const float d = float(data_a_packed16[ib].d);
const uint v0 = uint(data_a_packed16[ib].qs[2*iqs]);
const uint v1 = uint(data_a_packed16[ib].qs[2*iqs + 1]);
const vec4 v = vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)) * d;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use data_a_packed32 here? Then you can directly get the vec4 from an unpack8.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since each block has a 16 bit delta each q8_0 block takes up 34 bytes. That's not divisible by 4 bytes and we end up with an unaligned 32 bit load that's slower than a 16 bit one (I think I tried that a long time ago when I did the inference optimizations).

Maybe there's a way to repack the blocks and stuff an extra 16 bits at the end to make it 36 bytes, but that'll use up more memory and it sounds like a lot of work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could repack the tensor such that the quants and the scales are stored separately, then it would be an aligned 32 byte load, plus a 2 byte load for the scale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants