Skip to content

Commit ed32089

Browse files
authored
ggml-cuda: reorder only relevant nodes (ggml-org#17639)
1 parent 7b6d745 commit ed32089

File tree

2 files changed

+62
-13
lines changed

2 files changed

+62
-13
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,10 @@ struct ggml_cuda_concurrent_event {
989989
int n_streams = 0;
990990
std::unordered_map<const ggml_tensor *, int> stream_mapping;
991991

992+
// Original order of nodes in this concurrent region (before interleaving)
993+
// Used to restore grouping for fusion within streams
994+
std::vector<const ggml_tensor *> original_order;
995+
992996
const ggml_tensor * join_node;
993997

994998
ggml_cuda_concurrent_event() = default;
@@ -1011,6 +1015,7 @@ struct ggml_cuda_concurrent_event {
10111015
, fork_event(other.fork_event)
10121016
, n_streams(other.n_streams)
10131017
, stream_mapping(std::move(other.stream_mapping))
1018+
, original_order(std::move(other.original_order))
10141019
, join_node(other.join_node) {
10151020
other.fork_event = nullptr;
10161021
}
@@ -1121,11 +1126,9 @@ struct ggml_cuda_concurrent_event {
11211126
};
11221127

11231128
struct ggml_cuda_stream_context {
1124-
std::vector<const ggml_tensor *> original_nodes;
11251129
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
11261130

11271131
void reset() {
1128-
original_nodes.clear();
11291132
concurrent_events.clear();
11301133
}
11311134
};

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3238,9 +3238,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32383238
}
32393239
}
32403240
if (should_launch_concurrent_events) {
3241-
//Restore the original graph to enable fusion within the streams
3242-
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
3243-
cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
3241+
// Restore original node order within each concurrent region to enable fusion within streams
3242+
3243+
std::unordered_map<const ggml_tensor *, int> node_to_idx;
3244+
node_to_idx.reserve(cgraph->n_nodes);
3245+
for (int i = 0; i < cgraph->n_nodes; ++i) {
3246+
node_to_idx[cgraph->nodes[i]] = i;
3247+
}
3248+
3249+
for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
3250+
// Find positions of all nodes from this event in the current graph
3251+
std::vector<int> positions;
3252+
positions.reserve(event.original_order.size());
3253+
3254+
bool all_found = true;
3255+
for (const ggml_tensor * orig_node : event.original_order) {
3256+
auto it = node_to_idx.find(orig_node);
3257+
if (it != node_to_idx.end()) {
3258+
positions.push_back(it->second);
3259+
} else {
3260+
all_found = false;
3261+
break;
3262+
}
3263+
}
3264+
3265+
if (!all_found || positions.size() != event.original_order.size()) {
3266+
continue;
3267+
}
3268+
3269+
// Sort positions to get contiguous range
3270+
std::vector<int> sorted_positions = positions;
3271+
std::sort(sorted_positions.begin(), sorted_positions.end());
3272+
3273+
bool is_contiguous = true;
3274+
for (size_t i = 1; i < sorted_positions.size(); ++i) {
3275+
if (sorted_positions[i] != sorted_positions[i-1] + 1) {
3276+
is_contiguous = false;
3277+
break;
3278+
}
3279+
}
3280+
3281+
if (!is_contiguous) {
3282+
continue;
3283+
}
3284+
3285+
// Restore original order at the sorted positions
3286+
int start_pos = sorted_positions[0];
3287+
for (size_t i = 0; i < event.original_order.size(); ++i) {
3288+
cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
3289+
}
3290+
}
32443291
}
32453292

32463293
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -3805,14 +3852,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38053852
// store {fork_idx, join_idx}
38063853
std::vector<std::pair<int, int>> concurrent_node_ranges;
38073854

3808-
// save the original nodes
3809-
std::vector<const ggml_tensor *> original_nodes;
3810-
original_nodes.reserve(cgraph->n_nodes);
3811-
for (int i = 0; i < cgraph->n_nodes; ++i) {
3812-
original_nodes.push_back(cgraph->nodes[i]);
3813-
}
3814-
cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
3815-
38163855
for (const auto & [root_node, count] : fan_out) {
38173856
if (count >= min_fan_out && count <= max_fan_out) {
38183857
const int root_node_idx = node_indices[root_node];
@@ -3917,6 +3956,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
39173956
continue;
39183957
}
39193958

3959+
// Save the original order of nodes in this region before interleaving
3960+
// This is used later to restore grouping for fusion within streams
3961+
concurrent_event.original_order.reserve(total_branch_nodes);
3962+
for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
3963+
concurrent_event.original_order.push_back(cgraph->nodes[i]);
3964+
}
3965+
39203966
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
39213967
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
39223968
concurrent_events.emplace(root_node, std::move(concurrent_event));

0 commit comments

Comments
 (0)