@@ -11381,13 +11381,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
1138111381 }
1138211382}
1138311383
11384- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
11384+ static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
1138511385
1138611386// Returns true if node has enqueued work into the queue, false otherwise
1138711387// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
1138811388static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
1138911389 ggml_tensor * node = cgraph->nodes[node_idx];
11390- if (ggml_is_empty(node) || !node->buffer) {
11390+ if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
1139111391 return false;
1139211392 }
1139311393
@@ -11399,132 +11399,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1139911399 ggml_tensor * src2 = node->src[2];
1140011400 ggml_tensor * src3 = node->src[3];
1140111401
11402- switch (node->op) {
11403- // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
11404- case GGML_OP_RESHAPE:
11405- case GGML_OP_VIEW:
11406- case GGML_OP_PERMUTE:
11407- case GGML_OP_TRANSPOSE:
11408- case GGML_OP_NONE:
11409- return false;
11410- case GGML_OP_UNARY:
11411- switch (ggml_get_unary_op(node)) {
11412- case GGML_UNARY_OP_EXP:
11413- case GGML_UNARY_OP_SILU:
11414- case GGML_UNARY_OP_GELU:
11415- case GGML_UNARY_OP_GELU_ERF:
11416- case GGML_UNARY_OP_GELU_QUICK:
11417- case GGML_UNARY_OP_RELU:
11418- case GGML_UNARY_OP_NEG:
11419- case GGML_UNARY_OP_TANH:
11420- case GGML_UNARY_OP_SIGMOID:
11421- case GGML_UNARY_OP_HARDSIGMOID:
11422- case GGML_UNARY_OP_HARDSWISH:
11423- case GGML_UNARY_OP_ABS:
11424- case GGML_UNARY_OP_SOFTPLUS:
11425- case GGML_UNARY_OP_STEP:
11426- case GGML_UNARY_OP_ROUND:
11427- case GGML_UNARY_OP_CEIL:
11428- case GGML_UNARY_OP_FLOOR:
11429- case GGML_UNARY_OP_TRUNC:
11430- break;
11431- default:
11432- return false;
11433- }
11434- break;
11435- case GGML_OP_GLU:
11436- switch (ggml_get_glu_op(node)) {
11437- case GGML_GLU_OP_GEGLU:
11438- case GGML_GLU_OP_REGLU:
11439- case GGML_GLU_OP_SWIGLU:
11440- case GGML_GLU_OP_SWIGLU_OAI:
11441- case GGML_GLU_OP_GEGLU_ERF:
11442- case GGML_GLU_OP_GEGLU_QUICK:
11443- break;
11444- default:
11445- return false;
11446- }
11447- break;
11448- case GGML_OP_ADD:
11449- {
11450- int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
11451- if (next_node_idx < cgraph->n_nodes &&
11452- cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
11453- cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
11454- ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
11455- ctx->device->add_rms_fusion) {
11456- uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
11457- ctx->do_add_rms_partials_offset_calculation = true;
11458- if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
11459- ctx->do_add_rms_partials = true;
11460- }
11402+ if (node->op == GGML_OP_ADD) {
11403+ int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
11404+ if (next_node_idx < cgraph->n_nodes &&
11405+ cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
11406+ cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
11407+ ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
11408+ ctx->device->add_rms_fusion) {
11409+ uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
11410+ ctx->do_add_rms_partials_offset_calculation = true;
11411+ if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
11412+ ctx->do_add_rms_partials = true;
1146111413 }
11462- } break;
11463- case GGML_OP_REPEAT:
11464- case GGML_OP_REPEAT_BACK:
11465- case GGML_OP_GET_ROWS:
11466- case GGML_OP_ADD_ID:
11467- case GGML_OP_ACC:
11468- case GGML_OP_SUB:
11469- case GGML_OP_MUL:
11470- case GGML_OP_DIV:
11471- case GGML_OP_ADD1:
11472- case GGML_OP_ARANGE:
11473- case GGML_OP_FILL:
11474- case GGML_OP_CONCAT:
11475- case GGML_OP_UPSCALE:
11476- case GGML_OP_SCALE:
11477- case GGML_OP_SQR:
11478- case GGML_OP_SQRT:
11479- case GGML_OP_SIN:
11480- case GGML_OP_COS:
11481- case GGML_OP_LOG:
11482- case GGML_OP_CLAMP:
11483- case GGML_OP_PAD:
11484- case GGML_OP_ROLL:
11485- case GGML_OP_CPY:
11486- case GGML_OP_SET_ROWS:
11487- case GGML_OP_CONT:
11488- case GGML_OP_DUP:
11489- case GGML_OP_SILU_BACK:
11490- case GGML_OP_NORM:
11491- case GGML_OP_GROUP_NORM:
11492- case GGML_OP_RMS_NORM:
11493- case GGML_OP_RMS_NORM_BACK:
11494- case GGML_OP_L2_NORM:
11495- case GGML_OP_DIAG_MASK_INF:
11496- case GGML_OP_SOFT_MAX:
11497- case GGML_OP_SOFT_MAX_BACK:
11498- case GGML_OP_ROPE:
11499- case GGML_OP_ROPE_BACK:
11500- case GGML_OP_MUL_MAT:
11501- case GGML_OP_MUL_MAT_ID:
11502- case GGML_OP_ARGSORT:
11503- case GGML_OP_SUM:
11504- case GGML_OP_SUM_ROWS:
11505- case GGML_OP_MEAN:
11506- case GGML_OP_ARGMAX:
11507- case GGML_OP_COUNT_EQUAL:
11508- case GGML_OP_IM2COL:
11509- case GGML_OP_IM2COL_3D:
11510- case GGML_OP_TIMESTEP_EMBEDDING:
11511- case GGML_OP_CONV_TRANSPOSE_1D:
11512- case GGML_OP_POOL_2D:
11513- case GGML_OP_CONV_2D:
11514- case GGML_OP_CONV_TRANSPOSE_2D:
11515- case GGML_OP_CONV_2D_DW:
11516- case GGML_OP_RWKV_WKV6:
11517- case GGML_OP_RWKV_WKV7:
11518- case GGML_OP_SSM_SCAN:
11519- case GGML_OP_SSM_CONV:
11520- case GGML_OP_LEAKY_RELU:
11521- case GGML_OP_FLASH_ATTN_EXT:
11522- case GGML_OP_OPT_STEP_ADAMW:
11523- case GGML_OP_OPT_STEP_SGD:
11524- break;
11525- default:
11526- std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
11527- GGML_ABORT("fatal error");
11414+ }
1152811415 }
1152911416
1153011417 vk_context compute_ctx;
@@ -11961,145 +11848,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1196111848
1196211849 ctx->compute_ctx.reset();
1196311850
11964- bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
11965- if (!ok) {
11966- if (node->op == GGML_OP_UNARY) {
11967- std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
11968- } else if (node->op == GGML_OP_GLU) {
11969- std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
11970- } else {
11971- std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
11972- }
11973- }
11974-
11851+ ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
1197511852 }
1197611853 return true;
1197711854}
1197811855
11979- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
11856+ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
1198011857 GGML_UNUSED(cgraph);
11981- ggml_backend_buffer * buf = nullptr;
11982-
11983- switch (tensor->op) {
11984- case GGML_OP_ADD:
11985- case GGML_OP_ACC:
11986- case GGML_OP_GET_ROWS:
11987- case GGML_OP_SUB:
11988- case GGML_OP_MUL:
11989- case GGML_OP_DIV:
11990- case GGML_OP_ADD1:
11991- case GGML_OP_ARANGE:
11992- case GGML_OP_FILL:
11993- case GGML_OP_ADD_ID:
11994- case GGML_OP_CONCAT:
11995- case GGML_OP_UPSCALE:
11996- case GGML_OP_SCALE:
11997- case GGML_OP_SQR:
11998- case GGML_OP_SQRT:
11999- case GGML_OP_SIN:
12000- case GGML_OP_COS:
12001- case GGML_OP_LOG:
12002- case GGML_OP_CLAMP:
12003- case GGML_OP_PAD:
12004- case GGML_OP_ROLL:
12005- case GGML_OP_CPY:
12006- case GGML_OP_SET_ROWS:
12007- case GGML_OP_CONT:
12008- case GGML_OP_DUP:
12009- case GGML_OP_SILU_BACK:
12010- case GGML_OP_NORM:
12011- case GGML_OP_GROUP_NORM:
12012- case GGML_OP_RMS_NORM:
12013- case GGML_OP_RMS_NORM_BACK:
12014- case GGML_OP_L2_NORM:
12015- case GGML_OP_DIAG_MASK_INF:
12016- case GGML_OP_SOFT_MAX:
12017- case GGML_OP_SOFT_MAX_BACK:
12018- case GGML_OP_ROPE:
12019- case GGML_OP_ROPE_BACK:
12020- case GGML_OP_RESHAPE:
12021- case GGML_OP_VIEW:
12022- case GGML_OP_PERMUTE:
12023- case GGML_OP_TRANSPOSE:
12024- case GGML_OP_NONE:
12025- case GGML_OP_ARGSORT:
12026- case GGML_OP_SUM:
12027- case GGML_OP_SUM_ROWS:
12028- case GGML_OP_MEAN:
12029- case GGML_OP_ARGMAX:
12030- case GGML_OP_COUNT_EQUAL:
12031- case GGML_OP_IM2COL:
12032- case GGML_OP_IM2COL_3D:
12033- case GGML_OP_TIMESTEP_EMBEDDING:
12034- case GGML_OP_CONV_TRANSPOSE_1D:
12035- case GGML_OP_POOL_2D:
12036- case GGML_OP_CONV_2D:
12037- case GGML_OP_CONV_TRANSPOSE_2D:
12038- case GGML_OP_CONV_2D_DW:
12039- case GGML_OP_RWKV_WKV6:
12040- case GGML_OP_RWKV_WKV7:
12041- case GGML_OP_SSM_SCAN:
12042- case GGML_OP_SSM_CONV:
12043- case GGML_OP_LEAKY_RELU:
12044- case GGML_OP_REPEAT:
12045- case GGML_OP_REPEAT_BACK:
12046- case GGML_OP_OPT_STEP_ADAMW:
12047- case GGML_OP_OPT_STEP_SGD:
12048- buf = tensor->buffer;
12049- break;
12050- case GGML_OP_UNARY:
12051- switch (ggml_get_unary_op(tensor)) {
12052- case GGML_UNARY_OP_EXP:
12053- case GGML_UNARY_OP_SILU:
12054- case GGML_UNARY_OP_GELU:
12055- case GGML_UNARY_OP_GELU_ERF:
12056- case GGML_UNARY_OP_GELU_QUICK:
12057- case GGML_UNARY_OP_RELU:
12058- case GGML_UNARY_OP_NEG:
12059- case GGML_UNARY_OP_TANH:
12060- case GGML_UNARY_OP_SIGMOID:
12061- case GGML_UNARY_OP_HARDSIGMOID:
12062- case GGML_UNARY_OP_HARDSWISH:
12063- case GGML_UNARY_OP_ABS:
12064- case GGML_UNARY_OP_SOFTPLUS:
12065- case GGML_UNARY_OP_STEP:
12066- case GGML_UNARY_OP_ROUND:
12067- case GGML_UNARY_OP_CEIL:
12068- case GGML_UNARY_OP_FLOOR:
12069- case GGML_UNARY_OP_TRUNC:
12070- buf = tensor->buffer;
12071- break;
12072- default:
12073- return false;
12074- }
12075- break;
12076- case GGML_OP_GLU:
12077- switch (ggml_get_glu_op(tensor)) {
12078- case GGML_GLU_OP_GEGLU:
12079- case GGML_GLU_OP_REGLU:
12080- case GGML_GLU_OP_SWIGLU:
12081- case GGML_GLU_OP_SWIGLU_OAI:
12082- case GGML_GLU_OP_GEGLU_ERF:
12083- case GGML_GLU_OP_GEGLU_QUICK:
12084- buf = tensor->buffer;
12085- break;
12086- default:
12087- return false;
12088- }
12089- break;
12090- case GGML_OP_MUL_MAT:
12091- case GGML_OP_MUL_MAT_ID:
12092- case GGML_OP_FLASH_ATTN_EXT:
12093- buf = tensor->buffer;
12094-
12095- break;
12096- default:
12097- return false;
12098- }
12099-
12100- if (buf == nullptr) {
12101- return false;
12102- }
11858+ GGML_UNUSED(tensor);
1210311859
1210411860 VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
1210511861
@@ -12143,8 +11899,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1214311899 subctx->out_memcpys.clear();
1214411900 subctx->memsets.clear();
1214511901 }
12146-
12147- return true;
1214811902}
1214911903
1215011904// Clean up after graph processing is done
0 commit comments