Skip to content

Commit c7af376

Browse files
am17anIMbackKJohannesGaessler
authored
CUDA: add stream-based concurrency (#16991)
* CUDA: add stream-based concurrency * HIP: fix hipStreamWaitEvent define and nodiscard warnings * ggml-cuda: fix fusion inside stream * ggml-cuda: fix bug w.r.t first stream launch * ggml-cuda: format * ggml-cuda: improve assert message * ggml-cuda: use lambda instead of duplicating code * ggml-cuda: add some more comments * ggml-cuda: add more detailed comments about concurrency * ggml-cuda: rename + remove unused var * ggml-cuda: fix condition for stream launch * ggml-cuda: address review comments, add destructor * common.cuh: add is_valid for concurrent events * common.cuh: make comment better * update comment Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * update comment Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * common.cuh: fix lower_bound condition + remove join_node data from write_ranges * ggml-cuda: fix overlap condition + shadowing parameter --------- Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent 00425e2 commit c7af376

File tree

3 files changed

+469
-14
lines changed

3 files changed

+469
-14
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 162 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
#include "ggml-common.h"
2222

2323
#include <array>
24+
#include <algorithm>
2425
#include <cassert>
2526
#include <cfloat>
2627
#include <cstdio>
2728
#include <string>
29+
#include <unordered_map>
2830
#include <vector>
2931

3032
#if defined(GGML_USE_HIP)
@@ -980,6 +982,154 @@ struct ggml_cuda_graph {
980982
#endif
981983
};
982984

985+
struct ggml_cuda_concurrent_event {
986+
std::vector<cudaEvent_t> join_events;
987+
cudaEvent_t fork_event = nullptr;
988+
989+
int n_streams = 0;
990+
std::unordered_map<const ggml_tensor *, int> stream_mapping;
991+
992+
const ggml_tensor * join_node;
993+
994+
ggml_cuda_concurrent_event() = default;
995+
996+
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
997+
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
998+
999+
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
1000+
join_events.resize(n_streams);
1001+
1002+
for (size_t i = 0; i < join_events.size(); ++i) {
1003+
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
1004+
}
1005+
1006+
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
1007+
}
1008+
1009+
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
1010+
: join_events(std::move(other.join_events))
1011+
, fork_event(other.fork_event)
1012+
, n_streams(other.n_streams)
1013+
, stream_mapping(std::move(other.stream_mapping))
1014+
, join_node(other.join_node) {
1015+
other.fork_event = nullptr;
1016+
}
1017+
1018+
// 1. check if any branches write to overlapping memory ranges (except the join node)
1019+
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
1020+
// we assume all nodes have the same buffer
1021+
bool is_valid() const {
1022+
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
1023+
write_ranges.resize(n_streams);
1024+
1025+
// get join_node's memory range to exclude from overlap checking.
1026+
// multiple nodes can use join_node's buffer; we synchronize on the join node.
1027+
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
1028+
const int64_t join_start = (int64_t) join_t->data;
1029+
const int64_t join_end = join_start + ggml_nbytes(join_t);
1030+
1031+
for (const auto & [tensor, stream] : stream_mapping) {
1032+
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
1033+
const int64_t t_start = (int64_t) t->data;
1034+
const int64_t t_end = t_start + ggml_nbytes(t);
1035+
1036+
// skip tensors that overlap with join_node's buffer.
1037+
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
1038+
continue;
1039+
}
1040+
1041+
// concurrent streams begin from 1
1042+
write_ranges[stream - 1].emplace_back(t_start, t_end);
1043+
}
1044+
1045+
for (int i = 0; i < n_streams; ++i) {
1046+
// sorts first by start then by end of write range
1047+
std::sort(write_ranges[i].begin(), write_ranges[i].end());
1048+
}
1049+
1050+
bool writes_overlap = false;
1051+
bool dependent_srcs = false;
1052+
for (const auto & [tensor, stream] : stream_mapping) {
1053+
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
1054+
const int64_t t_start = (int64_t) t->data;
1055+
const int64_t t_end = t_start + ggml_nbytes(t);
1056+
1057+
// skip tensors that overlap with join_node's buffer
1058+
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
1059+
continue;
1060+
}
1061+
1062+
// check if this buffer's write data overlaps with another stream's
1063+
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
1064+
for (int i = 0; i < n_streams; ++i) {
1065+
if (i == stream - 1) {
1066+
continue;
1067+
}
1068+
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
1069+
1070+
if (it != write_ranges[i].end()) {
1071+
const std::pair<int64_t, int64_t> & other = *it;
1072+
1073+
// std::lower_bound returns the first element where other >= data_range (lexicographically).
1074+
// This guarantees other.first >= data_range.first.
1075+
// Therefore, overlap occurs iff other.first < data_range.second
1076+
// (i.e., the other range starts before this range ends).
1077+
if (other.first < data_range.second) {
1078+
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
1079+
writes_overlap = true;
1080+
break;
1081+
}
1082+
}
1083+
}
1084+
1085+
//check if all srcs are either in branch or don't have a branch
1086+
for (int i = 0; i < GGML_MAX_SRC; ++i) {
1087+
if (!tensor->src[i]) {
1088+
continue;
1089+
}
1090+
1091+
auto it = stream_mapping.find(tensor->src[i]);
1092+
1093+
if (it == stream_mapping.end()) {
1094+
continue;
1095+
}
1096+
1097+
if (it->second != stream) {
1098+
dependent_srcs = true;
1099+
break;
1100+
}
1101+
}
1102+
1103+
if (dependent_srcs || writes_overlap) {
1104+
break;
1105+
}
1106+
}
1107+
1108+
return !writes_overlap && !dependent_srcs;
1109+
}
1110+
1111+
~ggml_cuda_concurrent_event() {
1112+
if (fork_event != nullptr) {
1113+
CUDA_CHECK(cudaEventDestroy(fork_event));
1114+
}
1115+
for (cudaEvent_t e : join_events) {
1116+
if (e != nullptr) {
1117+
CUDA_CHECK(cudaEventDestroy(e));
1118+
}
1119+
}
1120+
}
1121+
};
1122+
1123+
struct ggml_cuda_stream_context {
1124+
std::vector<const ggml_tensor *> original_nodes;
1125+
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
1126+
1127+
void reset() {
1128+
original_nodes.clear();
1129+
concurrent_events.clear();
1130+
}
1131+
};
1132+
9831133
struct ggml_backend_cuda_context {
9841134
int device;
9851135
std::string name;
@@ -990,11 +1140,15 @@ struct ggml_backend_cuda_context {
9901140

9911141
std::unique_ptr<ggml_cuda_graph> cuda_graph;
9921142

1143+
int curr_stream_no = 0;
1144+
9931145
explicit ggml_backend_cuda_context(int device) :
9941146
device(device),
9951147
name(GGML_CUDA_NAME + std::to_string(device)) {
9961148
}
9971149

1150+
ggml_cuda_stream_context concurrent_stream_context;
1151+
9981152
~ggml_backend_cuda_context();
9991153

10001154
cudaStream_t stream(int device, int stream) {
@@ -1005,9 +1159,9 @@ struct ggml_backend_cuda_context {
10051159
return streams[device][stream];
10061160
}
10071161

1008-
cudaStream_t stream() {
1009-
return stream(device, 0);
1010-
}
1162+
cudaStream_t stream() { return stream(device, curr_stream_no); }
1163+
1164+
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
10111165

10121166
cublasHandle_t cublas_handle(int device) {
10131167
if (cublas_handles[device] == nullptr) {
@@ -1023,15 +1177,15 @@ struct ggml_backend_cuda_context {
10231177
}
10241178

10251179
// pool
1026-
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
1180+
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
10271181

1028-
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
1182+
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
10291183

10301184
ggml_cuda_pool & pool(int device) {
1031-
if (pools[device] == nullptr) {
1032-
pools[device] = new_pool_for_device(device);
1185+
if (pools[device][curr_stream_no] == nullptr) {
1186+
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
10331187
}
1034-
return *pools[device];
1188+
return *pools[device][curr_stream_no];
10351189
}
10361190

10371191
ggml_cuda_pool & pool() {

0 commit comments

Comments
 (0)