diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 611341deb0a..992ec0495fe 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -989,6 +989,10 @@ struct ggml_cuda_concurrent_event { int n_streams = 0; std::unordered_map stream_mapping; + // Original order of nodes in this concurrent region (before interleaving) + // Used to restore grouping for fusion within streams + std::vector original_order; + const ggml_tensor * join_node; ggml_cuda_concurrent_event() = default; @@ -1011,6 +1015,7 @@ struct ggml_cuda_concurrent_event { , fork_event(other.fork_event) , n_streams(other.n_streams) , stream_mapping(std::move(other.stream_mapping)) + , original_order(std::move(other.original_order)) , join_node(other.join_node) { other.fork_event = nullptr; } @@ -1121,11 +1126,9 @@ struct ggml_cuda_concurrent_event { }; struct ggml_cuda_stream_context { - std::vector original_nodes; std::unordered_map concurrent_events; void reset() { - original_nodes.clear(); concurrent_events.clear(); } }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fa7e1e13a71..0da57e6715e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3238,9 +3238,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } if (should_launch_concurrent_events) { - //Restore the original graph to enable fusion within the streams - cgraph->nodes = const_cast(stream_ctx.original_nodes.data()); - cgraph->n_nodes = (int) stream_ctx.original_nodes.size(); + // Restore original node order within each concurrent region to enable fusion within streams + + std::unordered_map node_to_idx; + node_to_idx.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + node_to_idx[cgraph->nodes[i]] = i; + } + + for (auto & [fork_node, event] : stream_ctx.concurrent_events) { + // Find positions of all nodes from this event in the current graph + std::vector positions; + positions.reserve(event.original_order.size()); + + bool all_found = true; + for (const ggml_tensor * orig_node : event.original_order) { + auto it = node_to_idx.find(orig_node); + if (it != node_to_idx.end()) { + positions.push_back(it->second); + } else { + all_found = false; + break; + } + } + + if (!all_found || positions.size() != event.original_order.size()) { + continue; + } + + // Sort positions to get contiguous range + std::vector sorted_positions = positions; + std::sort(sorted_positions.begin(), sorted_positions.end()); + + bool is_contiguous = true; + for (size_t i = 1; i < sorted_positions.size(); ++i) { + if (sorted_positions[i] != sorted_positions[i-1] + 1) { + is_contiguous = false; + break; + } + } + + if (!is_contiguous) { + continue; + } + + // Restore original order at the sorted positions + int start_pos = sorted_positions[0]; + for (size_t i = 0; i < event.original_order.size(); ++i) { + cgraph->nodes[start_pos + i] = const_cast(event.original_order[i]); + } + } } for (int i = 0; i < cgraph->n_nodes; i++) { @@ -3805,14 +3852,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph // store {fork_idx, join_idx} std::vector> concurrent_node_ranges; - // save the original nodes - std::vector original_nodes; - original_nodes.reserve(cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; ++i) { - original_nodes.push_back(cgraph->nodes[i]); - } - cuda_ctx->stream_context().original_nodes = std::move(original_nodes); - for (const auto & [root_node, count] : fan_out) { if (count >= min_fan_out && count <= max_fan_out) { const int root_node_idx = node_indices[root_node]; @@ -3917,6 +3956,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph continue; } + // Save the original order of nodes in this region before interleaving + // This is used later to restore grouping for fusion within streams + concurrent_event.original_order.reserve(total_branch_nodes); + for (int i = fork_node_idx + 1; i < join_node_idx; ++i) { + concurrent_event.original_order.push_back(cgraph->nodes[i]); + } + std::unordered_map & concurrent_events = cuda_ctx->stream_context().concurrent_events; GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); concurrent_events.emplace(root_node, std::move(concurrent_event));