Skip to content

Commit 36ad3cb

Browse files
committed
Special q4_K for Kimi-K2-Thinking - ik_llama.cpp patch
ik_llama.cpp patched with ggml-org/llama.cpp#17064 (comment)
1 parent b6cc848 commit 36ad3cb

File tree

1 file changed

+213
-57
lines changed

1 file changed

+213
-57
lines changed

ggml/src/ggml-quants.c

Lines changed: 213 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,75 +2713,231 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
27132713

27142714
// ====================== 4-bit (de)-quantization
27152715

2716-
void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
2716+
static inline void set_scale_min_k4(int j, uint8_t * GGML_RESTRICT q, uint8_t d, uint8_t m) {
2717+
assert(d < 64 && m < 64);
2718+
if (j < 4) {
2719+
q[j] = (q[j] & 0xC0) | (d & 0x3F);
2720+
q[j + 4] = (q[j + 4] & 0xC0) | (m & 0x3F);
2721+
} else {
2722+
const int j2 = j - 4;
2723+
q[j2] = (q[j2] & 0x3F) | ((d & 0x30) << 2);
2724+
q[j + 4] = (d & 0x0F) | ((m & 0x0F) << 4);
2725+
q[j] = (q[j] & 0x3F) | ((m & 0x30) << 2);
2726+
}
2727+
}
2728+
2729+
void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k) {
27172730
assert(k % QK_K == 0);
2718-
const int nb = k / QK_K;
2731+
2732+
const int max_iter = 10;
2733+
const float epsilon = 1e-6f;
27192734

2720-
uint8_t L[QK_K];
2721-
uint8_t Laux[32];
2722-
float weights[32];
2723-
float mins[QK_K/32];
2724-
float scales[QK_K/32];
2735+
const int nb = k / QK_K;
2736+
const int num_subblocks = QK_K / 32;
27252737

27262738
for (int i = 0; i < nb; i++) {
2727-
float max_scale = 0; // as we are deducting the min, scales are always positive
2728-
float max_min = 0;
2729-
for (int j = 0; j < QK_K/32; ++j) {
2730-
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
2731-
float sum_x2 = 0;
2732-
for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
2733-
float av_x = sqrtf(sum_x2/32);
2734-
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2735-
scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
2736-
float scale = scales[j];
2737-
if (scale > max_scale) {
2738-
max_scale = scale;
2739+
memset(y[i].scales, 0, K_SCALE_SIZE);
2740+
2741+
float scales[num_subblocks];
2742+
float mins[num_subblocks];
2743+
2744+
// Initialization: compute initial scales and mins per sub-block
2745+
for (int j = 0; j < num_subblocks; j++) {
2746+
float xmin = x[i*QK_K + j*32];
2747+
float xmax = x[i*QK_K + j*32];
2748+
2749+
for (int l = 1; l < 32; l++) {
2750+
const float v = x[i*QK_K + j*32 + l];
2751+
xmin = v < xmin ? v : xmin;
2752+
xmax = v > xmax ? v : xmax;
2753+
}
2754+
2755+
// note: kimi-k2-thinking QAT-specific initialisation
2756+
scales[j] = MAX(fabsf(xmin), fabsf(xmax)) / 7.0f;
2757+
mins[j] = -7.0f * scales[j];
2758+
2759+
if (scales[j] == 0.0f) scales[j] = 1.0f;
2760+
}
2761+
2762+
// Initialize super-block scales
2763+
float d = 0.0f;
2764+
float dmin_abs = 0.0f;
2765+
for (int j = 0; j < num_subblocks; j++) {
2766+
d = scales[j] > d ? scales[j] : d;
2767+
const float mins_abs = fabsf(mins[j]);
2768+
dmin_abs = mins_abs > dmin_abs ? mins_abs : dmin_abs;
2769+
}
2770+
d = d / 63.0f;
2771+
float dmin = dmin_abs / 63.0f;
2772+
if (d == 0.0f) d = 1.0f;
2773+
if (dmin == 0.0f) dmin = 1.0f;
2774+
2775+
// Quantize initial sub-block scales and mins
2776+
uint8_t sc[num_subblocks];
2777+
uint8_t m[num_subblocks];
2778+
for (int j = 0; j < num_subblocks; j++) {
2779+
sc[j] = (uint8_t)(nearest_int(scales[j] / d));
2780+
sc[j] = sc[j] > 63 ? 63 : sc[j];
2781+
2782+
const int m_int = nearest_int(mins[j] / dmin);
2783+
m[j] = (uint8_t)(m_int < 0 ? -m_int : m_int);
2784+
m[j] = m[j] > 63 ? 63 : m[j];
2785+
2786+
set_scale_min_k4(j, y[i].scales, sc[j], m[j]);
2787+
}
2788+
2789+
// Adjust dmin sign based on typical min values
2790+
float avg_min = 0.0f;
2791+
for (int j = 0; j < num_subblocks; j++) avg_min += mins[j];
2792+
avg_min /= num_subblocks;
2793+
if (avg_min > 0.0f) dmin = -dmin;
2794+
2795+
// Temporary storage for 4-bit codes
2796+
uint8_t q[QK_K];
2797+
2798+
// Lloyd-Max iteration
2799+
for (int iter = 0; iter < max_iter; iter++) {
2800+
const float d_old = d;
2801+
const float dmin_old = dmin;
2802+
2803+
// Step 1: Assignment - quantize to 4-bit codes
2804+
for (int j = 0; j < num_subblocks; j++) {
2805+
const float scale = d * sc[j];
2806+
const float offset = -dmin * m[j];
2807+
2808+
if (scale == 0.0f) {
2809+
for (int l = 0; l < 32; ++l) {
2810+
q[j*32 + l] = 0;
2811+
}
2812+
continue;
2813+
}
2814+
2815+
for (int l = 0; l < 32; l++) {
2816+
const float v = x[i*QK_K + j*32 + l];
2817+
const int q_int = nearest_int((v - offset) / scale);
2818+
2819+
// note: kimi-k2-thinking QAT-specific clipping
2820+
q[j*32 + l] = (uint8_t)(q_int < 0 ? 0 : (q_int > 14 ? 14 : q_int));
2821+
}
27392822
}
2740-
float min = mins[j];
2741-
if (min > max_min) {
2742-
max_min = min;
2823+
2824+
// Step 2: Update sub-block scales and mins (2D least squares per sub-block)
2825+
for (int j = 0; j < num_subblocks; j++) {
2826+
float sum_x = 0.0f;
2827+
float sum_q = 0.0f;
2828+
float sum_xq = 0.0f;
2829+
float sum_qq = 0.0f;
2830+
2831+
for (int l = 0; l < 32; l++) {
2832+
const float xv = x[i*QK_K + j*32 + l];
2833+
const float qv = (float)q[j*32 + l];
2834+
sum_x += xv;
2835+
sum_q += qv;
2836+
sum_xq += xv * qv;
2837+
sum_qq += qv * qv;
2838+
}
2839+
2840+
const float n = 32.0f;
2841+
const float det = n * sum_qq - sum_q * sum_q;
2842+
2843+
if (det > 0.0f) {
2844+
const float a = (n * sum_xq - sum_x * sum_q) / det;
2845+
const float b = (sum_x - a * sum_q) / n;
2846+
2847+
if (a > 0.0f && d > 0.0f) {
2848+
const int sc_new = nearest_int(a / d);
2849+
sc[j] = (uint8_t)(sc_new < 0 ? 0 : (sc_new > 63 ? 63 : sc_new));
2850+
}
2851+
2852+
if (dmin != 0.0f) {
2853+
const int m_new = nearest_int(-b / dmin);
2854+
m[j] = (uint8_t)(m_new < 0 ? 0 : (m_new > 63 ? 63 : m_new));
2855+
}
2856+
2857+
set_scale_min_k4(j, y[i].scales, sc[j], m[j]);
2858+
}
27432859
}
2744-
}
2745-
2746-
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2747-
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2748-
for (int j = 0; j < QK_K/32; ++j) {
2749-
uint8_t ls = nearest_int(inv_scale*scales[j]);
2750-
uint8_t lm = nearest_int(inv_min*mins[j]);
2751-
ls = MIN(63, ls);
2752-
lm = MIN(63, lm);
2753-
if (j < 4) {
2754-
y[i].scales[j] = ls;
2755-
y[i].scales[j+4] = lm;
2756-
} else {
2757-
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
2758-
y[i].scales[j-4] |= ((ls >> 4) << 6);
2759-
y[i].scales[j-0] |= ((lm >> 4) << 6);
2860+
2861+
// Step 3: Update super-block scales (2D least squares across all sub-blocks)
2862+
float A = 0.0f; // Σ(sc*q)²
2863+
float B = 0.0f; // Σ(m*sc*q)
2864+
float C = 0.0f; // Σm²
2865+
float X_d = 0.0f; // Σ(x*sc*q)
2866+
float X_m = 0.0f; // Σ(x*m)
2867+
2868+
for (int j = 0; j < num_subblocks; j++) {
2869+
float sum_sq = 0.0f;
2870+
float sum_q = 0.0f;
2871+
float sum_xq = 0.0f;
2872+
float sum_x = 0.0f;
2873+
2874+
for (int l = 0; l < 32; l++) {
2875+
const float xv = x[i*QK_K + j*32 + l];
2876+
const float qv = (float)q[j*32 + l];
2877+
sum_sq += qv * qv;
2878+
sum_q += qv;
2879+
sum_xq += xv * qv;
2880+
sum_x += xv;
2881+
}
2882+
2883+
const float sc_f = (float)sc[j];
2884+
const float m_f = (float)m[j];
2885+
2886+
A += sc_f * sc_f * sum_sq;
2887+
B += m_f * sc_f * sum_q;
2888+
C += m_f * m_f * 32.0f;
2889+
X_d += sc_f * sum_xq;
2890+
X_m += m_f * sum_x;
2891+
}
2892+
2893+
const float det = A * C - B * B;
2894+
2895+
if (det > 0.0f) {
2896+
const float d_new = (C * X_d - B * X_m) / det;
2897+
const float dmin_new = (B * X_d - A * X_m) / det;
2898+
2899+
if (d_new > 0.0f) {
2900+
d = d_new;
2901+
}
2902+
if (dmin_new != 0.0f) {
2903+
dmin = dmin_new;
2904+
}
2905+
}
2906+
2907+
// Check convergence
2908+
const float delta_d = fabsf(d - d_old);
2909+
const float delta_dmin = fabsf(dmin - dmin_old);
2910+
2911+
if (delta_d < epsilon && delta_dmin < epsilon) {
2912+
break;
27602913
}
27612914
}
2762-
y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2763-
y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2915+
2916+
// Final assignment with converged parameters
2917+
for (int j = 0; j < num_subblocks; j++) {
2918+
const float scale = d * sc[j];
2919+
const float offset = -dmin * m[j];
2920+
2921+
for (int l = 0; l < 32; l++) {
2922+
const float v = x[i*QK_K + j*32 + l];
2923+
const int q_int = scale != 0.0f ? nearest_int((v - offset) / scale) : 0;
27642924

2765-
uint8_t sc, m;
2766-
for (int j = 0; j < QK_K/32; ++j) {
2767-
get_scale_min_k4(j, y[i].scales, &sc, &m);
2768-
const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
2769-
if (!d) continue;
2770-
const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
2771-
for (int ii = 0; ii < 32; ++ii) {
2772-
int l = nearest_int((x[32*j + ii] + dm)/d);
2773-
l = MAX(0, MIN(15, l));
2774-
L[32*j + ii] = l;
2925+
// note: kimi-k2-thinking QAT-specific clipping
2926+
q[j*32 + l] = (uint8_t)(q_int < 0 ? 0 : (q_int > 14 ? 14 : q_int));
27752927
}
27762928
}
2777-
2778-
uint8_t * q = y[i].qs;
2779-
for (int j = 0; j < QK_K; j += 64) {
2780-
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
2781-
q += 32;
2929+
2930+
// Store final super-block scales
2931+
y[i].d = GGML_FP32_TO_FP16(d);
2932+
y[i].dmin = GGML_FP32_TO_FP16(dmin);
2933+
2934+
// Pack 4-bit quantized values (layout expected by dequant)
2935+
uint8_t *qs = y[i].qs;
2936+
for (int base = 0, out = 0; base < QK_K; base += 64, out += 32) {
2937+
for (int l = 0; l < 32; ++l) {
2938+
qs[out + l] = (q[base + l] & 0x0F) | ((q[base + 32 + l] & 0x0F) << 4);
2939+
}
27822940
}
2783-
2784-
x += QK_K;
27852941
}
27862942
}
27872943

0 commit comments

Comments
 (0)