@@ -428,6 +428,7 @@ struct clip_ctx {
428428 int max_nodes = 8192 ;
429429 ggml_backend_sched_ptr sched;
430430 clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
431+ bool is_allocated = false ;
431432
432433 // for debugging
433434 bool debug_graph = false ;
@@ -3305,12 +3306,30 @@ struct clip_model_loader {
33053306 };
33063307
33073308 static void warmup (clip_ctx & ctx_clip) {
3309+ // create a fake batch
3310+ const auto & hparams = ctx_clip.model .hparams ;
3311+ clip_image_f32_batch batch;
3312+ clip_image_f32_ptr img (clip_image_f32_init ());
3313+ if (ctx_clip.model .modality == CLIP_MODALITY_VISION) {
3314+ img->nx = hparams.warmup_image_size ;
3315+ img->ny = hparams.warmup_image_size ;
3316+ LOG_INF (" %s: warmup with image size = %d x %d\n " , __func__, img->nx , img->ny );
3317+ } else {
3318+ img->nx = hparams.warmup_audio_size ;
3319+ img->ny = hparams.n_mel_bins ;
3320+ LOG_INF (" %s: warmup with audio size = %d\n " , __func__, img->nx );
3321+ }
3322+ batch.entries .push_back (std::move (img));
3323+ warmup (ctx_clip, batch);
3324+ }
3325+
3326+ static void warmup (clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
33083327 support_info_graph info;
33093328
33103329 if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
33113330 // try to enable flash attention to see if it's supported
33123331 ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
3313- info = alloc_compute_meta (ctx_clip);
3332+ info = alloc_compute_meta (ctx_clip, batch );
33143333 if (!info.fattn && info.fattn_op ) {
33153334 auto op = info.fattn_op ;
33163335 LOG_WRN (" %s: *****************************************************************\n " , __func__);
@@ -3329,15 +3348,17 @@ struct clip_model_loader {
33293348 LOG_WRN (" %s: please report this on github as an issue\n " , __func__);
33303349 LOG_WRN (" %s: *****************************************************************\n " , __func__);
33313350 ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
3332- alloc_compute_meta (ctx_clip);
3351+ alloc_compute_meta (ctx_clip, batch );
33333352 }
33343353 } else {
3335- info = alloc_compute_meta (ctx_clip);
3354+ info = alloc_compute_meta (ctx_clip, batch );
33363355 if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
33373356 LOG_WRN (" %s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n " , __func__);
33383357 }
33393358 }
33403359
3360+ ctx_clip.is_allocated = true ; // mark buffers as allocated
3361+
33413362 LOG_INF (" %s: flash attention is %s\n " , __func__,
33423363 (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? " enabled" : " disabled" );
33433364
@@ -3369,24 +3390,9 @@ struct clip_model_loader {
33693390 }
33703391 }
33713392
3372- static support_info_graph alloc_compute_meta (clip_ctx & ctx_clip) {
3373- const auto & hparams = ctx_clip.model .hparams ;
3393+ static support_info_graph alloc_compute_meta (clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
33743394 ctx_clip.buf_compute_meta .resize (ctx_clip.max_nodes * ggml_tensor_overhead () + ggml_graph_overhead ());
33753395
3376- // create a fake batch
3377- clip_image_f32_batch batch;
3378- clip_image_f32_ptr img (clip_image_f32_init ());
3379- if (ctx_clip.model .modality == CLIP_MODALITY_VISION) {
3380- img->nx = hparams.warmup_image_size ;
3381- img->ny = hparams.warmup_image_size ;
3382- LOG_INF (" %s: warmup with image size = %d x %d\n " , __func__, img->nx , img->ny );
3383- } else {
3384- img->nx = hparams.warmup_audio_size ;
3385- img->ny = hparams.n_mel_bins ;
3386- LOG_INF (" %s: warmup with audio size = %d\n " , __func__, img->nx );
3387- }
3388- batch.entries .push_back (std::move (img));
3389-
33903396 ggml_cgraph * gf = clip_image_build_graph (&ctx_clip, batch);
33913397 ggml_backend_sched_reserve (ctx_clip.sched .get (), gf);
33923398
@@ -4630,6 +4636,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
46304636 return false ; // only support batch size of 1
46314637 }
46324638
4639+ // if buffers are not allocated, we need to do a warmup run to allocate them
4640+ if (!ctx->is_allocated ) {
4641+ clip_model_loader::warmup (*ctx, *imgs_c_ptr);
4642+ }
4643+
46334644 // build the inference graph
46344645 ctx->debug_print_tensors .clear ();
46354646 ggml_backend_sched_reset (ctx->sched .get ());
0 commit comments