Skip to content

Commit 47a268e

Browse files
authored
Vulkan: MMVQ Integer Dot K-Quant and MUL_MAT_ID support (#16900)
* vulkan: split mul_mmq_funcs for mul_mat_vecq use * add mxfp4 mmvq * add q2_k mmvq * add q3_k mmvq * add q4_k and q5_k mmvq * add q6_k mmvq * handle 4x4 quants per mmvq thread * enable MUL_MAT_ID mmvq support * enable subgroup optimizations for mul_mat_vec_id shaders * device tuning * request prealloc_y sync after quantization * fix indentation * fix llvmpipe test failures * fix mul_mat_id mmvq condition * fix unused variable warning
1 parent 59d8d4e commit 47a268e

12 files changed

+680
-286
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 190 additions & 54 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@
44

55
#include "types.glsl"
66

7-
#if defined(A_TYPE_PACKED16)
8-
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
9-
#endif
10-
#if defined(A_TYPE_PACKED32)
11-
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
12-
#endif
13-
147
#if defined(DATA_A_F32)
158
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
169
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);

ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ layout (push_constant) uniform parameter
2222

2323
#if !RMS_NORM_ROPE_FUSION
2424
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
25+
#if defined(A_TYPE_PACKED16)
26+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
27+
#endif
28+
#if defined(A_TYPE_PACKED32)
29+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
30+
#endif
31+
2532
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
2633
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
2734
#endif

ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
1818
} p;
1919

2020
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
21+
#if defined(A_TYPE_PACKED16)
22+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
23+
#endif
24+
#if defined(A_TYPE_PACKED32)
25+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
26+
#endif
27+
2128
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
2229

2330
uint get_idx() {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
44

55
#include "mul_mat_vec_base.glsl"
6+
#include "dequant_funcs.glsl"
67

78
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
89

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
#include "mul_mat_vec_iface.glsl"
1515

16-
#include "dequant_funcs.glsl"
17-
1816
layout (push_constant) uniform parameter
1917
{
2018
uint ncols;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
66
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
77

8-
#ifndef MMQ
98
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
109
#if defined(A_TYPE_VEC4)
1110
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
1211
#endif
13-
#else
14-
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
12+
#if defined(A_TYPE_PACKED16)
13+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
14+
#endif
15+
#if defined(A_TYPE_PACKED32)
16+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
1517
#endif
1618

1719
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,60 +10,56 @@
1010

1111
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1212

13+
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
1314
#define K_PER_ITER 8
14-
15-
#include "mul_mmq_funcs.glsl"
15+
#elif defined(DATA_A_QUANT_K)
16+
#define K_PER_ITER 16
17+
#else
18+
#error unimplemented
19+
#endif
1620

1721
uint a_offset, b_offset, d_offset;
1822

19-
int32_t cache_b_qs[2];
23+
int32_t cache_b_qs[K_PER_ITER / 4];
2024
vec2 cache_b_ds;
2125

26+
#include "mul_mat_vecq_funcs.glsl"
27+
2228
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
2329
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
2430
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
2531

2632
// Preload data_b block
2733
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
28-
const uint b_qs_idx = tid % 4;
34+
const uint b_qs_idx = tid % (32 / K_PER_ITER);
2935
const uint b_block_idx_outer = b_block_idx / 4;
3036
const uint b_block_idx_inner = b_block_idx % 4;
3137
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
3238

3339
#if QUANT_R == 2
40+
// Assumes K_PER_ITER == 8
3441
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
3542
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
3643
#else
44+
#if K_PER_ITER == 8
3745
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
3846
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
47+
#elif K_PER_ITER == 16
48+
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
49+
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
50+
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
51+
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
52+
#else
53+
#error unimplemented
54+
#endif
3955
#endif
4056

4157
uint ibi = first_row*p.ncols;
4258
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
43-
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
59+
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
4460
ibi += p.ncols;
4561

46-
int32_t q_sum = 0;
47-
#if QUANT_R == 2
48-
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
49-
q_sum += dotPacked4x8EXT(data_a_qs.x,
50-
cache_b_qs[0]);
51-
q_sum += dotPacked4x8EXT(data_a_qs.y,
52-
cache_b_qs[1]);
53-
#else
54-
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
55-
q_sum += dotPacked4x8EXT(data_a_qs,
56-
cache_b_qs[0]);
57-
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
58-
q_sum += dotPacked4x8EXT(data_a_qs,
59-
cache_b_qs[1]);
60-
#endif
61-
62-
#if QUANT_AUXF == 1
63-
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
64-
#else
65-
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
66-
#endif
62+
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
6763
}
6864
}
6965
}
@@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
7268
const uint tid = gl_LocalInvocationID.x;
7369

7470
get_offsets(a_offset, b_offset, d_offset);
75-
a_offset /= QUANT_K;
71+
a_offset /= QUANT_K_Q8_1;
7672
b_offset /= QUANT_K_Q8_1;
7773

7874
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
@@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
10298
unroll_count = 2;
10399
unrolled_iters = num_iters & ~(unroll_count - 1);
104100

105-
#if K_PER_ITER == 2
106-
if ((p.ncols & 1) != 0 &&
107-
unrolled_iters == num_iters &&
108-
unrolled_iters > 0) {
109-
unrolled_iters -= unroll_count;
110-
}
111-
#endif
112-
113101
while (i < unrolled_iters) {
114102
// Manually partially unroll the loop
115103
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
@@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
128116
void main() {
129117
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
130118

119+
#ifdef NEEDS_INIT_IQ_SHMEM
120+
init_iq_shmem(gl_WorkGroupSize);
121+
#endif
122+
131123
// do NUM_ROWS at a time, unless there aren't enough remaining rows
132124
if (first_row + NUM_ROWS <= p.stride_d) {
133125
compute_outputs(first_row, NUM_ROWS);

0 commit comments

Comments
 (0)