@@ -3852,7 +3852,43 @@ def set_gguf_parameters(self):
38523852 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
38533853 # process the experts separately
38543854 name = name .replace ("language_model." , "" ) # InternVL
3855- if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ):
3855+
3856+ # handle aggregated expert tensors
3857+ # GGUF stores dimensions reversed from PyTorch, so:
3858+ # PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
3859+ # Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
3860+ # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
3861+ if name .endswith ("mlp.experts.down_proj" ) or name .endswith ("mlp.experts.down_proj.weight" ):
3862+ mapped = f"{ name } .weight" if not name .endswith (".weight" ) else name
3863+ # Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
3864+ # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
3865+ # Need PyTorch: (128, 2048, 768) [reversed of GGML]
3866+ # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
3867+ permuted = data_torch .permute (0 , 2 , 1 ).contiguous ()
3868+ return [(self .map_tensor_name (mapped ), permuted )]
3869+
3870+ if name .endswith ("mlp.experts.gate_up_proj" ) or name .endswith ("mlp.experts.gate_up_proj.weight" ):
3871+ if data_torch .ndim < 3 or data_torch .shape [- 1 ] % 2 != 0 :
3872+ raise ValueError (f"Unexpected gate_up_proj shape for { name } : { tuple (data_torch .shape )} " )
3873+ split_dim = data_torch .shape [- 1 ] // 2
3874+ gate = data_torch [..., :split_dim ].contiguous ()
3875+ up = data_torch [..., split_dim :].contiguous ()
3876+ # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
3877+ # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
3878+ # Need PyTorch: (128, 768, 2048) [reversed of GGML]
3879+ # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
3880+ base_name = name .removesuffix (".weight" )
3881+ base = base_name .rsplit ('.' , 1 )[0 ]
3882+ mapped_gate = f"{ base } .gate_proj.weight"
3883+ mapped_up = f"{ base } .up_proj.weight"
3884+ perm_gate = gate .permute (0 , 2 , 1 ).contiguous ()
3885+ perm_up = up .permute (0 , 2 , 1 ).contiguous ()
3886+ return [
3887+ (self .map_tensor_name (mapped_gate ), perm_gate ),
3888+ (self .map_tensor_name (mapped_up ), perm_up ),
3889+ ]
3890+
3891+ if name .startswith ("mlp" ) or name .startswith ("vision_model" ) or name .startswith ("model.vision_tower" ) or name .startswith ("model.multi_modal_projector" ) or name .startswith ("model.visual" ):
38563892 # skip visual tensors
38573893 return []
38583894 if name .find ("experts" ) != - 1 :
@@ -4004,6 +4040,201 @@ def set_vocab(self):
40044040 super ().set_vocab ()
40054041
40064042
4043+ @ModelBase .register ("Qwen3VLForConditionalGeneration" , "Qwen3VLMoeForConditionalGeneration" )
4044+ class Qwen3VLVisionModel (MmprojModel ):
4045+ def __init__ (self , * args , ** kwargs ):
4046+ super ().__init__ (* args , ** kwargs )
4047+ assert self .hparams_vision is not None
4048+ # Compute image_size if not present
4049+ if "image_size" not in self .hparams_vision :
4050+ # For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
4051+ num_pos = self .hparams_vision .get ("num_position_embeddings" , 2304 )
4052+ patch_size = self .hparams_vision .get ("patch_size" , 16 )
4053+ # num_position_embeddings = (image_size / patch_size) ** 2
4054+ # So image_size = sqrt(num_position_embeddings) * patch_size
4055+ image_size = int (num_pos ** 0.5 * patch_size )
4056+ self .hparams_vision ["image_size" ] = image_size
4057+
4058+ # Rename config values for compatibility
4059+ self .hparams_vision ["num_attention_heads" ] = self .hparams_vision .get ("num_heads" )
4060+ self .hparams_vision ["num_hidden_layers" ] = self .hparams_vision .get ("depth" )
4061+
4062+ self .deepstack_layers : list [int ] = list (self .hparams_vision .get ("deepstack_visual_indexes" , []))
4063+
4064+ def set_gguf_parameters (self ):
4065+ super ().set_gguf_parameters ()
4066+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .QWEN3VL )
4067+ self .gguf_writer .add_vision_use_gelu (True )
4068+
4069+ if self .hparams_vision is not None :
4070+ merge_size = self .hparams_vision .get ("spatial_merge_size" )
4071+ if merge_size is not None :
4072+ self .gguf_writer .add_vision_spatial_merge_size (int (merge_size ))
4073+
4074+ # Use text config's rms_norm_eps for vision attention layernorm eps
4075+ rms_norm_eps = self .global_config .get ("text_config" , {}).get ("rms_norm_eps" , 1e-6 )
4076+ self .gguf_writer .add_vision_attention_layernorm_eps (rms_norm_eps )
4077+
4078+ if self .deepstack_layers :
4079+ self .gguf_writer .add_vision_deepstack_layers (self .deepstack_layers )
4080+
4081+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4082+ # Skip text model tensors - they go in the text model file
4083+ if name .startswith ("model.language_model." ) or name .startswith ("lm_head." ):
4084+ return []
4085+
4086+ if name .startswith ("model.visual." ):
4087+ name = name .replace ("model.visual." , "visual." , 1 )
4088+
4089+ if name .startswith ("visual.deepstack_merger_list." ):
4090+ prefix , rest = name .split ("." , maxsplit = 3 )[2 :]
4091+ idx = int (prefix )
4092+ target = rest
4093+
4094+ tensor_type : gguf .MODEL_TENSOR
4095+ if target .startswith ("norm." ):
4096+ tensor_type = gguf .MODEL_TENSOR .V_DS_NORM
4097+ suffix = target .split ("." , 1 )[1 ]
4098+ elif target .startswith ("linear_fc1." ):
4099+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC1
4100+ suffix = target .split ("." , 1 )[1 ]
4101+ elif target .startswith ("linear_fc2." ):
4102+ tensor_type = gguf .MODEL_TENSOR .V_DS_FC2
4103+ suffix = target .split ("." , 1 )[1 ]
4104+ else :
4105+ raise ValueError (f"Unexpected deepstack tensor: { name } " )
4106+
4107+ new_name = self .format_tensor_name (tensor_type , idx , suffix = f".{ suffix } " )
4108+ return [(new_name , data_torch )]
4109+
4110+ if name .startswith ("visual.merger." ):
4111+ suffix = name .split ("." , 2 )[2 ]
4112+ if suffix .startswith ("linear_fc" ):
4113+ fc_idx_str , tail = suffix .split ("." , 1 )
4114+ fc_num = int (fc_idx_str .replace ("linear_fc" , "" ))
4115+ # Qwen3VL has linear_fc1 and linear_fc2
4116+ # Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
4117+ if fc_num == 1 :
4118+ fc_idx = 0
4119+ elif fc_num == 2 :
4120+ fc_idx = 2
4121+ else :
4122+ raise ValueError (f"unexpected fc index { fc_num } in { name } " )
4123+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_MMPROJ , fc_idx , suffix = f".{ tail } " )
4124+ elif suffix .startswith ("norm." ):
4125+ new_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_POST_NORM , suffix = f".{ suffix .split ('.' , 1 )[1 ]} " )
4126+ else :
4127+ raise ValueError (f"Unexpected merger tensor: { name } " )
4128+ return [(new_name , data_torch )]
4129+
4130+ if name == "visual.patch_embed.proj.weight" :
4131+ # split Conv3D into Conv2Ds along temporal dimension
4132+ c1 , c2 , kt , _ , _ = data_torch .shape
4133+ del c1 , c2
4134+ if kt != 2 :
4135+ raise ValueError ("Current implementation only supports temporal_patch_size of 2" )
4136+ return [
4137+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight" , data_torch [:, :, 0 , ...]),
4138+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".weight.1" , data_torch [:, :, 1 , ...]),
4139+ ]
4140+
4141+ if name == "visual.patch_embed.proj.bias" :
4142+ # Include the bias - it's used by the C++ code
4143+ return [(gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ] + ".bias" , data_torch )]
4144+
4145+ if name .startswith ("visual." ):
4146+ if ".qkv." in name :
4147+ if data_torch .ndim == 2 :
4148+ c3 , _ = data_torch .shape
4149+ else :
4150+ c3 = data_torch .shape [0 ]
4151+ if c3 % 3 != 0 :
4152+ raise ValueError (f"Unexpected QKV shape for { name } : { data_torch .shape } " )
4153+ c = c3 // 3
4154+ wq = data_torch [:c ]
4155+ wk = data_torch [c : c * 2 ]
4156+ wv = data_torch [c * 2 :]
4157+ base = name .replace ("qkv" , "{placeholder}" )
4158+ return [
4159+ (self .map_tensor_name (base .format (placeholder = "q" )), wq ),
4160+ (self .map_tensor_name (base .format (placeholder = "k" )), wk ),
4161+ (self .map_tensor_name (base .format (placeholder = "v" )), wv ),
4162+ ]
4163+
4164+ return [(self .map_tensor_name (name ), data_torch )]
4165+
4166+ # Fall back to parent class for other tensors
4167+ return super ().modify_tensors (data_torch , name , bid )
4168+
4169+
4170+ @ModelBase .register ("Qwen3VLForConditionalGeneration" )
4171+ class Qwen3VLTextModel (Qwen3Model ):
4172+ model_arch = gguf .MODEL_ARCH .QWEN3VL
4173+
4174+ def set_gguf_parameters (self ):
4175+ super ().set_gguf_parameters ()
4176+
4177+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4178+ text_config = self .hparams .get ("text_config" , {})
4179+ # rope_scaling is deprecated in V5, use rope_parameters instead
4180+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4181+
4182+ if rope_scaling .get ("mrope_section" ):
4183+ # mrope_section contains [time, height, width] dimensions
4184+ mrope_section = rope_scaling ["mrope_section" ]
4185+ # Pad to 4 dimensions [time, height, width, extra]
4186+ while len (mrope_section ) < 4 :
4187+ mrope_section .append (0 )
4188+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4189+
4190+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4191+
4192+ vision_config = self .hparams .get ("vision_config" , {})
4193+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4194+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4195+
4196+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4197+ # Skip vision tensors - they go in the mmproj file
4198+ if name .startswith ("model.visual." ):
4199+ return []
4200+
4201+ return super ().modify_tensors (data_torch , name , bid )
4202+
4203+
4204+ @ModelBase .register ("Qwen3VLMoeForConditionalGeneration" )
4205+ class Qwen3VLMoeTextModel (Qwen3MoeModel ):
4206+ model_arch = gguf .MODEL_ARCH .QWEN3VLMOE
4207+
4208+ def set_gguf_parameters (self ):
4209+ super ().set_gguf_parameters ()
4210+
4211+ # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4212+ text_config = self .hparams .get ("text_config" , {})
4213+ # rope_scaling is deprecated in V5, use rope_parameters instead
4214+ rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4215+
4216+ if rope_scaling .get ("mrope_section" ):
4217+ # mrope_section contains [time, height, width] dimensions
4218+ mrope_section = rope_scaling ["mrope_section" ]
4219+ # Pad to 4 dimensions [time, height, width, extra]
4220+ while len (mrope_section ) < 4 :
4221+ mrope_section .append (0 )
4222+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4223+
4224+ logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4225+
4226+ vision_config = self .hparams .get ("vision_config" , {})
4227+ deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
4228+ self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
4229+
4230+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4231+ # Skip vision tensors - they go in the mmproj file
4232+ if name .startswith ("model.visual." ):
4233+ return []
4234+
4235+ return super ().modify_tensors (data_torch , name , bid )
4236+
4237+
40074238@ModelBase .register ("GPT2LMHeadModel" )
40084239class GPT2Model (TextModel ):
40094240 model_arch = gguf .MODEL_ARCH .GPT2
0 commit comments