@@ -60,7 +60,8 @@ def device(self) -> str:
6060 def load_multi (self ,
6161 key : str ,
6262 keys : list [str ],
63- measure : bool = False ) -> int | dict [str : torch .Tensor ]:
63+ measure : bool = False ,
64+ cpu : bool = False ) -> int | dict [str : torch .Tensor ]:
6465
6566 tensors = {}
6667 submap = {}
@@ -85,13 +86,14 @@ def load_multi(self,
8586 if measure :
8687 size += stfile .measure (key + "." + k )
8788 else :
88- tensors [k ] = stfile .get_tensor (key + "." + k , device = self .device ())
89+ tensors [k ] = stfile .get_tensor (key + "." + k , device = self .device () if not cpu else "cpu" )
8990
9091 return size if measure else tensors
9192
9293
9394 def load_weight (self ,
94- override_key : str | None = None ):
95+ override_key : str | None = None ,
96+ cpu : bool = False ):
9597
9698 if override_key is not None :
9799 keys = [override_key ]
@@ -105,14 +107,14 @@ def load_weight(self,
105107 # EXL2
106108
107109 if key + ".q_weight" in self .model .config .tensor_file_map :
108- qtensors = self .load_multi (key , ["q_weight" , "q_invperm" , "q_scale" , "q_scale_max" , "q_groups" , "q_perm" , "bias" ])
110+ qtensors = self .load_multi (key , ["q_weight" , "q_invperm" , "q_scale" , "q_scale_max" , "q_groups" , "q_perm" , "bias" ], cpu = cpu )
109111 qtensors ["q_perm" ] = torch .argsort (qtensors ["q_invperm" ]).to (torch .int )
110112 return qtensors
111113
112114 # GPTQ
113115
114116 if key + ".qweight" in self .model .config .tensor_file_map :
115- qtensors = self .load_multi (key , ["qweight" , "qzeros" , "scales" , "g_idx" , "bias" ])
117+ qtensors = self .load_multi (key , ["qweight" , "qzeros" , "scales" , "g_idx" , "bias" ], cpu = cpu )
116118 if "bias" in qtensors and torch .all (qtensors ["bias" ].eq (0 )):
117119 del qtensors ["bias" ]
118120 qtensors ["scales" ] = qtensors ["scales" ].half ()
@@ -122,14 +124,14 @@ def load_weight(self,
122124
123125 if key + ".weight" in self .model .config .tensor_file_map :
124126 if key + ".bias" in self .model .config .tensor_file_map :
125- tensors = self .load_multi (key , ["weight" , "bias" ])
127+ tensors = self .load_multi (key , ["weight" , "bias" ], cpu = cpu )
126128 tensor = tensors ["weight" ].half ()
127129 bias = tensors ["bias" ].half ()
128130 if self .model .config .arch .orig_weights_transposed and len (tensor .shape ) == 2 :
129131 tensor = tensor .T
130132 return nn .Parameter (tensor , requires_grad = False ), nn .Parameter (bias , requires_grad = False )
131133 else :
132- tensors = self .load_multi (key , ["weight" ])
134+ tensors = self .load_multi (key , ["weight" ], cpu = cpu )
133135 tensor = tensors ["weight" ].half ()
134136 # if self.model.config.arch.orig_weights_transposed:
135137 # tensor = tensor.T
0 commit comments