diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 10318e8766074..8835c442ecfd8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -82,9 +82,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; - uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; - return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); + const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]); + const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]); + return vec4(v0.x, v0.y, v1.x, v1.y); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 4770469eddcab..4ccbe613af2ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 const uint iqs = idx; // Load 16b and select the byte for this element - int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; + int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; float16_t ret = float16_t(qs) * d; return ret; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 39657195cfc8d..a8fd93fdeadee 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -32,6 +32,13 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; @@ -243,74 +250,100 @@ void main() { #endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint uint_qh = data_a[ib].qh; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 16; - const uint iqs = (idx & 0xF) * 2; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]); + const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]); + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -623,17 +656,18 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; + buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; #endif } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index dfa16cda516bd..907067d7fa56e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -139,7 +139,7 @@ struct block_q8_0 struct block_q8_0_packed16 { float16_t d; - uint16_t qs[32/2]; + int16_t qs[32/2]; }; #if defined(DATA_A_Q8_0) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index c5e0bba82b260..4a81505565e89 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -325,11 +325,17 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); for (const auto& tname : type_names) { + std::string load_vec_quant = "2"; + if ((tname == "q4_0") || (tname == "q4_1")) + load_vec_quant = "8"; + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + load_vec_quant = "4"; + std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant; // For aligned matmul loads - std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant; // don't generate f32 variants for coopmat2 if (!coopmat2) {