2121#include " ggml-common.h"
2222
2323#include < array>
24+ #include < algorithm>
2425#include < cassert>
2526#include < cfloat>
2627#include < cstdio>
2728#include < string>
29+ #include < unordered_map>
2830#include < vector>
2931
3032#if defined(GGML_USE_HIP)
@@ -980,6 +982,154 @@ struct ggml_cuda_graph {
980982#endif
981983};
982984
985+ struct ggml_cuda_concurrent_event {
986+ std::vector<cudaEvent_t> join_events;
987+ cudaEvent_t fork_event = nullptr ;
988+
989+ int n_streams = 0 ;
990+ std::unordered_map<const ggml_tensor *, int > stream_mapping;
991+
992+ const ggml_tensor * join_node;
993+
994+ ggml_cuda_concurrent_event () = default ;
995+
996+ ggml_cuda_concurrent_event (const ggml_cuda_concurrent_event &) = delete ;
997+ ggml_cuda_concurrent_event & operator =(const ggml_cuda_concurrent_event &) = delete ;
998+
999+ explicit ggml_cuda_concurrent_event (int n_streams) : n_streams(n_streams) {
1000+ join_events.resize (n_streams);
1001+
1002+ for (size_t i = 0 ; i < join_events.size (); ++i) {
1003+ CUDA_CHECK (cudaEventCreateWithFlags (&join_events[i], cudaEventDisableTiming));
1004+ }
1005+
1006+ CUDA_CHECK (cudaEventCreateWithFlags (&fork_event, cudaEventDisableTiming));
1007+ }
1008+
1009+ ggml_cuda_concurrent_event (ggml_cuda_concurrent_event && other) noexcept
1010+ : join_events(std::move(other.join_events))
1011+ , fork_event(other.fork_event)
1012+ , n_streams(other.n_streams)
1013+ , stream_mapping(std::move(other.stream_mapping))
1014+ , join_node(other.join_node) {
1015+ other.fork_event = nullptr ;
1016+ }
1017+
1018+ // 1. check if any branches write to overlapping memory ranges (except the join node)
1019+ // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
1020+ // we assume all nodes have the same buffer
1021+ bool is_valid () const {
1022+ std::vector<std::vector<std::pair<int64_t , int64_t >>> write_ranges;
1023+ write_ranges.resize (n_streams);
1024+
1025+ // get join_node's memory range to exclude from overlap checking.
1026+ // multiple nodes can use join_node's buffer; we synchronize on the join node.
1027+ const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
1028+ const int64_t join_start = (int64_t ) join_t ->data ;
1029+ const int64_t join_end = join_start + ggml_nbytes (join_t );
1030+
1031+ for (const auto & [tensor, stream] : stream_mapping) {
1032+ const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
1033+ const int64_t t_start = (int64_t ) t->data ;
1034+ const int64_t t_end = t_start + ggml_nbytes (t);
1035+
1036+ // skip tensors that overlap with join_node's buffer.
1037+ if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
1038+ continue ;
1039+ }
1040+
1041+ // concurrent streams begin from 1
1042+ write_ranges[stream - 1 ].emplace_back (t_start, t_end);
1043+ }
1044+
1045+ for (int i = 0 ; i < n_streams; ++i) {
1046+ // sorts first by start then by end of write range
1047+ std::sort (write_ranges[i].begin (), write_ranges[i].end ());
1048+ }
1049+
1050+ bool writes_overlap = false ;
1051+ bool dependent_srcs = false ;
1052+ for (const auto & [tensor, stream] : stream_mapping) {
1053+ const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
1054+ const int64_t t_start = (int64_t ) t->data ;
1055+ const int64_t t_end = t_start + ggml_nbytes (t);
1056+
1057+ // skip tensors that overlap with join_node's buffer
1058+ if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
1059+ continue ;
1060+ }
1061+
1062+ // check if this buffer's write data overlaps with another stream's
1063+ std::pair<int64_t , int64_t > data_range = std::make_pair (t_start, t_end);
1064+ for (int i = 0 ; i < n_streams; ++i) {
1065+ if (i == stream - 1 ) {
1066+ continue ;
1067+ }
1068+ auto it = std::lower_bound (write_ranges[i].begin (), write_ranges[i].end (), data_range);
1069+
1070+ if (it != write_ranges[i].end ()) {
1071+ const std::pair<int64_t , int64_t > & other = *it;
1072+
1073+ // std::lower_bound returns the first element where other >= data_range (lexicographically).
1074+ // This guarantees other.first >= data_range.first.
1075+ // Therefore, overlap occurs iff other.first < data_range.second
1076+ // (i.e., the other range starts before this range ends).
1077+ if (other.first < data_range.second ) {
1078+ GGML_LOG_DEBUG (" Writes overlap for %s" , tensor->name );
1079+ writes_overlap = true ;
1080+ break ;
1081+ }
1082+ }
1083+ }
1084+
1085+ // check if all srcs are either in branch or don't have a branch
1086+ for (int i = 0 ; i < GGML_MAX_SRC; ++i) {
1087+ if (!tensor->src [i]) {
1088+ continue ;
1089+ }
1090+
1091+ auto it = stream_mapping.find (tensor->src [i]);
1092+
1093+ if (it == stream_mapping.end ()) {
1094+ continue ;
1095+ }
1096+
1097+ if (it->second != stream) {
1098+ dependent_srcs = true ;
1099+ break ;
1100+ }
1101+ }
1102+
1103+ if (dependent_srcs || writes_overlap) {
1104+ break ;
1105+ }
1106+ }
1107+
1108+ return !writes_overlap && !dependent_srcs;
1109+ }
1110+
1111+ ~ggml_cuda_concurrent_event () {
1112+ if (fork_event != nullptr ) {
1113+ CUDA_CHECK (cudaEventDestroy (fork_event));
1114+ }
1115+ for (cudaEvent_t e : join_events) {
1116+ if (e != nullptr ) {
1117+ CUDA_CHECK (cudaEventDestroy (e));
1118+ }
1119+ }
1120+ }
1121+ };
1122+
1123+ struct ggml_cuda_stream_context {
1124+ std::vector<const ggml_tensor *> original_nodes;
1125+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
1126+
1127+ void reset () {
1128+ original_nodes.clear ();
1129+ concurrent_events.clear ();
1130+ }
1131+ };
1132+
9831133struct ggml_backend_cuda_context {
9841134 int device;
9851135 std::string name;
@@ -990,11 +1140,15 @@ struct ggml_backend_cuda_context {
9901140
9911141 std::unique_ptr<ggml_cuda_graph> cuda_graph;
9921142
1143+ int curr_stream_no = 0 ;
1144+
9931145 explicit ggml_backend_cuda_context (int device) :
9941146 device(device),
9951147 name(GGML_CUDA_NAME + std::to_string(device)) {
9961148 }
9971149
1150+ ggml_cuda_stream_context concurrent_stream_context;
1151+
9981152 ~ggml_backend_cuda_context ();
9991153
10001154 cudaStream_t stream (int device, int stream) {
@@ -1005,9 +1159,9 @@ struct ggml_backend_cuda_context {
10051159 return streams[device][stream];
10061160 }
10071161
1008- cudaStream_t stream () {
1009- return stream (device, 0 );
1010- }
1162+ cudaStream_t stream () { return stream (device, curr_stream_no); }
1163+
1164+ ggml_cuda_stream_context & stream_context () { return concurrent_stream_context; }
10111165
10121166 cublasHandle_t cublas_handle (int device) {
10131167 if (cublas_handles[device] == nullptr ) {
@@ -1023,15 +1177,15 @@ struct ggml_backend_cuda_context {
10231177 }
10241178
10251179 // pool
1026- std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
1180+ std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] ;
10271181
1028- static std::unique_ptr<ggml_cuda_pool> new_pool_for_device (int device);
1182+ static std::unique_ptr<ggml_cuda_pool> new_pool_for_device (int device, int stream_no );
10291183
10301184 ggml_cuda_pool & pool (int device) {
1031- if (pools[device] == nullptr ) {
1032- pools[device] = new_pool_for_device (device);
1185+ if (pools[device][curr_stream_no] == nullptr ) {
1186+ pools[device][curr_stream_no] = new_pool_for_device (device, curr_stream_no );
10331187 }
1034- return *pools[device];
1188+ return *pools[device][curr_stream_no] ;
10351189 }
10361190
10371191 ggml_cuda_pool & pool () {
0 commit comments