1010T = TypeVar ('T' )
1111no_default = object ()
1212
13- def read (input_dict : dict [str , Any ], expected_type : type , keys : str | list [str ], default = no_default ) -> T :
13+ def read (input_dict : dict [str , Any ], expected_type : type | list [type ], keys : str | list [str ], default = no_default ) -> T :
14+
15+ expected_types = expected_type if isinstance (expected_type , list ) else [expected_type ]
1416
1517 if isinstance (keys , str ): keys = [keys ]
1618
@@ -34,10 +36,10 @@ def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str],
3436 if expected_type == int and isinstance (x , float ) and x == int (x ):
3537 x = int (x )
3638
37- if isinstance ( x , expected_type ) :
38- return cast ( T , x )
39- else :
40- raise TypeError (f"Value for { key } is not of expected type { expected_type } " )
39+ for t in expected_types :
40+ if isinstance ( x , t ):
41+ return cast ( T , x )
42+ raise TypeError (f"Value for { key } is not of expected type { expected_type } " )
4143
4244 if default != no_default : return default
4345 raise ValueError (f"Missing any of the following keys: { keys } " )
@@ -105,7 +107,10 @@ class ExLlamaV2Config:
105107 attn_logit_softcapping : float | None
106108 sliding_window : int
107109 norm_head : int | None
108-
110+ l3_rope_factor : float | None
111+ l3_rope_low_freq_factor : float | None
112+ l3_rope_high_freq_factor : float | None
113+ l3_rope_original_max_position_embeddings : int | None
109114 checkpoint_fused_mlp : bool
110115 checkpoint_offset_qzeros : bool
111116
@@ -191,10 +196,13 @@ def prepare(self, no_tensors: bool = False):
191196 # Vocab params
192197
193198 self .bos_token_id = read (read_config , int , "bos_token_id" , None ) # 1
194- self .eos_token_id = read (read_config , int , "eos_token_id" , None ) # 2
199+ self .eos_token_id = read (read_config , [ int , list ] , "eos_token_id" , None ) # 2
195200 self .pad_token_id = read (read_config , int , "pad_token_id" , None ) # 0
196201 self .vocab_size = read (read_config , int , "vocab_size" )
197202
203+ if isinstance (self .eos_token_id , list ):
204+ self .eos_token_id = self .eos_token_id [0 ] # TODO: Figure out a way to maybe use all the EOS tokens somehow
205+
198206 # Standard params
199207
200208 self .initializer_range = read (read_config , float , ["initializer_range" ])
@@ -287,6 +295,13 @@ def prepare(self, no_tensors: bool = False):
287295 self .alt_rope_method = "su"
288296 # if scaling_type == "yarn":
289297 # self.scale_alpha_value = factor
298+ rope_type = rs .get ("rope_type" , None )
299+ if rope_type == "llama3" :
300+ self .alt_rope_method = "llama3"
301+ self .l3_rope_factor = rs ["factor" ]
302+ self .l3_rope_low_freq_factor = rs ["low_freq_factor" ]
303+ self .l3_rope_high_freq_factor = rs ["high_freq_factor" ]
304+ self .l3_rope_original_max_position_embeddings = rs ["original_max_position_embeddings" ]
290305
291306 # Checkpoint format (for GPTQ models)
292307
0 commit comments