Skip to content

Commit 54d83bb

Browse files
authored
vulkan: remove a couple unnecessary switches (#17419)
1 parent 4949ac0 commit 54d83bb

File tree

1 file changed

+17
-263
lines changed

1 file changed

+17
-263
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 17 additions & 263 deletions
Original file line numberDiff line numberDiff line change
@@ -11381,13 +11381,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
1138111381
}
1138211382
}
1138311383

11384-
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
11384+
static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
1138511385

1138611386
// Returns true if node has enqueued work into the queue, false otherwise
1138711387
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
1138811388
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
1138911389
ggml_tensor * node = cgraph->nodes[node_idx];
11390-
if (ggml_is_empty(node) || !node->buffer) {
11390+
if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
1139111391
return false;
1139211392
}
1139311393

@@ -11399,132 +11399,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1139911399
ggml_tensor * src2 = node->src[2];
1140011400
ggml_tensor * src3 = node->src[3];
1140111401

11402-
switch (node->op) {
11403-
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
11404-
case GGML_OP_RESHAPE:
11405-
case GGML_OP_VIEW:
11406-
case GGML_OP_PERMUTE:
11407-
case GGML_OP_TRANSPOSE:
11408-
case GGML_OP_NONE:
11409-
return false;
11410-
case GGML_OP_UNARY:
11411-
switch (ggml_get_unary_op(node)) {
11412-
case GGML_UNARY_OP_EXP:
11413-
case GGML_UNARY_OP_SILU:
11414-
case GGML_UNARY_OP_GELU:
11415-
case GGML_UNARY_OP_GELU_ERF:
11416-
case GGML_UNARY_OP_GELU_QUICK:
11417-
case GGML_UNARY_OP_RELU:
11418-
case GGML_UNARY_OP_NEG:
11419-
case GGML_UNARY_OP_TANH:
11420-
case GGML_UNARY_OP_SIGMOID:
11421-
case GGML_UNARY_OP_HARDSIGMOID:
11422-
case GGML_UNARY_OP_HARDSWISH:
11423-
case GGML_UNARY_OP_ABS:
11424-
case GGML_UNARY_OP_SOFTPLUS:
11425-
case GGML_UNARY_OP_STEP:
11426-
case GGML_UNARY_OP_ROUND:
11427-
case GGML_UNARY_OP_CEIL:
11428-
case GGML_UNARY_OP_FLOOR:
11429-
case GGML_UNARY_OP_TRUNC:
11430-
break;
11431-
default:
11432-
return false;
11433-
}
11434-
break;
11435-
case GGML_OP_GLU:
11436-
switch (ggml_get_glu_op(node)) {
11437-
case GGML_GLU_OP_GEGLU:
11438-
case GGML_GLU_OP_REGLU:
11439-
case GGML_GLU_OP_SWIGLU:
11440-
case GGML_GLU_OP_SWIGLU_OAI:
11441-
case GGML_GLU_OP_GEGLU_ERF:
11442-
case GGML_GLU_OP_GEGLU_QUICK:
11443-
break;
11444-
default:
11445-
return false;
11446-
}
11447-
break;
11448-
case GGML_OP_ADD:
11449-
{
11450-
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
11451-
if (next_node_idx < cgraph->n_nodes &&
11452-
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
11453-
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
11454-
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
11455-
ctx->device->add_rms_fusion) {
11456-
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
11457-
ctx->do_add_rms_partials_offset_calculation = true;
11458-
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
11459-
ctx->do_add_rms_partials = true;
11460-
}
11402+
if (node->op == GGML_OP_ADD) {
11403+
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
11404+
if (next_node_idx < cgraph->n_nodes &&
11405+
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
11406+
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
11407+
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
11408+
ctx->device->add_rms_fusion) {
11409+
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
11410+
ctx->do_add_rms_partials_offset_calculation = true;
11411+
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
11412+
ctx->do_add_rms_partials = true;
1146111413
}
11462-
} break;
11463-
case GGML_OP_REPEAT:
11464-
case GGML_OP_REPEAT_BACK:
11465-
case GGML_OP_GET_ROWS:
11466-
case GGML_OP_ADD_ID:
11467-
case GGML_OP_ACC:
11468-
case GGML_OP_SUB:
11469-
case GGML_OP_MUL:
11470-
case GGML_OP_DIV:
11471-
case GGML_OP_ADD1:
11472-
case GGML_OP_ARANGE:
11473-
case GGML_OP_FILL:
11474-
case GGML_OP_CONCAT:
11475-
case GGML_OP_UPSCALE:
11476-
case GGML_OP_SCALE:
11477-
case GGML_OP_SQR:
11478-
case GGML_OP_SQRT:
11479-
case GGML_OP_SIN:
11480-
case GGML_OP_COS:
11481-
case GGML_OP_LOG:
11482-
case GGML_OP_CLAMP:
11483-
case GGML_OP_PAD:
11484-
case GGML_OP_ROLL:
11485-
case GGML_OP_CPY:
11486-
case GGML_OP_SET_ROWS:
11487-
case GGML_OP_CONT:
11488-
case GGML_OP_DUP:
11489-
case GGML_OP_SILU_BACK:
11490-
case GGML_OP_NORM:
11491-
case GGML_OP_GROUP_NORM:
11492-
case GGML_OP_RMS_NORM:
11493-
case GGML_OP_RMS_NORM_BACK:
11494-
case GGML_OP_L2_NORM:
11495-
case GGML_OP_DIAG_MASK_INF:
11496-
case GGML_OP_SOFT_MAX:
11497-
case GGML_OP_SOFT_MAX_BACK:
11498-
case GGML_OP_ROPE:
11499-
case GGML_OP_ROPE_BACK:
11500-
case GGML_OP_MUL_MAT:
11501-
case GGML_OP_MUL_MAT_ID:
11502-
case GGML_OP_ARGSORT:
11503-
case GGML_OP_SUM:
11504-
case GGML_OP_SUM_ROWS:
11505-
case GGML_OP_MEAN:
11506-
case GGML_OP_ARGMAX:
11507-
case GGML_OP_COUNT_EQUAL:
11508-
case GGML_OP_IM2COL:
11509-
case GGML_OP_IM2COL_3D:
11510-
case GGML_OP_TIMESTEP_EMBEDDING:
11511-
case GGML_OP_CONV_TRANSPOSE_1D:
11512-
case GGML_OP_POOL_2D:
11513-
case GGML_OP_CONV_2D:
11514-
case GGML_OP_CONV_TRANSPOSE_2D:
11515-
case GGML_OP_CONV_2D_DW:
11516-
case GGML_OP_RWKV_WKV6:
11517-
case GGML_OP_RWKV_WKV7:
11518-
case GGML_OP_SSM_SCAN:
11519-
case GGML_OP_SSM_CONV:
11520-
case GGML_OP_LEAKY_RELU:
11521-
case GGML_OP_FLASH_ATTN_EXT:
11522-
case GGML_OP_OPT_STEP_ADAMW:
11523-
case GGML_OP_OPT_STEP_SGD:
11524-
break;
11525-
default:
11526-
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
11527-
GGML_ABORT("fatal error");
11414+
}
1152811415
}
1152911416

1153011417
vk_context compute_ctx;
@@ -11961,145 +11848,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1196111848

1196211849
ctx->compute_ctx.reset();
1196311850

11964-
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
11965-
if (!ok) {
11966-
if (node->op == GGML_OP_UNARY) {
11967-
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
11968-
} else if (node->op == GGML_OP_GLU) {
11969-
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
11970-
} else {
11971-
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
11972-
}
11973-
}
11974-
11851+
ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
1197511852
}
1197611853
return true;
1197711854
}
1197811855

11979-
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
11856+
static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
1198011857
GGML_UNUSED(cgraph);
11981-
ggml_backend_buffer * buf = nullptr;
11982-
11983-
switch (tensor->op) {
11984-
case GGML_OP_ADD:
11985-
case GGML_OP_ACC:
11986-
case GGML_OP_GET_ROWS:
11987-
case GGML_OP_SUB:
11988-
case GGML_OP_MUL:
11989-
case GGML_OP_DIV:
11990-
case GGML_OP_ADD1:
11991-
case GGML_OP_ARANGE:
11992-
case GGML_OP_FILL:
11993-
case GGML_OP_ADD_ID:
11994-
case GGML_OP_CONCAT:
11995-
case GGML_OP_UPSCALE:
11996-
case GGML_OP_SCALE:
11997-
case GGML_OP_SQR:
11998-
case GGML_OP_SQRT:
11999-
case GGML_OP_SIN:
12000-
case GGML_OP_COS:
12001-
case GGML_OP_LOG:
12002-
case GGML_OP_CLAMP:
12003-
case GGML_OP_PAD:
12004-
case GGML_OP_ROLL:
12005-
case GGML_OP_CPY:
12006-
case GGML_OP_SET_ROWS:
12007-
case GGML_OP_CONT:
12008-
case GGML_OP_DUP:
12009-
case GGML_OP_SILU_BACK:
12010-
case GGML_OP_NORM:
12011-
case GGML_OP_GROUP_NORM:
12012-
case GGML_OP_RMS_NORM:
12013-
case GGML_OP_RMS_NORM_BACK:
12014-
case GGML_OP_L2_NORM:
12015-
case GGML_OP_DIAG_MASK_INF:
12016-
case GGML_OP_SOFT_MAX:
12017-
case GGML_OP_SOFT_MAX_BACK:
12018-
case GGML_OP_ROPE:
12019-
case GGML_OP_ROPE_BACK:
12020-
case GGML_OP_RESHAPE:
12021-
case GGML_OP_VIEW:
12022-
case GGML_OP_PERMUTE:
12023-
case GGML_OP_TRANSPOSE:
12024-
case GGML_OP_NONE:
12025-
case GGML_OP_ARGSORT:
12026-
case GGML_OP_SUM:
12027-
case GGML_OP_SUM_ROWS:
12028-
case GGML_OP_MEAN:
12029-
case GGML_OP_ARGMAX:
12030-
case GGML_OP_COUNT_EQUAL:
12031-
case GGML_OP_IM2COL:
12032-
case GGML_OP_IM2COL_3D:
12033-
case GGML_OP_TIMESTEP_EMBEDDING:
12034-
case GGML_OP_CONV_TRANSPOSE_1D:
12035-
case GGML_OP_POOL_2D:
12036-
case GGML_OP_CONV_2D:
12037-
case GGML_OP_CONV_TRANSPOSE_2D:
12038-
case GGML_OP_CONV_2D_DW:
12039-
case GGML_OP_RWKV_WKV6:
12040-
case GGML_OP_RWKV_WKV7:
12041-
case GGML_OP_SSM_SCAN:
12042-
case GGML_OP_SSM_CONV:
12043-
case GGML_OP_LEAKY_RELU:
12044-
case GGML_OP_REPEAT:
12045-
case GGML_OP_REPEAT_BACK:
12046-
case GGML_OP_OPT_STEP_ADAMW:
12047-
case GGML_OP_OPT_STEP_SGD:
12048-
buf = tensor->buffer;
12049-
break;
12050-
case GGML_OP_UNARY:
12051-
switch (ggml_get_unary_op(tensor)) {
12052-
case GGML_UNARY_OP_EXP:
12053-
case GGML_UNARY_OP_SILU:
12054-
case GGML_UNARY_OP_GELU:
12055-
case GGML_UNARY_OP_GELU_ERF:
12056-
case GGML_UNARY_OP_GELU_QUICK:
12057-
case GGML_UNARY_OP_RELU:
12058-
case GGML_UNARY_OP_NEG:
12059-
case GGML_UNARY_OP_TANH:
12060-
case GGML_UNARY_OP_SIGMOID:
12061-
case GGML_UNARY_OP_HARDSIGMOID:
12062-
case GGML_UNARY_OP_HARDSWISH:
12063-
case GGML_UNARY_OP_ABS:
12064-
case GGML_UNARY_OP_SOFTPLUS:
12065-
case GGML_UNARY_OP_STEP:
12066-
case GGML_UNARY_OP_ROUND:
12067-
case GGML_UNARY_OP_CEIL:
12068-
case GGML_UNARY_OP_FLOOR:
12069-
case GGML_UNARY_OP_TRUNC:
12070-
buf = tensor->buffer;
12071-
break;
12072-
default:
12073-
return false;
12074-
}
12075-
break;
12076-
case GGML_OP_GLU:
12077-
switch (ggml_get_glu_op(tensor)) {
12078-
case GGML_GLU_OP_GEGLU:
12079-
case GGML_GLU_OP_REGLU:
12080-
case GGML_GLU_OP_SWIGLU:
12081-
case GGML_GLU_OP_SWIGLU_OAI:
12082-
case GGML_GLU_OP_GEGLU_ERF:
12083-
case GGML_GLU_OP_GEGLU_QUICK:
12084-
buf = tensor->buffer;
12085-
break;
12086-
default:
12087-
return false;
12088-
}
12089-
break;
12090-
case GGML_OP_MUL_MAT:
12091-
case GGML_OP_MUL_MAT_ID:
12092-
case GGML_OP_FLASH_ATTN_EXT:
12093-
buf = tensor->buffer;
12094-
12095-
break;
12096-
default:
12097-
return false;
12098-
}
12099-
12100-
if (buf == nullptr) {
12101-
return false;
12102-
}
11858+
GGML_UNUSED(tensor);
1210311859

1210411860
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
1210511861

@@ -12143,8 +11899,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1214311899
subctx->out_memcpys.clear();
1214411900
subctx->memsets.clear();
1214511901
}
12146-
12147-
return true;
1214811902
}
1214911903

1215011904
// Clean up after graph processing is done

0 commit comments

Comments
 (0)