@@ -2713,75 +2713,231 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
27132713
27142714// ====================== 4-bit (de)-quantization
27152715
2716- void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
2716+ static inline void set_scale_min_k4(int j, uint8_t * GGML_RESTRICT q, uint8_t d, uint8_t m) {
2717+ assert(d < 64 && m < 64);
2718+ if (j < 4) {
2719+ q[j] = (q[j] & 0xC0) | (d & 0x3F);
2720+ q[j + 4] = (q[j + 4] & 0xC0) | (m & 0x3F);
2721+ } else {
2722+ const int j2 = j - 4;
2723+ q[j2] = (q[j2] & 0x3F) | ((d & 0x30) << 2);
2724+ q[j + 4] = (d & 0x0F) | ((m & 0x0F) << 4);
2725+ q[j] = (q[j] & 0x3F) | ((m & 0x30) << 2);
2726+ }
2727+ }
2728+
2729+ void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k) {
27172730 assert(k % QK_K == 0);
2718- const int nb = k / QK_K;
2731+
2732+ const int max_iter = 10;
2733+ const float epsilon = 1e-6f;
27192734
2720- uint8_t L[QK_K];
2721- uint8_t Laux[32];
2722- float weights[32];
2723- float mins[QK_K/32];
2724- float scales[QK_K/32];
2735+ const int nb = k / QK_K;
2736+ const int num_subblocks = QK_K / 32;
27252737
27262738 for (int i = 0; i < nb; i++) {
2727- float max_scale = 0; // as we are deducting the min, scales are always positive
2728- float max_min = 0;
2729- for (int j = 0; j < QK_K/32; ++j) {
2730- //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
2731- float sum_x2 = 0;
2732- for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
2733- float av_x = sqrtf(sum_x2/32);
2734- for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2735- scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
2736- float scale = scales[j];
2737- if (scale > max_scale) {
2738- max_scale = scale;
2739+ memset(y[i].scales, 0, K_SCALE_SIZE);
2740+
2741+ float scales[num_subblocks];
2742+ float mins[num_subblocks];
2743+
2744+ // Initialization: compute initial scales and mins per sub-block
2745+ for (int j = 0; j < num_subblocks; j++) {
2746+ float xmin = x[i*QK_K + j*32];
2747+ float xmax = x[i*QK_K + j*32];
2748+
2749+ for (int l = 1; l < 32; l++) {
2750+ const float v = x[i*QK_K + j*32 + l];
2751+ xmin = v < xmin ? v : xmin;
2752+ xmax = v > xmax ? v : xmax;
2753+ }
2754+
2755+ // note: kimi-k2-thinking QAT-specific initialisation
2756+ scales[j] = MAX(fabsf(xmin), fabsf(xmax)) / 7.0f;
2757+ mins[j] = -7.0f * scales[j];
2758+
2759+ if (scales[j] == 0.0f) scales[j] = 1.0f;
2760+ }
2761+
2762+ // Initialize super-block scales
2763+ float d = 0.0f;
2764+ float dmin_abs = 0.0f;
2765+ for (int j = 0; j < num_subblocks; j++) {
2766+ d = scales[j] > d ? scales[j] : d;
2767+ const float mins_abs = fabsf(mins[j]);
2768+ dmin_abs = mins_abs > dmin_abs ? mins_abs : dmin_abs;
2769+ }
2770+ d = d / 63.0f;
2771+ float dmin = dmin_abs / 63.0f;
2772+ if (d == 0.0f) d = 1.0f;
2773+ if (dmin == 0.0f) dmin = 1.0f;
2774+
2775+ // Quantize initial sub-block scales and mins
2776+ uint8_t sc[num_subblocks];
2777+ uint8_t m[num_subblocks];
2778+ for (int j = 0; j < num_subblocks; j++) {
2779+ sc[j] = (uint8_t)(nearest_int(scales[j] / d));
2780+ sc[j] = sc[j] > 63 ? 63 : sc[j];
2781+
2782+ const int m_int = nearest_int(mins[j] / dmin);
2783+ m[j] = (uint8_t)(m_int < 0 ? -m_int : m_int);
2784+ m[j] = m[j] > 63 ? 63 : m[j];
2785+
2786+ set_scale_min_k4(j, y[i].scales, sc[j], m[j]);
2787+ }
2788+
2789+ // Adjust dmin sign based on typical min values
2790+ float avg_min = 0.0f;
2791+ for (int j = 0; j < num_subblocks; j++) avg_min += mins[j];
2792+ avg_min /= num_subblocks;
2793+ if (avg_min > 0.0f) dmin = -dmin;
2794+
2795+ // Temporary storage for 4-bit codes
2796+ uint8_t q[QK_K];
2797+
2798+ // Lloyd-Max iteration
2799+ for (int iter = 0; iter < max_iter; iter++) {
2800+ const float d_old = d;
2801+ const float dmin_old = dmin;
2802+
2803+ // Step 1: Assignment - quantize to 4-bit codes
2804+ for (int j = 0; j < num_subblocks; j++) {
2805+ const float scale = d * sc[j];
2806+ const float offset = -dmin * m[j];
2807+
2808+ if (scale == 0.0f) {
2809+ for (int l = 0; l < 32; ++l) {
2810+ q[j*32 + l] = 0;
2811+ }
2812+ continue;
2813+ }
2814+
2815+ for (int l = 0; l < 32; l++) {
2816+ const float v = x[i*QK_K + j*32 + l];
2817+ const int q_int = nearest_int((v - offset) / scale);
2818+
2819+ // note: kimi-k2-thinking QAT-specific clipping
2820+ q[j*32 + l] = (uint8_t)(q_int < 0 ? 0 : (q_int > 14 ? 14 : q_int));
2821+ }
27392822 }
2740- float min = mins[j];
2741- if (min > max_min) {
2742- max_min = min;
2823+
2824+ // Step 2: Update sub-block scales and mins (2D least squares per sub-block)
2825+ for (int j = 0; j < num_subblocks; j++) {
2826+ float sum_x = 0.0f;
2827+ float sum_q = 0.0f;
2828+ float sum_xq = 0.0f;
2829+ float sum_qq = 0.0f;
2830+
2831+ for (int l = 0; l < 32; l++) {
2832+ const float xv = x[i*QK_K + j*32 + l];
2833+ const float qv = (float)q[j*32 + l];
2834+ sum_x += xv;
2835+ sum_q += qv;
2836+ sum_xq += xv * qv;
2837+ sum_qq += qv * qv;
2838+ }
2839+
2840+ const float n = 32.0f;
2841+ const float det = n * sum_qq - sum_q * sum_q;
2842+
2843+ if (det > 0.0f) {
2844+ const float a = (n * sum_xq - sum_x * sum_q) / det;
2845+ const float b = (sum_x - a * sum_q) / n;
2846+
2847+ if (a > 0.0f && d > 0.0f) {
2848+ const int sc_new = nearest_int(a / d);
2849+ sc[j] = (uint8_t)(sc_new < 0 ? 0 : (sc_new > 63 ? 63 : sc_new));
2850+ }
2851+
2852+ if (dmin != 0.0f) {
2853+ const int m_new = nearest_int(-b / dmin);
2854+ m[j] = (uint8_t)(m_new < 0 ? 0 : (m_new > 63 ? 63 : m_new));
2855+ }
2856+
2857+ set_scale_min_k4(j, y[i].scales, sc[j], m[j]);
2858+ }
27432859 }
2744- }
2745-
2746- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2747- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2748- for (int j = 0; j < QK_K/32; ++j) {
2749- uint8_t ls = nearest_int(inv_scale*scales[j]);
2750- uint8_t lm = nearest_int(inv_min*mins[j]);
2751- ls = MIN(63, ls);
2752- lm = MIN(63, lm);
2753- if (j < 4) {
2754- y[i].scales[j] = ls;
2755- y[i].scales[j+4] = lm;
2756- } else {
2757- y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
2758- y[i].scales[j-4] |= ((ls >> 4) << 6);
2759- y[i].scales[j-0] |= ((lm >> 4) << 6);
2860+
2861+ // Step 3: Update super-block scales (2D least squares across all sub-blocks)
2862+ float A = 0.0f; // Σ(sc*q)²
2863+ float B = 0.0f; // Σ(m*sc*q)
2864+ float C = 0.0f; // Σm²
2865+ float X_d = 0.0f; // Σ(x*sc*q)
2866+ float X_m = 0.0f; // Σ(x*m)
2867+
2868+ for (int j = 0; j < num_subblocks; j++) {
2869+ float sum_sq = 0.0f;
2870+ float sum_q = 0.0f;
2871+ float sum_xq = 0.0f;
2872+ float sum_x = 0.0f;
2873+
2874+ for (int l = 0; l < 32; l++) {
2875+ const float xv = x[i*QK_K + j*32 + l];
2876+ const float qv = (float)q[j*32 + l];
2877+ sum_sq += qv * qv;
2878+ sum_q += qv;
2879+ sum_xq += xv * qv;
2880+ sum_x += xv;
2881+ }
2882+
2883+ const float sc_f = (float)sc[j];
2884+ const float m_f = (float)m[j];
2885+
2886+ A += sc_f * sc_f * sum_sq;
2887+ B += m_f * sc_f * sum_q;
2888+ C += m_f * m_f * 32.0f;
2889+ X_d += sc_f * sum_xq;
2890+ X_m += m_f * sum_x;
2891+ }
2892+
2893+ const float det = A * C - B * B;
2894+
2895+ if (det > 0.0f) {
2896+ const float d_new = (C * X_d - B * X_m) / det;
2897+ const float dmin_new = (B * X_d - A * X_m) / det;
2898+
2899+ if (d_new > 0.0f) {
2900+ d = d_new;
2901+ }
2902+ if (dmin_new != 0.0f) {
2903+ dmin = dmin_new;
2904+ }
2905+ }
2906+
2907+ // Check convergence
2908+ const float delta_d = fabsf(d - d_old);
2909+ const float delta_dmin = fabsf(dmin - dmin_old);
2910+
2911+ if (delta_d < epsilon && delta_dmin < epsilon) {
2912+ break;
27602913 }
27612914 }
2762- y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2763- y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2915+
2916+ // Final assignment with converged parameters
2917+ for (int j = 0; j < num_subblocks; j++) {
2918+ const float scale = d * sc[j];
2919+ const float offset = -dmin * m[j];
2920+
2921+ for (int l = 0; l < 32; l++) {
2922+ const float v = x[i*QK_K + j*32 + l];
2923+ const int q_int = scale != 0.0f ? nearest_int((v - offset) / scale) : 0;
27642924
2765- uint8_t sc, m;
2766- for (int j = 0; j < QK_K/32; ++j) {
2767- get_scale_min_k4(j, y[i].scales, &sc, &m);
2768- const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
2769- if (!d) continue;
2770- const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
2771- for (int ii = 0; ii < 32; ++ii) {
2772- int l = nearest_int((x[32*j + ii] + dm)/d);
2773- l = MAX(0, MIN(15, l));
2774- L[32*j + ii] = l;
2925+ // note: kimi-k2-thinking QAT-specific clipping
2926+ q[j*32 + l] = (uint8_t)(q_int < 0 ? 0 : (q_int > 14 ? 14 : q_int));
27752927 }
27762928 }
2777-
2778- uint8_t * q = y[i].qs;
2779- for (int j = 0; j < QK_K; j += 64) {
2780- for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
2781- q += 32;
2929+
2930+ // Store final super-block scales
2931+ y[i].d = GGML_FP32_TO_FP16(d);
2932+ y[i].dmin = GGML_FP32_TO_FP16(dmin);
2933+
2934+ // Pack 4-bit quantized values (layout expected by dequant)
2935+ uint8_t *qs = y[i].qs;
2936+ for (int base = 0, out = 0; base < QK_K; base += 64, out += 32) {
2937+ for (int l = 0; l < 32; ++l) {
2938+ qs[out + l] = (q[base + l] & 0x0F) | ((q[base + 32 + l] & 0x0F) << 4);
2939+ }
27822940 }
2783-
2784- x += QK_K;
27852941 }
27862942}
27872943
0 commit comments