@@ -218,8 +218,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
218218 logger .info (f"gguf: indexing model part '{ part_name } '" )
219219 ctx : ContextManager [Any ]
220220 if is_safetensors :
221- from safetensors import safe_open
222- ctx = cast (ContextManager [Any ], safe_open (self .dir_model / part_name , framework = "pt" , device = "cpu" ))
221+ ctx = cast (ContextManager [Any ], gguf .utility .SafetensorsLocal (self .dir_model / part_name ))
223222 else :
224223 ctx = contextlib .nullcontext (torch .load (str (self .dir_model / part_name ), map_location = "cpu" , mmap = True , weights_only = True ))
225224
@@ -228,18 +227,18 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
228227
229228 for name in model_part .keys ():
230229 if is_safetensors :
230+ data : gguf .utility .LocalTensor = model_part [name ]
231231 if self .lazy :
232- data = model_part .get_slice (name )
233- data_gen = lambda data = data : LazyTorchTensor .from_safetensors_slice (data ) # noqa: E731
232+ data_gen = lambda data = data : LazyTorchTensor .from_local_tensor (data ) # noqa: E731
234233 else :
235- data = model_part . get_tensor ( name )
236- data_gen = lambda data = data : data # noqa: E731
234+ dtype = LazyTorchTensor . _dtype_str_map [ data . dtype ]
235+ data_gen = lambda data = data , dtype = dtype : torch . from_numpy ( data . mmap_bytes ()). view ( dtype ). reshape ( data . shape ) # noqa: E731
237236 else :
238- data = model_part [name ]
237+ data_torch : Tensor = model_part [name ]
239238 if self .lazy :
240- data_gen = lambda data = data : LazyTorchTensor .from_eager (data ) # noqa: E731
239+ data_gen = lambda data = data_torch : LazyTorchTensor .from_eager (data ) # noqa: E731
241240 else :
242- data_gen = lambda data = data : data # noqa: E731
241+ data_gen = lambda data = data_torch : data # noqa: E731
243242 tensors [name ] = data_gen
244243
245244 # verify tensor name presence and identify potentially missing files
@@ -278,15 +277,14 @@ def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
278277 # The scale is inverted
279278 return data / scale .float ()
280279
281- def dequant_simple (weight : Tensor , scale : Tensor ) -> Tensor :
280+ def dequant_simple (weight : Tensor , scale : Tensor , block_size : Sequence [ int ] | None = None ) -> Tensor :
282281 scale = scale .float ()
283282
284- if (weight_block_size := quant_config .get ("weight_block_size" )):
285- # TODO: make sure it's a list of integers
286- for i , size in enumerate (weight_block_size ):
283+ if block_size is not None :
284+ for i , size in enumerate (block_size ):
287285 scale = scale .repeat_interleave (size , i )
288- # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
289- scale = scale [tuple (slice (0 , size ) for size in weight .shape )]
286+ # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
287+ scale = scale [tuple (slice (0 , size ) for size in weight .shape )]
290288
291289 return weight .float () * scale
292290
@@ -333,6 +331,40 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
333331
334332 return (scales [g_idx ].float () * (weight - zeros [g_idx ]).float ()).T
335333
334+ def dequant_packed (w : Tensor , scale : Tensor , shape_tensor : Tensor , zero_point : Tensor | None , num_bits : int , group_size : int ):
335+ assert w .dtype == torch .int32
336+ shape = tuple (shape_tensor .tolist ())
337+ assert len (shape ) == 2
338+ mask = (1 << num_bits ) - 1
339+
340+ shifts = torch .arange (0 , 32 - (num_bits - 1 ), num_bits , dtype = torch .int32 )
341+ if self .lazy :
342+ shifts = LazyTorchTensor .from_eager (shifts )
343+
344+ if zero_point is None :
345+ offset = 1 << (num_bits - 1 )
346+ else :
347+ assert len (zero_point .shape ) == 2
348+ offset = (zero_point .unsqueeze (1 ) >> shifts .reshape (1 , - 1 , 1 )) & mask
349+ offset = offset .reshape (- 1 , zero_point .shape [1 ])
350+ # trim padding, and prepare for broadcast
351+ # NOTE: the zero-point is packed along dim 0
352+ offset = offset [:shape [0 ], :].unsqueeze (- 1 )
353+
354+ # extract values
355+ # NOTE: the weights are packed along dim 1
356+ unpacked = (w .unsqueeze (- 1 ) >> shifts .reshape (1 , 1 , - 1 )) & mask
357+ unpacked = unpacked .reshape (shape [0 ], - 1 )
358+
359+ # trim padding
360+ unpacked = unpacked [:, :shape [1 ]]
361+
362+ # prepare for broadcast of the scale
363+ unpacked = unpacked .reshape (shape [0 ], (unpacked .shape [- 1 ] + group_size - 1 ) // group_size , group_size )
364+ unpacked = unpacked - offset
365+
366+ return (unpacked * scale .unsqueeze (- 1 ).float ()).reshape (shape )
367+
336368 if quant_method == "bitnet" :
337369 for name in self .model_tensors .keys ():
338370 if name .endswith (".weight_scale" ):
@@ -342,12 +374,13 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
342374 self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_bitnet (w (), s ())
343375 tensors_to_remove .append (name )
344376 elif quant_method == "fp8" :
377+ block_size = quant_config .get ("weight_block_size" )
345378 for name in self .model_tensors .keys ():
346379 if name .endswith (".weight_scale_inv" ):
347380 weight_name = name .removesuffix ("_scale_inv" )
348381 w = self .model_tensors [weight_name ]
349382 s = self .model_tensors [name ]
350- self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s ())
383+ self .model_tensors [weight_name ] = lambda w = w , s = s , bs = block_size : dequant_simple (w (), s (), bs )
351384 tensors_to_remove .append (name )
352385 elif quant_method == "gptq" :
353386 for name in self .model_tensors .keys ():
@@ -371,6 +404,49 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
371404 ".scales" ,
372405 )
373406 ]
407+ elif quant_method == "compressed-tensors" :
408+ quant_format = quant_config ["format" ]
409+ groups = quant_config ["config_groups" ]
410+ if len (groups ) > 1 :
411+ raise NotImplementedError ("Can't handle multiple config groups for compressed-tensors yet" )
412+ weight_config = tuple (groups .values ())[0 ]["weights" ]
413+
414+ if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized" :
415+ block_size = weight_config .get ("block_structure" , None )
416+ strategy = weight_config .get ("strategy" )
417+ assert strategy == "channel" or strategy == "block"
418+ assert weight_config .get ("group_size" ) is None # didn't find a model using this yet
419+ for name in self .model_tensors .keys ():
420+ if name .endswith (".weight_scale" ):
421+ weight_name = name .removesuffix ("_scale" )
422+ w = self .model_tensors [weight_name ]
423+ s = self .model_tensors [name ]
424+ self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s (), block_size )
425+ tensors_to_remove .append (name )
426+ elif quant_format == "pack-quantized" :
427+ assert weight_config .get ("strategy" ) == "group"
428+ assert weight_config .get ("type" , "int" ) == "int"
429+ num_bits = weight_config .get ("num_bits" )
430+ group_size = weight_config .get ("group_size" )
431+ assert isinstance (num_bits , int )
432+ assert isinstance (group_size , int )
433+ for name in self .model_tensors .keys ():
434+ if name .endswith (".weight_packed" ):
435+ base_name = name .removesuffix ("_packed" )
436+ w = self .model_tensors [name ]
437+ scale = self .model_tensors [base_name + "_scale" ]
438+ shape = self .model_tensors [base_name + "_shape" ]
439+ zero_point = self .model_tensors .get (base_name + "_zero_point" , lambda : None )
440+ new_tensors [base_name ] = (
441+ lambda w = w , scale = scale , shape = shape , zero_point = zero_point : dequant_packed (
442+ w (), scale (), shape (), zero_point (), num_bits , group_size ,
443+ )
444+ )
445+ tensors_to_remove += [base_name + n for n in ("_packed" , "_shape" , "_scale" )]
446+ if (base_name + "_zero_point" ) in self .model_tensors :
447+ tensors_to_remove .append (base_name + "_zero_point" )
448+ else :
449+ raise NotImplementedError (f"Quant format { quant_format !r} for method { quant_method !r} is not yet supported" )
374450 else :
375451 raise NotImplementedError (f"Quant method is not yet supported: { quant_method !r} " )
376452
@@ -10002,6 +10078,16 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
1000210078 lazy = cls (meta = cls .meta_with_dtype_and_shape (dtype , shape ), args = (st_slice ,), func = lambda s : s [...] if len (s .get_shape ()) == 0 else s [:])
1000310079 return cast (torch .Tensor , lazy )
1000410080
10081+ @classmethod
10082+ def from_local_tensor (cls , t : gguf .utility .LocalTensor ) -> Tensor :
10083+ def load_tensor (tensor : gguf .utility .LocalTensor ) -> Tensor :
10084+ dtype = cls ._dtype_str_map [tensor .dtype ]
10085+ return torch .from_numpy (tensor .mmap_bytes ()).view (dtype ).reshape (tensor .shape )
10086+ dtype = cls ._dtype_str_map [t .dtype ]
10087+ shape = t .shape
10088+ lazy = cls (meta = cls .meta_with_dtype_and_shape (dtype , shape ), args = (t ,), func = lambda r : load_tensor (r ))
10089+ return cast (torch .Tensor , lazy )
10090+
1000510091 @classmethod
1000610092 def from_remote_tensor (cls , remote_tensor : gguf .utility .RemoteTensor ):
1000710093 dtype = cls ._dtype_str_map [remote_tensor .dtype ]
0 commit comments