@@ -2007,118 +2007,73 @@ struct clip_graph {
20072007
20082008 return gf;
20092009 }
2010- // -------------------------------------------------------------------------
2011- // Helper: Space-to-Depth + Grid Stitching (Matches reshape_hd_patches_2x2merge)
2012- // -------------------------------------------------------------------------
2010+
20132011 static struct ggml_tensor * ggml_phi3v_hd_merge (
20142012 struct ggml_context * ctx,
2015- struct ggml_tensor * image_features, // Input: {1024, 24, 24, N}
2013+ struct ggml_tensor * image_features,
20162014 int h_crop,
20172015 int w_crop
20182016 ) {
2019- // N = total patches (num_images * h_crop * w_crop)
2020- // For inference, we typically process 1 image batch at a time, so N = h_crop * w_crop
2021- int N = image_features->ne [3 ];
2022- const int C = 1024 ;
2023- const int H = 24 ;
2024- const int H2 = 12 ;
2017+ int n_images = image_features->ne [3 ];
2018+ const int n_channels = 1024 ;
2019+ const int size = 24 ;
2020+ const int size_half = 12 ;
20252021
20262022 struct ggml_tensor * t = image_features;
2027-
2028- t = ggml_reshape_4d (ctx, t, C, 2 , H2, H * N);
2029-
2030- t = ggml_reshape_4d (ctx, t, C * 2 , H2, 2 , H2 * N);
2031-
2023+ t = ggml_reshape_4d (ctx, t, n_channels, 2 , size_half, size * n_images);
2024+ t = ggml_reshape_4d (ctx, t, n_channels * 2 , size_half, 2 , size_half * n_images);
20322025 t = ggml_permute (ctx, t, 0 , 2 , 1 , 3 );
20332026 t = ggml_cont (ctx, t);
20342027
20352028 if (h_crop == 1 && w_crop == 1 ) {
20362029 return t;
20372030 }
20382031
2039- const int C_NEW = C * 4 ;
2040-
2041- t = ggml_reshape_4d (ctx, t, C_NEW * H2, H2, w_crop, h_crop);
2042-
2032+ const int n_channels_new = n_channels * 4 ;
2033+ t = ggml_reshape_4d (ctx, t, n_channels_new * size_half, size_half, w_crop, h_crop);
20432034 t = ggml_permute (ctx, t, 0 , 2 , 1 , 3 );
20442035 t = ggml_cont (ctx, t);
20452036
20462037 t = ggml_reshape_4d (ctx, t,
2047- C_NEW ,
2048- H2 * w_crop,
2049- H2 * h_crop,
2038+ n_channels_new ,
2039+ size_half * w_crop,
2040+ size_half * h_crop,
20502041 1
20512042 );
20522043
20532044 return t;
20542045 }
20552046
20562047 ggml_cgraph * build_phi3v () {
2057- // 1. Prepare Input (Patches)
2058- // ---------------------------------------------------------------------
2059- ggml_tensor * inp = build_inp (); // [n_embd, 576, num_crops]
2060-
2061- // Calculate grid (e.g. 2x2 grid -> 4 crops + 1 global = 5 crops)
2062- int n_patches_per_crop = 24 * 24 ; // 576
2063- int num_crops = inp->ne [2 ]; // Passed from the batch encoder
2064-
2065- // Reshape to [n_embd, 576, num_crops]
2048+ ggml_tensor * inp = build_inp ();
2049+ int n_patches_per_crop = 24 * 24 ;
2050+ int num_crops = inp->ne [2 ];
20662051 inp = ggml_reshape_3d (ctx0, inp, n_embd, n_patches_per_crop, num_crops);
20672052
2068- // 1. Prepend CLS Token
20692053 ggml_tensor * cls = model.class_embedding ;
20702054 cls = ggml_reshape_3d (ctx0, cls, n_embd, 1 , 1 );
20712055 cls = ggml_repeat (ctx0, cls, ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_embd, 1 , num_crops));
20722056
2073- // Concat: [CLS, Patch0 ... Patch575]
20742057 inp = ggml_concat (ctx0, cls, inp, 1 );
2075-
2076- // 2. Add Full Position Embeddings (Do not slice!)
20772058 inp = ggml_add (ctx0, inp, model.position_embeddings );
20782059
2079- // ---------------------------------------------------------------------
2080- // 2. Run Vision Transformer (ViT)
2081- // ---------------------------------------------------------------------
2082-
2083- // Flatten for the transformer: [n_embd, 577 * num_crops]
20842060 inp = ggml_reshape_2d (ctx0, inp, n_embd, (n_patches_per_crop + 1 ) * num_crops);
20852061
2086- // Run the layers on the full sequence (577 tokens per crop)
20872062 ggml_tensor * cur = build_vit (inp, (n_patches_per_crop + 1 ) * num_crops,
20882063 NORM_TYPE_NORMAL, hparams.ffn_op , nullptr , nullptr );
20892064
2090- // Reshape back to separate the crops: [n_embd, 577, num_crops]
20912065
20922066 cur = ggml_reshape_3d (ctx0, cur, n_embd, n_patches_per_crop + 1 , num_crops);
20932067
2094- // Slice: Keep indices 1..577 (The 576 spatial patches)
2095- // We skip Index 0 (CLS).
2096- // Dimensions: n_embd, 576, num_crops
2097- // Offset: 1 * nb[1] (Skip one column of embeddings)
2098-
20992068 cur = ggml_view_3d (ctx0, cur,
21002069 n_embd, n_patches_per_crop, num_crops,
21012070 cur->nb [1 ], cur->nb [2 ],
2102- cur->nb [1 ]); // <--- Offset starts at token 1
2071+ cur->nb [1 ]);
21032072
2104- // Make it contiguous memory for the HD Merge step
2105- // [n_embd, 24, 24, num_crops]
21062073 cur = ggml_reshape_4d (ctx0, cur, n_embd, 24 , 24 , num_crops);
21072074 cur = ggml_cont (ctx0, cur);
21082075
2109- // ---------------------------------------------------------------------
2110- // 3. HD Merge (Space-to-Depth)
2111- // ---------------------------------------------------------------------
2112- // Now 'cur' contains only the 24x24 spatial tokens, perfect for merging.
2113- // We use 1x1 here because we are processing the batch in the graph,
2114- // but we treat crops as independent until the CPU stitching step.
2115-
21162076 cur = ggml_phi3v_hd_merge (ctx0, cur, 1 , 1 );
2117-
2118- // ---------------------------------------------------------------------
2119- // 4. MLP Projection
2120- // ---------------------------------------------------------------------
2121- // Flatten to [4096, 144 * num_crops]
21222077 cur = ggml_reshape_2d (ctx0, cur, 4096 , 144 * num_crops);
21232078
21242079 ggml_tensor * final_emb = ggml_mul_mat (ctx0, model.mm_0_w , cur);
@@ -5407,19 +5362,13 @@ static void clip_phi3_setup(clip_ctx * ctx) {
54075362
54085363 // Helper to build MLP graph for a single vector
54095364 auto build_mlp = [&](ggml_tensor* input) {
5410- // Layer 1 (4096 -> 3072)
54115365 ggml_tensor* cur = ggml_mul_mat (ctx0, ctx->model .mm_0_w , input);
5412- // Bias 0
54135366 if (ctx->model .mm_0_b ) {
54145367 ggml_tensor* b = ctx->model .mm_0_b ;
54155368 cur = ggml_add (ctx0, cur, b);
54165369 }
5417- // GELU
54185370 cur = ggml_gelu (ctx0, cur);
5419-
5420- // Layer 2 (3072 -> 3072)
54215371 cur = ggml_mul_mat (ctx0, ctx->model .mm_2_w , cur);
5422- // Bias 2
54235372 if (ctx->model .mm_2_b ) {
54245373 ggml_tensor* b = ctx->model .mm_2_b ;
54255374 cur = ggml_add (ctx0, cur, b);
@@ -5441,13 +5390,11 @@ static void clip_phi3_setup(clip_ctx * ctx) {
54415390 ggml_build_forward_expand (gf, res_sub);
54425391 }
54435392
5444- // Compute
54455393 ggml_backend_sched_reset (ctx->sched .get ());
54465394 ggml_backend_sched_alloc_graph (ctx->sched .get (), gf);
54475395 ggml_backend_sched_graph_compute (ctx->sched .get (), gf);
54485396
5449- // Save results
5450- int dim = clip_n_mmproj_embd (ctx); // 3072
5397+ int dim = clip_n_mmproj_embd (ctx);
54515398
54525399 if (res_glb) {
54535400 ctx->model .phi3_proj_glb_GN .resize (dim);
@@ -5466,12 +5413,8 @@ static void clip_phi3_setup(clip_ctx * ctx) {
54665413bool clip_image_batch_encode_phi3 (struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec) {
54675414 if (!ctx || !imgs || !vec) return false ;
54685415
5469- // 1. SETUP: Project Separators (4096 -> 3072)
5470- // ----------------------------------------------------------------
5471- // This runs the MLP on the special tokens ONCE.
54725416 clip_phi3_setup (ctx);
54735417
5474- // Sanity Check: If setup failed, we cannot proceed because
54755418 if (ctx->model .phi3_proj_sub_GN .empty () || ctx->model .phi3_proj_glb_GN .empty ()) {
54765419 fprintf (stderr, " %s: Error - Phi-3 separators not initialized.\n " , __func__);
54775420 return false ;
@@ -5480,14 +5423,10 @@ bool clip_image_batch_encode_phi3(struct clip_ctx * ctx, int n_threads, const st
54805423 const auto & entries = imgs->entries ;
54815424 int n_crops = entries.size ();
54825425
5483- // Phi-3 Vision requires at least 1 crop (Global only) or Locals + Global
54845426 if (n_crops < 1 ) return false ;
54855427
5486- // Dimension of the embedding (e.g., 3072 for ViT-L/14 CLIP, or 4096 depending on projection)
54875428 int dim = clip_n_mmproj_embd (ctx);
54885429
5489- // 1. Identify Grid Dimensions
5490- // ----------------------------------------------------------------
54915430 int w_crop = imgs->grid_x ;
54925431 int h_crop = imgs->grid_y ;
54935432 int n_sub_images = w_crop * h_crop;
@@ -5500,24 +5439,17 @@ bool clip_image_batch_encode_phi3(struct clip_ctx * ctx, int n_threads, const st
55005439 const int grid_side = 12 ; // 12x12 tokens per crop (after 2x2 pooling of 24x24 CLIP output)
55015440 const int sub_crop_tokens = grid_side * grid_side; // 144 tokens (raw image embeddings)
55025441
5503- // Temporary buffer to hold the output of a single crop encoding (144 * dim)
55045442 std::vector<float > crop_output (sub_crop_tokens * dim);
55055443
5506- // Buffer to store ALL raw local crops before stitching
5507- // We calculate size: Number of crops * 144 tokens * embedding dimension
55085444 std::vector<float > all_local_crops;
55095445 if (n_sub_images > 0 ) {
55105446 all_local_crops.resize (n_sub_images * sub_crop_tokens * dim);
55115447 }
55125448
55135449 float * dest = vec;
55145450
5515- // 2. Encode and Store Local Crops
5516- // ----------------------------------------------------------------
5517- // We process the first 'n_sub_images' entries as the High-Res crops
55185451 for (int i = 0 ; i < n_sub_images; ++i) {
5519- // Encode individual crop (results in 144 vectors of size 'dim')
5520- // Note: clip_image_encode typically handles normalization and the ViT forward pass
5452+
55215453 bool ok = clip_image_encode (ctx, n_threads, entries[i].get (), crop_output.data ());
55225454 if (!ok) return false ;
55235455
@@ -5527,74 +5459,48 @@ bool clip_image_batch_encode_phi3(struct clip_ctx * ctx, int n_threads, const st
55275459 sub_crop_tokens * dim * sizeof (float ));
55285460 }
55295461
5530- // 3. Stitch Local Crops into 'vec' with Newlines
5531- // ----------------------------------------------------------------
55325462 if (n_sub_images > 0 ) {
5533- // We iterate over the logical rows of the *combined* high-res image.
5534- // Total rows = (Number of Vertical Crops) * (12 tokens per crop height)
55355463 for (int row_global = 0 ; row_global < h_crop * grid_side; ++row_global) {
5536-
5537- // Calculate which vertical crop index we are in (e.g., Crop Row 0 or 1)
55385464 int crop_y = row_global / grid_side;
5539- // Calculate the internal row index within that specific crop (0 to 11)
55405465 int internal_y = row_global % grid_side;
5541-
5542- // Iterate through the crops horizontally for this specific row
55435466 for (int crop_x = 0 ; crop_x < w_crop; ++crop_x) {
5544- // Calculate the linear index of the crop in 'all_local_crops'
55455467 int crop_idx = crop_y * w_crop + crop_x;
55465468
5547- // Calculate pointer to the start of the specific row inside that crop
5548- // Offset = [Start of Crop] + [Row offset within Crop]
55495469 float * src = all_local_crops.data () +
55505470 (crop_idx * sub_crop_tokens * dim) +
55515471 (internal_y * grid_side * dim);
55525472
5553- // Copy the row (12 tokens) to the destination
55545473 memcpy (dest, src, grid_side * dim * sizeof (float ));
55555474 dest += (grid_side * dim);
55565475 }
55575476
5558- // ADD NEWLINE (sub_GN) after the full stitched row is complete
5559- // This happens once per Global Row (e.g., 24 times for a 2x2 grid)
55605477 if (!ctx->model .phi3_proj_sub_GN .empty ()) {
55615478 memcpy (dest, ctx->model .phi3_proj_sub_GN .data (), dim * sizeof (float ));
55625479 } else {
5563- // Fallback for safety, though phi3_proj_sub_GN should be populated during load
55645480 memset (dest, 0 , dim * sizeof (float ));
55655481 }
55665482 dest += dim;
55675483 }
55685484 }
55695485
5570- // 4. Inject Global Separator (glb_GN)
5571- // ----------------------------------------------------------------
5572- // This token separates the High-Res stitched canvas from the Global Low-Res view
55735486 if (!ctx->model .phi3_proj_glb_GN .empty ()) {
55745487 memcpy (dest, ctx->model .phi3_proj_glb_GN .data (), dim * sizeof (float ));
55755488 } else {
55765489 memset (dest, 0 , dim * sizeof (float ));
55775490 }
55785491 dest += dim;
55795492
5580- // 5. Process Global Crop (Last Entry)
5581- // ----------------------------------------------------------------
5582- // The Global crop is always the last one in the batch.
5583- // It is treated as a single 12x12 grid with newlines after every row.
55845493 {
55855494 bool ok = clip_image_encode (ctx, n_threads, entries.back ().get (), crop_output.data ());
55865495 if (!ok) return false ;
55875496
55885497 float * src = crop_output.data ();
55895498
5590- // Iterate over the 12 rows of the Global crop
55915499 for (int r = 0 ; r < grid_side; ++r) {
5592- // Copy the row of tokens (12 tokens)
55935500 memcpy (dest, src, grid_side * dim * sizeof (float ));
55945501 dest += (grid_side * dim);
55955502 src += (grid_side * dim);
55965503
5597- // Add Newline (sub_GN) after every row
55985504 if (!ctx->model .phi3_proj_sub_GN .empty ()) {
55995505 memcpy (dest, ctx->model .phi3_proj_sub_GN .data (), dim * sizeof (float ));
56005506 } else {
0 commit comments