@@ -757,7 +757,8 @@ struct vk_device_struct {
757757
758758 vk_pipeline pipeline_flash_attn_split_k_reduce;
759759
760- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
760+ // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
761+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
761762
762763 std::vector<vk_pipeline_ref> all_pipelines;
763764
@@ -1149,6 +1150,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
11491150
11501151struct vk_op_topk_moe_push_constants {
11511152 uint32_t n_rows;
1153+ uint32_t n_experts_push;
11521154 uint32_t n_expert_used;
11531155 float clamp_min;
11541156 float clamp_max;
@@ -4204,10 +4206,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
42044206 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
42054207 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
42064208
4207- for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
4208- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
4209- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
4210- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
4209+ for (uint32_t use_push = 0; use_push < 2; ++use_push) {
4210+ for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
4211+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
4212+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
4213+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
4214+ }
42114215 }
42124216
42134217 for (auto &c : compiles) {
@@ -8554,7 +8558,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85548558 uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
85558559 GGML_ASSERT(idx < num_topk_moe_pipelines);
85568560 topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
8557- return ctx->device->pipeline_topk_moe[idx][mode];
8561+ // use n_experts from push constant if it's not equal to the power of two spec constant
8562+ bool use_push = dst->ne[0] != (1u << idx);
8563+ return ctx->device->pipeline_topk_moe[idx][mode][use_push];
85588564 }
85598565
85608566 if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -10158,6 +10164,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
1015810164
1015910165 vk_op_topk_moe_push_constants pc {};
1016010166 pc.n_rows = n_rows;
10167+ pc.n_experts_push = n_experts;
1016110168 pc.n_expert_used = n_expert_used;
1016210169 if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
1016310170 ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
@@ -12832,8 +12839,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1283212839 }
1283312840
1283412841 const int n_expert = softmax->ne[0];
12835- // n_expert must be a power of 2
12836- if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
12842+ if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
1283712843 return false;
1283812844 }
1283912845
0 commit comments