@@ -54,7 +54,8 @@ def __init__(self,
5454 f_beg : int = None ,
5555 f_end : int = None ,
5656 is_sub_module : bool = True ,
57- altpack_qkv : bool = False ):
57+ altpack_qkv : bool = False ,
58+ normalize_unq : bool = False ):
5859 super ().__init__ (model , key )
5960
6061 self .is_sub_module = is_sub_module
@@ -89,6 +90,7 @@ def __init__(self,
8990 self .altpack_qkv = altpack_qkv
9091
9192 self .assumed_footprint = in_features * (out_features + self .padding ) * 2 + 128
93+ self .normalize_unq = normalize_unq
9294
9395
9496 @torch .inference_mode
@@ -125,6 +127,8 @@ def load(self,
125127
126128 elif isinstance (w , nn .Parameter ):
127129 assert not self .has_bias , self .key + " has no bias tensor but bias is expected"
130+ if self .normalize_unq :
131+ w = self .normalize (w )
128132 if self .padding > 0 : w = nn .Parameter (F .pad (w .data , (0 , 0 , 0 , self .padding )).contiguous ())
129133 if not self .model .config .load_in_q4 or not ".layers." in self .key :
130134 self .linear = nn .Linear (self .in_features , self .out_features , self .has_bias , device = "meta" , dtype = torch .float16 )
@@ -138,6 +142,8 @@ def load(self,
138142
139143 elif isinstance (w , tuple ):
140144 assert self .has_bias , self .key + " has bias tensor but bias is not expected"
145+ if self .normalize_unq :
146+ w = self .normalize (w [0 ]), w [1 ]
141147 ww = w [0 ]
142148 wb = w [1 ]
143149 if self .padding > 0 :
@@ -154,6 +160,10 @@ def load(self,
154160 self .fp16_bias = wb
155161
156162
163+ def normalize (self , w : torch .Tensor ):
164+ return nn .functional .normalize (w )
165+
166+
157167 def matrix_shape (self ):
158168
159169 return self .in_features , self .out_features
0 commit comments