Skip to content

Commit 9fee3f5

Browse files
committed
vulkan: Allow non-pow2 n_experts in topk_moe
1 parent 2fa51c1 commit 9fee3f5

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,8 @@ struct vk_device_struct {
757757

758758
vk_pipeline pipeline_flash_attn_split_k_reduce;
759759

760-
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
760+
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
761+
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
761762

762763
std::vector<vk_pipeline_ref> all_pipelines;
763764

@@ -1149,6 +1150,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
11491150

11501151
struct vk_op_topk_moe_push_constants {
11511152
uint32_t n_rows;
1153+
uint32_t n_experts_push;
11521154
uint32_t n_expert_used;
11531155
float clamp_min;
11541156
float clamp_max;
@@ -4204,10 +4206,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
42044206
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
42054207
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
42064208

4207-
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
4208-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
4209-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
4210-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
4209+
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
4210+
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
4211+
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
4212+
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
4213+
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
4214+
}
42114215
}
42124216

42134217
for (auto &c : compiles) {
@@ -8554,7 +8558,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85548558
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
85558559
GGML_ASSERT(idx < num_topk_moe_pipelines);
85568560
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
8557-
return ctx->device->pipeline_topk_moe[idx][mode];
8561+
// use n_experts from push constant if it's not equal to the power of two spec constant
8562+
bool use_push = dst->ne[0] != (1u << idx);
8563+
return ctx->device->pipeline_topk_moe[idx][mode][use_push];
85588564
}
85598565

85608566
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -10158,6 +10164,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
1015810164

1015910165
vk_op_topk_moe_push_constants pc {};
1016010166
pc.n_rows = n_rows;
10167+
pc.n_experts_push = n_experts;
1016110168
pc.n_expert_used = n_expert_used;
1016210169
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
1016310170
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
@@ -12832,8 +12839,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
1283212839
}
1283312840

1283412841
const int n_expert = softmax->ne[0];
12835-
// n_expert must be a power of 2
12836-
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
12842+
if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
1283712843
return false;
1283812844
}
1283912845

ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
layout (push_constant) uniform parameter
1111
{
1212
uint n_rows;
13+
uint n_experts_push;
1314
uint n_expert_used;
1415
float clamp_min;
1516
float clamp_max;
@@ -18,11 +19,16 @@ layout (push_constant) uniform parameter
1819
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
1920

2021
layout(constant_id = 0) const uint WARP_SIZE = 32;
21-
layout(constant_id = 1) const uint n_experts = 512;
22+
layout(constant_id = 1) const uint n_experts_spec = 512;
2223
layout(constant_id = 2) const bool with_norm = true;
2324
layout(constant_id = 3) const bool late_softmax = false;
25+
layout(constant_id = 4) const bool nexperts_use_push = false;
2426

25-
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
27+
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
28+
29+
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
30+
31+
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
2632

2733
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
2834
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
@@ -94,7 +100,7 @@ void main() {
94100
}
95101

96102
if (!late_softmax) {
97-
softmax_warp_inplace(wt, n_experts, lane, false);
103+
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
98104
}
99105

100106
// at this point, each thread holds a portion of softmax,

tests/test-backend-ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7927,8 +7927,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
79277927

79287928
for (bool with_norm : {false, true}) {
79297929
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
7930+
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
79307931
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
7932+
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
7933+
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
79317934
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
7935+
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
79327936
}
79337937

79347938
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));

0 commit comments

Comments
 (0)