@@ -3725,7 +3725,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37253725 bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
37263726 // flag used to determine whether it is an integrated_gpu
37273727 // TODO
3728- const bool integrated = false ; // ggml_cuda_info().devices[cuda_ctx->device].integrated;
3728+ [[maybe_unused]] const bool integrated = false ; // ggml_cuda_info().devices[cuda_ctx->device].integrated;
37293729
37303730 // printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us());
37313731 while (!graph_evaluated_or_captured) {
@@ -3763,8 +3763,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37633763 assert (node->src [j]->buffer );
37643764 }
37653765 }
3766- #else
3767- GGML_UNUSED (integrated);
37683766#endif // NDEBUG
37693767
37703768 bool ok = ggml_cuda_compute_forward (*cuda_ctx, node, cgraph, i);
@@ -3816,15 +3814,19 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
38163814#ifdef USE_CUDA_GRAPH
38173815 static const bool disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
38183816
3817+ // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3818+ // or previous graph capture failure.
3819+ // Also disable for multi-gpu for now. TO DO investigate
3820+ bool use_cuda_graph = !disable_cuda_graphs_due_to_env && cuda_ctx->use_cuda_graph ;
3821+
38193822 // Objects required for CUDA Graph
38203823 if (cuda_ctx->cuda_graph == nullptr ) {
38213824 cuda_ctx->cuda_graph .reset (new ggml_cuda_graph ());
38223825 }
38233826
3824- bool use_cuda_graph = true ;
38253827 bool cuda_graph_update_required = false ;
38263828
3827- if (cuda_ctx->cuda_graph ->graph == nullptr ) {
3829+ if (use_cuda_graph && cuda_ctx->cuda_graph ->graph == nullptr ) {
38283830 if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < CC_AMPERE) {
38293831 cuda_ctx->cuda_graph ->disable_due_to_gpu_arch = true ;
38303832#ifndef NDEBUG
@@ -3833,13 +3835,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
38333835 }
38343836 }
38353837
3836- // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3837- // or previous graph capture failure.
3838- // Also disable for multi-gpu for now. TO DO investigate
3839- if (disable_cuda_graphs_due_to_env
3840- || cuda_ctx->cuda_graph ->disable_due_to_gpu_arch
3841- || cuda_ctx->cuda_graph ->disable_due_to_too_many_updates
3842- || cuda_ctx->cuda_graph ->disable_due_to_failed_graph_capture ) {
3838+ if (use_cuda_graph && (
3839+ cuda_ctx->cuda_graph ->disable_due_to_gpu_arch ||
3840+ cuda_ctx->cuda_graph ->disable_due_to_too_many_updates ||
3841+ cuda_ctx->cuda_graph ->disable_due_to_failed_graph_capture )) {
38433842 use_cuda_graph = false ;
38443843 }
38453844
@@ -4287,6 +4286,11 @@ struct cuda_params {
42874286 int fusion = GGML_CUDA_FUSION;
42884287 int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
42894288 int mmq_id_thresh = 32 ;
4289+ #ifdef USE_CUDA_GRAPH
4290+ bool use_cuda_graph = true ;
4291+ #else
4292+ bool use_cuda_graph = false ;
4293+ #endif
42904294};
42914295
42924296static std::vector<std::string> string_split (const std::string& str, const std::string& delimiter) {
@@ -4333,6 +4337,11 @@ static cuda_params ggml_cuda_parse_params(const char * params_string) {
43334337 else if (parsed[0 ] == " mmq-id-size" ) {
43344338 is_good = read_value (parsed[1 ], params.mmq_id_thresh );
43354339 }
4340+ #ifdef USE_CUDA_GRAPH
4341+ else if (parsed[0 ] == " graphs" ) {
4342+ is_good = read_value (parsed[1 ], params.use_cuda_graph );
4343+ }
4344+ #endif
43364345 }
43374346 if (!is_good) {
43384347 GGML_CUDA_LOG_WARN (" %s: invalid parameter %s (%d) -> ignored\n " , __func__, value.c_str (), (int )parsed.size ());
@@ -4373,6 +4382,12 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con
43734382 GGML_CUDA_LOG_INFO (" =========================== %s: setting mmq_id_thresh to %d\n " , __func__, params.mmq_id_thresh );
43744383 ctx->mmq_id_thresh = params.mmq_id_thresh ;
43754384 }
4385+ #ifdef USE_CUDA_GRAPH
4386+ if (params.use_cuda_graph != ctx->use_cuda_graph ) {
4387+ GGML_CUDA_LOG_INFO (" =========================== %s: setting use_cuda_graph to %d\n " , __func__, params.use_cuda_graph );
4388+ ctx->use_cuda_graph = params.use_cuda_graph ;
4389+ }
4390+ #endif
43764391 }
43774392
43784393 return cuda_backend;
0 commit comments