@@ -653,11 +653,12 @@ ggml_tensor * llm_build_context::llm_build_ffn(
653653 auto split_u = u->splits [id];
654654 auto split_g = g->splits [id];
655655 auto split_d = d->splits [id];
656- GGML_ASSERT ((!split_u && !split_g && split_d) || (split_u && split_g && split_d));
656+ GGML_ASSERT ((!split_u && !split_g && ! split_d) || (split_u && split_g && split_d));
657657 if (!split_u) continue ;
658658 auto cur = input;
659659 if (ffn_norm && ffn_norm->extra ) {
660660 auto norm = (ggml_split_tensor_t *)ffn_norm->extra ;
661+ GGML_ASSERT (norm->splits [id]);
661662 cur = llm_build_norm (ctx, input, lctx.model .hparams , norm->splits [id], NULL , LLM_NORM_RMS, cb, il);
662663 cb (cur, " ffn_inp_normed" , il_cb);
663664 }
@@ -1088,6 +1089,7 @@ llm_expert_gating_func_type gating_op,
10881089 auto cur = input;
10891090 if (ffn_norm) {
10901091 auto the_ffn_norm = ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra )->splits [lctx.model .main_gpu ] : ffn_norm;
1092+ GGML_ASSERT (the_ffn_norm);
10911093 cur = llm_build_norm (ctx, input, lctx.model .hparams , the_ffn_norm, nullptr , LLM_NORM_RMS, cb, il);
10921094 cb (cur, " ffn_inp_normed" , il);
10931095 }
@@ -1109,17 +1111,18 @@ llm_expert_gating_func_type gating_op,
11091111 gating_op, cb, il, graph);
11101112 cb (routed_out, " routed_out" , il);
11111113 ggml_build_forward_expand (graph, routed_out);
1112- // printf("Using non-split llm_build_moe_ffn for layer %d. n_before = %d, n_now = %d\n", il, n_before, graph->n_nodes);
11131114
11141115 if (up_shexp && gate_shexp && down_shexp) {
11151116 if (split_up_shexp) {
1116- // printf("Using split ffn for shared experts in layer %d\n", il);
1117- std::vector<ggml_tensor *> results (split_up_shexp->n_device );
1117+ std::vector<ggml_tensor *> results; results.reserve (split_up_shexp->n_device );
11181118 GGML_ASSERT (!split_up_b_shexp || split_up_b_shexp->n_device == split_up_shexp->n_device );
11191119 GGML_ASSERT (!split_gate_b_shexp || split_gate_b_shexp->n_device == split_up_shexp->n_device );
11201120 GGML_ASSERT (!split_down_b_shexp || split_down_b_shexp->n_device == split_up_shexp->n_device );
11211121 for (int id = 0 ; id < split_up_shexp->n_device ; ++id) {
11221122 int il_cb = 1000 *id + il;
1123+ GGML_ASSERT ((split_up_shexp->splits [id] && split_gate_shexp->splits [id] && split_down_shexp->splits [id]) ||
1124+ (!split_up_shexp->splits [id] && !split_gate_shexp->splits [id] && !split_down_shexp->splits [id]));
1125+ if (!split_up_shexp->splits [id]) continue ;
11231126 auto the_ffn_norm = ffn_norm ? ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra )->splits [id] : ffn_norm : nullptr ;
11241127 auto shared_out = llm_build_ffn (ctx, lctx, the_ffn_norm, input,
11251128 split_up_shexp->splits [id], split_up_b_shexp ? split_up_b_shexp->splits [id] : nullptr , nullptr ,
@@ -1130,17 +1133,19 @@ llm_expert_gating_func_type gating_op,
11301133 if (shared_out->ne [1 ] > 32 ) {
11311134 shared_out = ggml_cast (ctx, shared_out, GGML_TYPE_F16);
11321135 }
1133- results[id] = shared_out;
1136+ results. push_back ( shared_out) ;
11341137 }
1135- cur = ggml_add (ctx, results[0 ], results[1 ]);
1136- if (cur->ne [1 ] > 32 ) {
1137- // Force a graph split
1138+ GGML_ASSERT (!results.empty ());
1139+ if (results.size () == 1 ) {
1140+ cur = results.front ();
1141+ } else {
1142+ cur = ggml_add (ctx, results[0 ], results[1 ]);
11381143 cur->op_params [0 ] = 0xff ;
1139- }
1140- cb (cur, " ffn_shared_combined" , il);
1141- for (int id = 2 ; id < int (results.size ()); ++id) {
1142- cur = ggml_add (ctx, cur, results[id]);
11431144 cb (cur, " ffn_shared_combined" , il);
1145+ for (int id = 2 ; id < int (results.size ()); ++id) {
1146+ cur = ggml_add (ctx, cur, results[id]);
1147+ cb (cur, " ffn_shared_combined" , il);
1148+ }
11441149 }
11451150 if (routed_out->ne [1 ] > 32 ) {
11461151 auto routed_out_f16 = ggml_cast (ctx, routed_out, GGML_TYPE_F16);
@@ -1150,7 +1155,6 @@ llm_expert_gating_func_type gating_op,
11501155 }
11511156 cb (cur, " ffn_out" , il);
11521157 } else {
1153- // printf("Using non-split ffn for shared experts in layer %d\n", il);
11541158 auto shared_out = llm_build_ffn (ctx, lctx, nullptr , cur,
11551159 up_shexp, up_b_shexp, nullptr ,
11561160 gate_shexp, gate_b_shexp, nullptr ,
@@ -1170,14 +1174,17 @@ llm_expert_gating_func_type gating_op,
11701174 }
11711175 GGML_ASSERT (split_up_exps && split_gate_exps && split_down_exps);
11721176 GGML_ASSERT (split_up_exps->n_device == split_gate_exps->n_device && split_up_exps->n_device == split_down_exps->n_device );
1173- std::vector<ggml_tensor *> results (split_up_exps->n_device );
1177+ std::vector<ggml_tensor *> results; results. reserve (split_up_exps->n_device );
11741178 GGML_ASSERT ((!split_up_shexp && !split_gate_shexp && !split_down_shexp) ||
11751179 ( split_up_shexp && split_gate_shexp && split_down_shexp));
11761180 auto split_gate_inp = (ggml_split_tensor_t *)gate_inp->extra ;
11771181 GGML_ASSERT (split_gate_inp && split_gate_inp->n_device == split_up_exps->n_device );
11781182 auto split_exp_probs_b = exp_probs_b ? (ggml_split_tensor_t *)exp_probs_b->extra : nullptr ;
11791183 GGML_ASSERT (!split_exp_probs_b || split_exp_probs_b->n_device == split_up_exps->n_device );
11801184 for (int id = 0 ; id < split_up_exps->n_device ; ++id) {
1185+ GGML_ASSERT ((split_up_exps->splits [id] && split_gate_exps->splits [id] && split_down_exps->splits [id]) ||
1186+ (!split_up_exps->splits [id] && !split_gate_exps->splits [id] && !split_down_exps->splits [id]));
1187+ if (!split_up_exps->splits [id]) continue ;
11811188 int il_cb = 1000 *(id + 1 ) + il;
11821189 auto cur = input;
11831190 if (ffn_norm) {
@@ -1220,8 +1227,9 @@ llm_expert_gating_func_type gating_op,
12201227 cur = ggml_cast (ctx, cur, GGML_TYPE_F16);
12211228 cb (cur, " ffn_out_f16" , il_cb);
12221229 }
1223- results[id] = cur;
1230+ results. push_back ( cur) ;
12241231 }
1232+ GGML_ASSERT (!results.empty ());
12251233 if (results.size () == 1 ) return results.front ();
12261234
12271235 auto cur = ggml_add (ctx, results[0 ], results[1 ]);
@@ -1660,10 +1668,15 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml
16601668 }
16611669 cb (o.back (), " output" , id);
16621670 }
1663- if (o.size () == 1 ) cur = o.front ();
1664- cur = ggml_concat (ctx, o[0 ], o[1 ], 0 );
1665- for (int id = 2 ; id < int (o.size ()); ++id) {
1666- cur = ggml_concat (ctx, cur, o[id], 0 );
1671+ GGML_ASSERT (!o.empty ());
1672+ if (o.size () == 1 ) {
1673+ cur = o.front ();
1674+ }
1675+ else {
1676+ cur = ggml_concat (ctx, o[0 ], o[1 ], 0 );
1677+ for (int id = 2 ; id < int (o.size ()); ++id) {
1678+ cur = ggml_concat (ctx, cur, o[id], 0 );
1679+ }
16671680 }
16681681 } else {
16691682 if (output_norm) {
@@ -9455,6 +9468,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
94559468 ggml_build_forward_expand (gf, cur);
94569469 attn.push_back (cur);
94579470 }
9471+ GGML_ASSERT (!attn.empty ());
94589472 if (attn.size () == 1 ) return attn.front ();
94599473 auto cur = ggml_add (ctx0, attn[0 ], attn[1 ]);
94609474 cb (cur, " combine_attn" , il);
@@ -9463,10 +9477,6 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
94639477 cur = ggml_add (ctx0, cur, attn[id]);
94649478 cb (cur, " combine_attn" , il);
94659479 }
9466- // TODO: for more than 2 GPUs, do we need to add another forced graph split?
9467- // if (attn.size() > 2) {
9468- // cur->op_params[0] = 0xff;
9469- // }
94709480 return cur;
94719481 }
94729482 }
0 commit comments