Skip to content

Commit 2b87602

Browse files
ddh0Nexesenex
authored andcommitted
support GLM-4.5 MoE models ggml-org#15026
initial PR commit add GGUF constants initial GLM-4.5 integration fix typo `LLM_ATCH_GLM4_MOE` --> `LLM_ARCH_GLM4_MOE` add glm4_moe tensor mapping add `attn_k_norm` and `attn_q_norm` tensors for GLM-4.5 more consistent organization more consistent organization (cont.) Merge branch 'ggml-org:master' into glm45 Merge branch 'ggml-org:master' into glm45
1 parent bfbddcd commit 2b87602

File tree

12 files changed

+371
-150
lines changed

12 files changed

+371
-150
lines changed

common/chat.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16461646
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
16471647
);
16481648

1649-
if (auto res = builder.try_find_regex(open_regex)) {
1649+
while (auto res = builder.try_find_regex(open_regex)) {
16501650
const auto & block_start = res->groups[1];
16511651
std::string block_end = block_start.empty() ? "" : "```";
16521652

@@ -1668,7 +1668,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16681668
builder.consume_literal(block_end);
16691669
builder.consume_spaces();
16701670
}
1671-
builder.add_content(builder.consume_rest());
16721671
} else {
16731672
throw common_chat_msg_partial_exception("failed to parse tool call");
16741673
}
@@ -1693,11 +1692,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16931692
builder.consume_spaces();
16941693
}
16951694
}
1696-
builder.add_content(builder.consume_rest());
16971695
}
1698-
} else {
1699-
builder.add_content(builder.consume_rest());
17001696
}
1697+
1698+
builder.add_content(builder.consume_rest());
17011699
}
17021700

17031701
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 233 additions & 94 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
#version 450
22

3+
#extension GL_EXT_control_flow_attributes : enable
4+
35
#ifdef USE_COLLECTIVES
46
# extension GL_KHR_shader_subgroup_shuffle : enable
57
#endif
68

79
#include "types.comp"
810

9-
// Make spec constant
10-
#define SHMEM_PAD 0
11-
1211
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
1312
layout(binding = 0) readonly buffer A {
1413
A_TYPE knl_data[];
@@ -56,6 +55,12 @@ layout(push_constant) uniform parameter {
5655
uint32_t nb1;
5756
uint32_t nb2;
5857
uint32_t nb3;
58+
59+
// fastdiv helper values
60+
uint32_t KWmp; uint32_t KWL;
61+
uint32_t KWKHmp; uint32_t KWKHL;
62+
uint32_t OWmp; uint32_t OWL;
63+
uint32_t OWOHmp; uint32_t OWOHL;
5964
}
6065

6166
p;
@@ -68,6 +73,7 @@ layout(constant_id = 3) const uint BS_NPQ = 128;
6873
// Thread-tile sizes
6974
layout(constant_id = 4) const uint TS_K = 8;
7075
layout(constant_id = 5) const uint use_collectives = 1;
76+
layout(constant_id = 6) const uint SHMEM_PAD = 4;
7177

7278
uint32_t tid = gl_LocalInvocationID.x;
7379
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
@@ -131,6 +137,14 @@ uint32_t Br = tid / BS_NPQ;
131137
uint32_t Bc = tid % BS_NPQ;
132138
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
133139

140+
// see init_fastdiv_values in ggml-vulkan.cpp
141+
uint fastdiv(uint n, uint mp, uint L) {
142+
uint msbs, lsbs;
143+
// msbs = mulhi(n, mp)
144+
umulExtended(n, mp, msbs, lsbs);
145+
return (msbs + n) >> L;
146+
}
147+
134148
void main() {
135149
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
136150
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
@@ -151,9 +165,9 @@ void main() {
151165
uint32_t cached_KW_idx;
152166
if (use_collectives == 1) {
153167
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
154-
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
168+
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
155169
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
156-
cached_KH_idx = cached_CRS_remainder / p.KW;
170+
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
157171
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
158172

159173
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
@@ -162,16 +176,16 @@ void main() {
162176
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
163177
} else {
164178
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
165-
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
179+
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
166180
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
167-
KH_idx_a = CRS_remainder / p.KW;
181+
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
168182
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
169183
}
170184
#else
171185
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
172-
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
186+
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
173187
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
174-
KH_idx_a = CRS_remainder / p.KW;
188+
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
175189
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
176190
#endif
177191

@@ -188,13 +202,13 @@ void main() {
188202
Ash[B_ly * Ash_stride + B_lx] = val;
189203
}
190204
/* Load input to B_block: (BS_CRS x BS_NPQ) */
191-
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
205+
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
192206
uint32_t B_ly = r_offset + Br; /* Row index of B block */
193207
uint32_t B_lx = Bc;
194208
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
195-
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
209+
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
196210
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
197-
uint32_t OH_idx = NPQ_remainder / p.OW;
211+
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
198212
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
199213

200214
uint32_t CRS_idx_b;
@@ -209,16 +223,16 @@ void main() {
209223
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
210224
} else {
211225
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
212-
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
226+
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
213227
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
214-
KH_idx_b = CRS_remainder / p.KW;
228+
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
215229
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
216230
}
217231
#else
218232
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
219-
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
233+
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
220234
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
221-
KH_idx_b = CRS_remainder / p.KW;
235+
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
222236
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
223237
#endif
224238

@@ -233,32 +247,36 @@ void main() {
233247
Bsh[B_ly * Bsh_stride + B_lx] = val;
234248
}
235249
barrier();
236-
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
237-
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
238-
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
239-
}
240-
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
241-
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
242-
}
243-
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
250+
if (T_y * TS_K < K) {
251+
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
252+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253+
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
254+
}
244255
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
245-
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
256+
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
257+
}
258+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
259+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
260+
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
261+
}
246262
}
247263
}
248264
}
249265
barrier();
250266
}
251267
/* Save C* */
252-
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253-
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
254-
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
255-
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
256-
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
257-
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
258-
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
259-
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
260-
if (K_idx < K && NPQ_idx < NPQ) {
261-
dst_data[dst_idx] = regC[T_ly][T_lx];
268+
if (T_y * TS_K < K) {
269+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
270+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
271+
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
272+
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
273+
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
274+
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
275+
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
276+
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
277+
if (K_idx < K && NPQ_idx < NPQ) {
278+
dst_data[dst_idx] = regC[T_ly][T_lx];
279+
}
262280
}
263281
}
264282
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ layout (push_constant) uniform parameter
2626
uint ne12;
2727
uint b_offset;
2828
uint d_offset;
29+
uint nb03;
30+
uint nb13;
31+
uint nb23;
2932
} p;
3033

3134
shared FLOAT_TYPE tmp[BLOCK_SIZE];
@@ -34,14 +37,15 @@ void main() {
3437
const uint tid = gl_LocalInvocationID.x;
3538
const uint row_x = gl_GlobalInvocationID.y;
3639
const uint channel = gl_GlobalInvocationID.z;
40+
const uint i3 = gl_WorkGroupID.x;
3741
const uint channel_x = channel / p.channel_x_divisor;
3842
const uint channel_y = channel % p.ne12;
3943

4044
const uint nrows_y = p.ncols_x;
4145
const uint nrows_dst = p.nrows_x;
4246
const uint row_dst = row_x;
4347

44-
const uint idst = channel*nrows_dst + row_dst;
48+
const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
4549

4650
FLOAT_TYPE temp = 0.0f;
4751

@@ -58,8 +62,8 @@ void main() {
5862

5963
const uint row_y = col_x;
6064

61-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
62-
const uint iy = channel_y*p.channel_stride_y + row_y;
65+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
66+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
6367

6468
const vec4 av4 = vec4(data_a_v4[ix / 4]);
6569
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -74,8 +78,8 @@ void main() {
7478

7579
const uint row_y = col_x;
7680

77-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
78-
const uint iy = channel_y*p.channel_stride_y + row_y;
81+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
82+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
7983

8084
const vec4 av4 = vec4(data_a_v4[ix / 4]);
8185
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -91,8 +95,8 @@ void main() {
9195

9296
const uint row_y = col_x;
9397

94-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
95-
const uint iy = channel_y*p.channel_stride_y + row_y;
98+
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
99+
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
96100

97101
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
98102

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,11 @@ void process_shaders() {
669669

670670
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
671671

672-
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
673-
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
672+
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
673+
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
674+
675+
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
676+
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
674677

675678
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
676679
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

gguf-py/gguf/constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class MODEL_ARCH(IntEnum):
358358
DEEPSEEK2 = auto()
359359
CHATGLM = auto()
360360
GLM4 = auto()
361+
GLM4_MOE = auto()
361362
BITNET = auto()
362363
BITNET_25 = auto()
363364
T5 = auto()
@@ -680,6 +681,7 @@ class MODEL_TENSOR(IntEnum):
680681
MODEL_ARCH.DEEPSEEK2: "deepseek2",
681682
MODEL_ARCH.CHATGLM: "chatglm",
682683
MODEL_ARCH.GLM4: "glm4",
684+
MODEL_ARCH.GLM4_MOE: "glm4_moe",
683685
MODEL_ARCH.BITNET: "bitnet",
684686
MODEL_ARCH.BITNET_25: "bitnet-25",
685687
MODEL_ARCH.T5: "t5",
@@ -2127,6 +2129,29 @@ class MODEL_TENSOR(IntEnum):
21272129
MODEL_TENSOR.ATTN_POST_NORM,
21282130
MODEL_TENSOR.FFN_POST_NORM,
21292131
],
2132+
MODEL_ARCH.GLM4_MOE: [
2133+
MODEL_TENSOR.TOKEN_EMBD,
2134+
MODEL_TENSOR.OUTPUT_NORM,
2135+
MODEL_TENSOR.OUTPUT,
2136+
MODEL_TENSOR.ATTN_NORM,
2137+
MODEL_TENSOR.ATTN_K_NORM, # not always present
2138+
MODEL_TENSOR.ATTN_Q_NORM, # not always present
2139+
MODEL_TENSOR.ATTN_Q,
2140+
MODEL_TENSOR.ATTN_K,
2141+
MODEL_TENSOR.ATTN_V,
2142+
MODEL_TENSOR.ATTN_OUT,
2143+
MODEL_TENSOR.FFN_NORM,
2144+
MODEL_TENSOR.FFN_GATE,
2145+
MODEL_TENSOR.FFN_DOWN,
2146+
MODEL_TENSOR.FFN_UP,
2147+
MODEL_TENSOR.FFN_GATE_EXP,
2148+
MODEL_TENSOR.FFN_DOWN_EXP,
2149+
MODEL_TENSOR.FFN_UP_EXP,
2150+
MODEL_TENSOR.FFN_GATE_SHEXP,
2151+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2152+
MODEL_TENSOR.FFN_UP_SHEXP,
2153+
MODEL_TENSOR.FFN_EXP_PROBS_B, # AKA "e_score_correction_bias" in transformers
2154+
],
21302155
MODEL_ARCH.BITNET: [
21312156
MODEL_TENSOR.ATTN_Q,
21322157
MODEL_TENSOR.ATTN_K,

src/llama-arch.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,33 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
13911391
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
13921392
},
13931393
},
1394+
{
1395+
LLM_ARCH_GLM4_MOE,
1396+
{
1397+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1398+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1399+
{ LLM_TENSOR_OUTPUT, "output" },
1400+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1401+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1402+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1403+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1404+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1405+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1406+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1407+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1408+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1409+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1410+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1411+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1412+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1413+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1414+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1415+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1416+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1417+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1418+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1419+
},
1420+
},
13941421
{
13951422
LLM_ARCH_BITNET,
13961423
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum llm_arch {
6666
LLM_ARCH_DEEPSEEK2,
6767
LLM_ARCH_CHATGLM,
6868
LLM_ARCH_GLM4,
69+
LLM_ARCH_GLM4_MOE,
6970
LLM_ARCH_BITNET,
7071
LLM_ARCH_T5,
7172
LLM_ARCH_T5ENCODER,

src/llama-graph.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -768,8 +768,10 @@ ggml_tensor * llm_graph_context::build_ffn(
768768

769769
if (down) {
770770
cur = build_lora_mm(down, cur);
771-
if (arch == LLM_ARCH_GLM4) {
772-
// GLM4 seems to have numerical issues with half-precision accumulators
771+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
772+
// GLM4 FFNs seem to have numerical issues with half-precision accumulators
773+
// -- ref: https://github.com/ggml-org/llama.cpp/pull/13101
774+
// (GLM4_MOE uses some GLM4 FFNs, so we need to match it too)
773775
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
774776
}
775777
}
@@ -1524,8 +1526,10 @@ ggml_tensor * llm_graph_context::build_attn(
15241526

15251527
if (wo) {
15261528
cur = build_lora_mm(wo, cur);
1527-
if (arch == LLM_ARCH_GLM4) {
1528-
// GLM4 seems to have numerical issues with half-precision accumulators
1529+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1530+
// GLM4 FFNs seem to have numerical issues with half-precision accumulators
1531+
// -- ref: https://github.com/ggml-org/llama.cpp/pull/13101
1532+
// (GLM4_MOE uses some GLM4 FFNs, so we need to match it too)
15291533
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
15301534
}
15311535
}

0 commit comments

Comments
 (0)