@@ -339,19 +339,21 @@ def make_q_matrix(w: dict,
339339 if "q_group_map" not in w :
340340 w ["q_group_map" ] = make_group_map (w ["q_groups" ], w ["q_weight" ].shape [0 ])
341341
342- return ext_c .make_q_matrix (w ["q_weight" ],
343- w .get ("q_perm" , none_tensor ),
344- w .get ("q_invperm" , none_tensor ),
345- w ["q_scale" ],
346- w ["q_scale_max" ],
347- w ["q_groups" ],
348- w ["q_group_map" ],
349- none_tensor ,
350- none_tensor ,
351- none_tensor ,
352- w .get ("bias" , none_tensor ),
353- temp_dq ,
354- max_dq_rows )
342+ return ext_c .make_q_matrix (
343+ w ["q_weight" ],
344+ w .get ("q_perm" , none_tensor ),
345+ w .get ("q_invperm" , none_tensor ),
346+ w ["q_scale" ],
347+ w ["q_scale_max" ],
348+ w ["q_groups" ],
349+ w ["q_group_map" ],
350+ none_tensor ,
351+ none_tensor ,
352+ none_tensor ,
353+ w .get ("bias" , none_tensor ),
354+ temp_dq ,
355+ max_dq_rows
356+ )
355357
356358 # GPTQ
357359
@@ -370,36 +372,38 @@ def make_q_matrix(w: dict,
370372 w ["q_perm" ] = torch .empty ((w ["qweight" ].shape [0 ] * 8 ,), dtype = torch .short , device = w ["qweight" ].device )
371373 w ["q_invperm" ] = torch .empty_like (w ["q_perm" ])
372374
373- return ext_c .make_q_matrix (w ["qweight" ],
374- w ["q_perm" ],
375- w ["q_invperm" ],
376- none_tensor ,
377- none_tensor ,
378- none_tensor ,
379- none_tensor ,
380- w ["qzeros" ],
381- w ["scales" ],
382- w ["g_idx" ].cpu (),
383- w .get ("bias" , none_tensor ),
384- temp_dq ,
385- max_dq_rows )
375+ return ext_c .make_q_matrix (
376+ w ["qweight" ],
377+ w ["q_perm" ],
378+ w ["q_invperm" ],
379+ none_tensor ,
380+ none_tensor ,
381+ none_tensor ,
382+ none_tensor ,
383+ w ["qzeros" ],
384+ w ["scales" ],
385+ w ["g_idx" ].cpu (),
386+ w .get ("bias" , none_tensor ),
387+ temp_dq ,
388+ max_dq_rows
389+ )
386390
387391 # GPTQ without g_idx
388392
389393 else :
390394
391- return ext_c .make_q_matrix (w [ "qweight" ],
392- none_tensor ,
393- none_tensor ,
394- none_tensor ,
395- none_tensor ,
396- none_tensor ,
397- none_tensor ,
398- w [ "qzeros" ] ,
399- w ["scales " ],
400- none_tensor ,
401- w . get ( "bias" , none_tensor ) ,
402- temp_dq ,
403- max_dq_rows )
404-
405-
395+ return ext_c .make_q_matrix (
396+ w [ "qweight" ] ,
397+ none_tensor ,
398+ none_tensor ,
399+ none_tensor ,
400+ none_tensor ,
401+ none_tensor ,
402+ none_tensor ,
403+ w ["qzeros " ],
404+ w [ "scales" ] ,
405+ none_tensor ,
406+ w . get ( "bias" , none_tensor ) ,
407+ temp_dq ,
408+ max_dq_rows
409+ )
0 commit comments