@@ -3238,9 +3238,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32383238 }
32393239 }
32403240 if (should_launch_concurrent_events) {
3241- // Restore the original graph to enable fusion within the streams
3242- cgraph->nodes = const_cast <ggml_tensor **>(stream_ctx.original_nodes .data ());
3243- cgraph->n_nodes = (int ) stream_ctx.original_nodes .size ();
3241+ // Restore original node order within each concurrent region to enable fusion within streams
3242+
3243+ std::unordered_map<const ggml_tensor *, int > node_to_idx;
3244+ node_to_idx.reserve (cgraph->n_nodes );
3245+ for (int i = 0 ; i < cgraph->n_nodes ; ++i) {
3246+ node_to_idx[cgraph->nodes [i]] = i;
3247+ }
3248+
3249+ for (auto & [fork_node, event] : stream_ctx.concurrent_events ) {
3250+ // Find positions of all nodes from this event in the current graph
3251+ std::vector<int > positions;
3252+ positions.reserve (event.original_order .size ());
3253+
3254+ bool all_found = true ;
3255+ for (const ggml_tensor * orig_node : event.original_order ) {
3256+ auto it = node_to_idx.find (orig_node);
3257+ if (it != node_to_idx.end ()) {
3258+ positions.push_back (it->second );
3259+ } else {
3260+ all_found = false ;
3261+ break ;
3262+ }
3263+ }
3264+
3265+ if (!all_found || positions.size () != event.original_order .size ()) {
3266+ continue ;
3267+ }
3268+
3269+ // Sort positions to get contiguous range
3270+ std::vector<int > sorted_positions = positions;
3271+ std::sort (sorted_positions.begin (), sorted_positions.end ());
3272+
3273+ bool is_contiguous = true ;
3274+ for (size_t i = 1 ; i < sorted_positions.size (); ++i) {
3275+ if (sorted_positions[i] != sorted_positions[i-1 ] + 1 ) {
3276+ is_contiguous = false ;
3277+ break ;
3278+ }
3279+ }
3280+
3281+ if (!is_contiguous) {
3282+ continue ;
3283+ }
3284+
3285+ // Restore original order at the sorted positions
3286+ int start_pos = sorted_positions[0 ];
3287+ for (size_t i = 0 ; i < event.original_order .size (); ++i) {
3288+ cgraph->nodes [start_pos + i] = const_cast <ggml_tensor *>(event.original_order [i]);
3289+ }
3290+ }
32443291 }
32453292
32463293 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
@@ -3805,14 +3852,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38053852 // store {fork_idx, join_idx}
38063853 std::vector<std::pair<int , int >> concurrent_node_ranges;
38073854
3808- // save the original nodes
3809- std::vector<const ggml_tensor *> original_nodes;
3810- original_nodes.reserve (cgraph->n_nodes );
3811- for (int i = 0 ; i < cgraph->n_nodes ; ++i) {
3812- original_nodes.push_back (cgraph->nodes [i]);
3813- }
3814- cuda_ctx->stream_context ().original_nodes = std::move (original_nodes);
3815-
38163855 for (const auto & [root_node, count] : fan_out) {
38173856 if (count >= min_fan_out && count <= max_fan_out) {
38183857 const int root_node_idx = node_indices[root_node];
@@ -3917,6 +3956,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
39173956 continue ;
39183957 }
39193958
3959+ // Save the original order of nodes in this region before interleaving
3960+ // This is used later to restore grouping for fusion within streams
3961+ concurrent_event.original_order .reserve (total_branch_nodes);
3962+ for (int i = fork_node_idx + 1 ; i < join_node_idx; ++i) {
3963+ concurrent_event.original_order .push_back (cgraph->nodes [i]);
3964+ }
3965+
39203966 std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context ().concurrent_events ;
39213967 GGML_ASSERT (concurrent_events.find (root_node) == concurrent_events.end ());
39223968 concurrent_events.emplace (root_node, std::move (concurrent_event));
0 commit comments