Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,10 @@ struct ggml_cuda_concurrent_event {
int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;

// Original order of nodes in this concurrent region (before interleaving)
// Used to restore grouping for fusion within streams
std::vector<const ggml_tensor *> original_order;

const ggml_tensor * join_node;

ggml_cuda_concurrent_event() = default;
Expand All @@ -1011,6 +1015,7 @@ struct ggml_cuda_concurrent_event {
, fork_event(other.fork_event)
, n_streams(other.n_streams)
, stream_mapping(std::move(other.stream_mapping))
, original_order(std::move(other.original_order))
, join_node(other.join_node) {
other.fork_event = nullptr;
}
Expand Down Expand Up @@ -1121,11 +1126,9 @@ struct ggml_cuda_concurrent_event {
};

struct ggml_cuda_stream_context {
std::vector<const ggml_tensor *> original_nodes;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;

void reset() {
original_nodes.clear();
concurrent_events.clear();
}
};
Expand Down
68 changes: 57 additions & 11 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3238,9 +3238,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
}
if (should_launch_concurrent_events) {
//Restore the original graph to enable fusion within the streams
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
// Restore original node order within each concurrent region to enable fusion within streams

std::unordered_map<const ggml_tensor *, int> node_to_idx;
node_to_idx.reserve(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; ++i) {
node_to_idx[cgraph->nodes[i]] = i;
}

for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
// Find positions of all nodes from this event in the current graph
std::vector<int> positions;
positions.reserve(event.original_order.size());

bool all_found = true;
for (const ggml_tensor * orig_node : event.original_order) {
auto it = node_to_idx.find(orig_node);
if (it != node_to_idx.end()) {
positions.push_back(it->second);
} else {
all_found = false;
break;
}
}

if (!all_found || positions.size() != event.original_order.size()) {
continue;
}

// Sort positions to get contiguous range
std::vector<int> sorted_positions = positions;
std::sort(sorted_positions.begin(), sorted_positions.end());

bool is_contiguous = true;
for (size_t i = 1; i < sorted_positions.size(); ++i) {
if (sorted_positions[i] != sorted_positions[i-1] + 1) {
is_contiguous = false;
break;
}
}

if (!is_contiguous) {
continue;
}

// Restore original order at the sorted positions
int start_pos = sorted_positions[0];
for (size_t i = 0; i < event.original_order.size(); ++i) {
cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
}
}
}

for (int i = 0; i < cgraph->n_nodes; i++) {
Expand Down Expand Up @@ -3805,14 +3852,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
// store {fork_idx, join_idx}
std::vector<std::pair<int, int>> concurrent_node_ranges;

// save the original nodes
std::vector<const ggml_tensor *> original_nodes;
original_nodes.reserve(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; ++i) {
original_nodes.push_back(cgraph->nodes[i]);
}
cuda_ctx->stream_context().original_nodes = std::move(original_nodes);

for (const auto & [root_node, count] : fan_out) {
if (count >= min_fan_out && count <= max_fan_out) {
const int root_node_idx = node_indices[root_node];
Expand Down Expand Up @@ -3917,6 +3956,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
continue;
}

// Save the original order of nodes in this region before interleaving
// This is used later to restore grouping for fusion within streams
concurrent_event.original_order.reserve(total_branch_nodes);
for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
concurrent_event.original_order.push_back(cgraph->nodes[i]);
}

std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
concurrent_events.emplace(root_node, std::move(concurrent_event));
Expand Down
Loading