@@ -56,7 +56,7 @@ def type_embedding_args():
5656 doc_trainable = "If the parameters in the embedding net are trainable"
5757
5858 return [
59- Argument ("neuron" , list , optional = True , default = [8 ], doc = doc_neuron ),
59+ Argument ("neuron" , List [ int ] , optional = True , default = [8 ], doc = doc_neuron ),
6060 Argument (
6161 "activation_function" ,
6262 str ,
@@ -77,9 +77,9 @@ def spin_args():
7777 doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin"
7878
7979 return [
80- Argument ("use_spin" , list , doc = doc_use_spin ),
81- Argument ("spin_norm" , list , doc = doc_spin_norm ),
82- Argument ("virtual_len" , list , doc = doc_virtual_len ),
80+ Argument ("use_spin" , List [ bool ] , doc = doc_use_spin ),
81+ Argument ("spin_norm" , List [ float ] , doc = doc_spin_norm ),
82+ Argument ("virtual_len" , List [ float ] , doc = doc_virtual_len ),
8383 ]
8484
8585
@@ -159,10 +159,10 @@ def descrpt_local_frame_args():
159159 - axis_rule[i*6+5]: index of the axis atom defining the second axis. Note that the neighbors with the same class and type are sorted according to their relative distance."
160160
161161 return [
162- Argument ("sel_a" , list , optional = False , doc = doc_sel_a ),
163- Argument ("sel_r" , list , optional = False , doc = doc_sel_r ),
162+ Argument ("sel_a" , List [ int ] , optional = False , doc = doc_sel_a ),
163+ Argument ("sel_r" , List [ int ] , optional = False , doc = doc_sel_r ),
164164 Argument ("rcut" , float , optional = True , default = 6.0 , doc = doc_rcut ),
165- Argument ("axis_rule" , list , optional = False , doc = doc_axis_rule ),
165+ Argument ("axis_rule" , List [ int ] , optional = False , doc = doc_axis_rule ),
166166 ]
167167
168168
@@ -185,10 +185,12 @@ def descrpt_se_a_args():
185185 doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"
186186
187187 return [
188- Argument ("sel" , [list , str ], optional = True , default = "auto" , doc = doc_sel ),
188+ Argument ("sel" , [List [ int ] , str ], optional = True , default = "auto" , doc = doc_sel ),
189189 Argument ("rcut" , float , optional = True , default = 6.0 , doc = doc_rcut ),
190190 Argument ("rcut_smth" , float , optional = True , default = 0.5 , doc = doc_rcut_smth ),
191- Argument ("neuron" , list , optional = True , default = [10 , 20 , 40 ], doc = doc_neuron ),
191+ Argument (
192+ "neuron" , List [int ], optional = True , default = [10 , 20 , 40 ], doc = doc_neuron
193+ ),
192194 Argument (
193195 "axis_neuron" ,
194196 int ,
@@ -212,7 +214,11 @@ def descrpt_se_a_args():
212214 Argument ("trainable" , bool , optional = True , default = True , doc = doc_trainable ),
213215 Argument ("seed" , [int , None ], optional = True , doc = doc_seed ),
214216 Argument (
215- "exclude_types" , list , optional = True , default = [], doc = doc_exclude_types
217+ "exclude_types" ,
218+ List [List [int ]],
219+ optional = True ,
220+ default = [],
221+ doc = doc_exclude_types ,
216222 ),
217223 Argument (
218224 "set_davg_zero" , bool , optional = True , default = False , doc = doc_set_davg_zero
@@ -236,10 +242,12 @@ def descrpt_se_t_args():
236242 doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"
237243
238244 return [
239- Argument ("sel" , [list , str ], optional = True , default = "auto" , doc = doc_sel ),
245+ Argument ("sel" , [List [ int ] , str ], optional = True , default = "auto" , doc = doc_sel ),
240246 Argument ("rcut" , float , optional = True , default = 6.0 , doc = doc_rcut ),
241247 Argument ("rcut_smth" , float , optional = True , default = 0.5 , doc = doc_rcut_smth ),
242- Argument ("neuron" , list , optional = True , default = [10 , 20 , 40 ], doc = doc_neuron ),
248+ Argument (
249+ "neuron" , List [int ], optional = True , default = [10 , 20 , 40 ], doc = doc_neuron
250+ ),
243251 Argument (
244252 "activation_function" ,
245253 str ,
@@ -289,10 +297,12 @@ def descrpt_se_r_args():
289297 doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"
290298
291299 return [
292- Argument ("sel" , [list , str ], optional = True , default = "auto" , doc = doc_sel ),
300+ Argument ("sel" , [List [ int ] , str ], optional = True , default = "auto" , doc = doc_sel ),
293301 Argument ("rcut" , float , optional = True , default = 6.0 , doc = doc_rcut ),
294302 Argument ("rcut_smth" , float , optional = True , default = 0.5 , doc = doc_rcut_smth ),
295- Argument ("neuron" , list , optional = True , default = [10 , 20 , 40 ], doc = doc_neuron ),
303+ Argument (
304+ "neuron" , List [int ], optional = True , default = [10 , 20 , 40 ], doc = doc_neuron
305+ ),
296306 Argument (
297307 "activation_function" ,
298308 str ,
@@ -308,7 +318,11 @@ def descrpt_se_r_args():
308318 Argument ("trainable" , bool , optional = True , default = True , doc = doc_trainable ),
309319 Argument ("seed" , [int , None ], optional = True , doc = doc_seed ),
310320 Argument (
311- "exclude_types" , list , optional = True , default = [], doc = doc_exclude_types
321+ "exclude_types" ,
322+ List [List [int ]],
323+ optional = True ,
324+ default = [],
325+ doc = doc_exclude_types ,
312326 ),
313327 Argument (
314328 "set_davg_zero" , bool , optional = True , default = False , doc = doc_set_davg_zero
@@ -356,10 +370,14 @@ def descrpt_se_atten_common_args():
356370 doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix"
357371
358372 return [
359- Argument ("sel" , [int , list , str ], optional = True , default = "auto" , doc = doc_sel ),
373+ Argument (
374+ "sel" , [int , List [int ], str ], optional = True , default = "auto" , doc = doc_sel
375+ ),
360376 Argument ("rcut" , float , optional = True , default = 6.0 , doc = doc_rcut ),
361377 Argument ("rcut_smth" , float , optional = True , default = 0.5 , doc = doc_rcut_smth ),
362- Argument ("neuron" , list , optional = True , default = [10 , 20 , 40 ], doc = doc_neuron ),
378+ Argument (
379+ "neuron" , List [int ], optional = True , default = [10 , 20 , 40 ], doc = doc_neuron
380+ ),
363381 Argument (
364382 "axis_neuron" ,
365383 int ,
@@ -383,7 +401,11 @@ def descrpt_se_atten_common_args():
383401 Argument ("trainable" , bool , optional = True , default = True , doc = doc_trainable ),
384402 Argument ("seed" , [int , None ], optional = True , doc = doc_seed ),
385403 Argument (
386- "exclude_types" , list , optional = True , default = [], doc = doc_exclude_types
404+ "exclude_types" ,
405+ List [List [int ]],
406+ optional = True ,
407+ default = [],
408+ doc = doc_exclude_types ,
387409 ),
388410 Argument ("attn" , int , optional = True , default = 128 , doc = doc_attn ),
389411 Argument ("attn_layer" , int , optional = True , default = 2 , doc = doc_attn_layer ),
@@ -454,8 +476,10 @@ def descrpt_se_a_mask_args():
454476 doc_seed = "Random seed for parameter initialization"
455477
456478 return [
457- Argument ("sel" , [list , str ], optional = True , default = "auto" , doc = doc_sel ),
458- Argument ("neuron" , list , optional = True , default = [10 , 20 , 40 ], doc = doc_neuron ),
479+ Argument ("sel" , [List [int ], str ], optional = True , default = "auto" , doc = doc_sel ),
480+ Argument (
481+ "neuron" , List [int ], optional = True , default = [10 , 20 , 40 ], doc = doc_neuron
482+ ),
459483 Argument (
460484 "axis_neuron" ,
461485 int ,
@@ -476,7 +500,11 @@ def descrpt_se_a_mask_args():
476500 "type_one_side" , bool , optional = True , default = False , doc = doc_type_one_side
477501 ),
478502 Argument (
479- "exclude_types" , list , optional = True , default = [], doc = doc_exclude_types
503+ "exclude_types" ,
504+ List [List [int ]],
505+ optional = True ,
506+ default = [],
507+ doc = doc_exclude_types ,
480508 ),
481509 Argument ("precision" , str , optional = True , default = "default" , doc = doc_precision ),
482510 Argument ("trainable" , bool , optional = True , default = True , doc = doc_trainable ),
@@ -525,7 +553,7 @@ def fitting_ener():
525553 doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
526554 doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n \n \
527555 - bool: True if all parameters of the fitting net are trainable, False otherwise.\n \n \
528- - list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of tihs list should be equal to len(`neuron`)+1."
556+ - list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
529557 doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details."
530558 doc_seed = "Random seed for parameter initialization of the fitting net"
531559 doc_atom_ener = "Specify the atomic energy in vacuum for each type"
@@ -547,7 +575,7 @@ def fitting_ener():
547575 Argument ("numb_aparam" , int , optional = True , default = 0 , doc = doc_numb_aparam ),
548576 Argument (
549577 "neuron" ,
550- list ,
578+ List [ int ] ,
551579 optional = True ,
552580 default = [120 , 120 , 120 ],
553581 alias = ["n_neuron" ],
@@ -563,14 +591,24 @@ def fitting_ener():
563591 Argument ("precision" , str , optional = True , default = "default" , doc = doc_precision ),
564592 Argument ("resnet_dt" , bool , optional = True , default = True , doc = doc_resnet_dt ),
565593 Argument (
566- "trainable" , [list , bool ], optional = True , default = True , doc = doc_trainable
594+ "trainable" ,
595+ [List [bool ], bool ],
596+ optional = True ,
597+ default = True ,
598+ doc = doc_trainable ,
567599 ),
568600 Argument (
569601 "rcond" , [float , type (None )], optional = True , default = None , doc = doc_rcond
570602 ),
571603 Argument ("seed" , [int , None ], optional = True , doc = doc_seed ),
572- Argument ("atom_ener" , list , optional = True , default = [], doc = doc_atom_ener ),
573- Argument ("layer_name" , list , optional = True , doc = doc_layer_name ),
604+ Argument (
605+ "atom_ener" ,
606+ List [Optional [float ]],
607+ optional = True ,
608+ default = [],
609+ doc = doc_atom_ener ,
610+ ),
611+ Argument ("layer_name" , List [str ], optional = True , doc = doc_layer_name ),
574612 Argument (
575613 "use_aparam_as_mask" ,
576614 bool ,
@@ -602,7 +640,7 @@ def fitting_dos():
602640 Argument ("numb_fparam" , int , optional = True , default = 0 , doc = doc_numb_fparam ),
603641 Argument ("numb_aparam" , int , optional = True , default = 0 , doc = doc_numb_aparam ),
604642 Argument (
605- "neuron" , list , optional = True , default = [120 , 120 , 120 ], doc = doc_neuron
643+ "neuron" , List [ int ] , optional = True , default = [120 , 120 , 120 ], doc = doc_neuron
606644 ),
607645 Argument (
608646 "activation_function" ,
@@ -614,7 +652,11 @@ def fitting_dos():
614652 Argument ("precision" , str , optional = True , default = "float64" , doc = doc_precision ),
615653 Argument ("resnet_dt" , bool , optional = True , default = True , doc = doc_resnet_dt ),
616654 Argument (
617- "trainable" , [list , bool ], optional = True , default = True , doc = doc_trainable
655+ "trainable" ,
656+ [List [bool ], bool ],
657+ optional = True ,
658+ default = True ,
659+ doc = doc_trainable ,
618660 ),
619661 Argument (
620662 "rcond" , [float , type (None )], optional = True , default = None , doc = doc_rcond
@@ -642,7 +684,7 @@ def fitting_polar():
642684 return [
643685 Argument (
644686 "neuron" ,
645- list ,
687+ List [ int ] ,
646688 optional = True ,
647689 default = [120 , 120 , 120 ],
648690 alias = ["n_neuron" ],
@@ -658,12 +700,14 @@ def fitting_polar():
658700 Argument ("resnet_dt" , bool , optional = True , default = True , doc = doc_resnet_dt ),
659701 Argument ("precision" , str , optional = True , default = "default" , doc = doc_precision ),
660702 Argument ("fit_diag" , bool , optional = True , default = True , doc = doc_fit_diag ),
661- Argument ("scale" , [list , float ], optional = True , default = 1.0 , doc = doc_scale ),
703+ Argument (
704+ "scale" , [List [float ], float ], optional = True , default = 1.0 , doc = doc_scale
705+ ),
662706 # Argument("diag_shift", [list,float], optional = True, default = 0.0, doc = doc_diag_shift),
663707 Argument ("shift_diag" , bool , optional = True , default = True , doc = doc_shift_diag ),
664708 Argument (
665709 "sel_type" ,
666- [list , int , None ],
710+ [List [ int ] , int , None ],
667711 optional = True ,
668712 alias = ["pol_type" ],
669713 doc = doc_sel_type ,
@@ -687,7 +731,7 @@ def fitting_dipole():
687731 return [
688732 Argument (
689733 "neuron" ,
690- list ,
734+ List [ int ] ,
691735 optional = True ,
692736 default = [120 , 120 , 120 ],
693737 alias = ["n_neuron" ],
@@ -704,7 +748,7 @@ def fitting_dipole():
704748 Argument ("precision" , str , optional = True , default = "default" , doc = doc_precision ),
705749 Argument (
706750 "sel_type" ,
707- [list , int , None ],
751+ [List [ int ] , int , None ],
708752 optional = True ,
709753 alias = ["dipole_type" ],
710754 doc = doc_sel_type ,
@@ -740,8 +784,10 @@ def modifier_dipole_charge():
740784
741785 return [
742786 Argument ("model_name" , str , optional = False , doc = doc_model_name ),
743- Argument ("model_charge_map" , list , optional = False , doc = doc_model_charge_map ),
744- Argument ("sys_charge_map" , list , optional = False , doc = doc_sys_charge_map ),
787+ Argument (
788+ "model_charge_map" , List [float ], optional = False , doc = doc_model_charge_map
789+ ),
790+ Argument ("sys_charge_map" , List [float ], optional = False , doc = doc_sys_charge_map ),
745791 Argument ("ewald_beta" , float , optional = True , default = 0.4 , doc = doc_ewald_beta ),
746792 Argument ("ewald_h" , float , optional = True , default = 1.0 , doc = doc_ewald_h ),
747793 ]
@@ -770,7 +816,7 @@ def model_compression():
770816
771817 return [
772818 Argument ("model_file" , str , optional = False , doc = doc_model_file ),
773- Argument ("table_config" , list , optional = False , doc = doc_table_config ),
819+ Argument ("table_config" , List [ float ] , optional = False , doc = doc_table_config ),
774820 Argument ("min_nbor_dist" , float , optional = False , doc = doc_min_nbor_dist ),
775821 ]
776822
@@ -814,7 +860,7 @@ def model_args(exclude_hybrid=False):
814860 "model" ,
815861 dict ,
816862 [
817- Argument ("type_map" , list , optional = True , doc = doc_type_map ),
863+ Argument ("type_map" , List [ str ] , optional = True , doc = doc_type_map ),
818864 Argument (
819865 "data_stat_nbatch" ,
820866 int ,
@@ -1456,11 +1502,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
14561502 )
14571503
14581504 args = [
1459- Argument ("systems" , [list , str ], optional = False , default = "." , doc = doc_systems ),
1505+ Argument (
1506+ "systems" , [List [str ], str ], optional = False , default = "." , doc = doc_systems
1507+ ),
14601508 Argument ("set_prefix" , str , optional = True , default = "set" , doc = doc_set_prefix ),
14611509 Argument (
14621510 "batch_size" ,
1463- [list , int , str ],
1511+ [List [ int ] , int , str ],
14641512 optional = True ,
14651513 default = "auto" ,
14661514 doc = doc_batch_size ,
@@ -1477,7 +1525,7 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
14771525 ),
14781526 Argument (
14791527 "sys_probs" ,
1480- list ,
1528+ List [ float ] ,
14811529 optional = True ,
14821530 default = None ,
14831531 doc = doc_sys_probs ,
@@ -1521,11 +1569,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
15211569 doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period."
15221570
15231571 args = [
1524- Argument ("systems" , [list , str ], optional = False , default = "." , doc = doc_systems ),
1572+ Argument (
1573+ "systems" , [List [str ], str ], optional = False , default = "." , doc = doc_systems
1574+ ),
15251575 Argument ("set_prefix" , str , optional = True , default = "set" , doc = doc_set_prefix ),
15261576 Argument (
15271577 "batch_size" ,
1528- [list , int , str ],
1578+ [List [ int ] , int , str ],
15291579 optional = True ,
15301580 default = "auto" ,
15311581 doc = doc_batch_size ,
@@ -1542,7 +1592,7 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
15421592 ),
15431593 Argument (
15441594 "sys_probs" ,
1545- list ,
1595+ List [ float ] ,
15461596 optional = True ,
15471597 default = None ,
15481598 doc = doc_sys_probs ,
0 commit comments