Skip to content

Commit 6d973ef

Browse files
authored
argcheck: restrict the type of elements in a list (#2945)
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent e9be507 commit 6d973ef

File tree

2 files changed

+94
-44
lines changed

2 files changed

+94
-44
lines changed

deepmd/utils/argcheck.py

Lines changed: 93 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def type_embedding_args():
5656
doc_trainable = "If the parameters in the embedding net are trainable"
5757

5858
return [
59-
Argument("neuron", list, optional=True, default=[8], doc=doc_neuron),
59+
Argument("neuron", List[int], optional=True, default=[8], doc=doc_neuron),
6060
Argument(
6161
"activation_function",
6262
str,
@@ -77,9 +77,9 @@ def spin_args():
7777
doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin"
7878

7979
return [
80-
Argument("use_spin", list, doc=doc_use_spin),
81-
Argument("spin_norm", list, doc=doc_spin_norm),
82-
Argument("virtual_len", list, doc=doc_virtual_len),
80+
Argument("use_spin", List[bool], doc=doc_use_spin),
81+
Argument("spin_norm", List[float], doc=doc_spin_norm),
82+
Argument("virtual_len", List[float], doc=doc_virtual_len),
8383
]
8484

8585

@@ -159,10 +159,10 @@ def descrpt_local_frame_args():
159159
- axis_rule[i*6+5]: index of the axis atom defining the second axis. Note that the neighbors with the same class and type are sorted according to their relative distance."
160160

161161
return [
162-
Argument("sel_a", list, optional=False, doc=doc_sel_a),
163-
Argument("sel_r", list, optional=False, doc=doc_sel_r),
162+
Argument("sel_a", List[int], optional=False, doc=doc_sel_a),
163+
Argument("sel_r", List[int], optional=False, doc=doc_sel_r),
164164
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
165-
Argument("axis_rule", list, optional=False, doc=doc_axis_rule),
165+
Argument("axis_rule", List[int], optional=False, doc=doc_axis_rule),
166166
]
167167

168168

@@ -185,10 +185,12 @@ def descrpt_se_a_args():
185185
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"
186186

187187
return [
188-
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
188+
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
189189
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
190190
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
191-
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
191+
Argument(
192+
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
193+
),
192194
Argument(
193195
"axis_neuron",
194196
int,
@@ -212,7 +214,11 @@ def descrpt_se_a_args():
212214
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
213215
Argument("seed", [int, None], optional=True, doc=doc_seed),
214216
Argument(
215-
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
217+
"exclude_types",
218+
List[List[int]],
219+
optional=True,
220+
default=[],
221+
doc=doc_exclude_types,
216222
),
217223
Argument(
218224
"set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
@@ -236,10 +242,12 @@ def descrpt_se_t_args():
236242
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"
237243

238244
return [
239-
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
245+
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
240246
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
241247
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
242-
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
248+
Argument(
249+
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
250+
),
243251
Argument(
244252
"activation_function",
245253
str,
@@ -289,10 +297,12 @@ def descrpt_se_r_args():
289297
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"
290298

291299
return [
292-
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
300+
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
293301
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
294302
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
295-
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
303+
Argument(
304+
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
305+
),
296306
Argument(
297307
"activation_function",
298308
str,
@@ -308,7 +318,11 @@ def descrpt_se_r_args():
308318
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
309319
Argument("seed", [int, None], optional=True, doc=doc_seed),
310320
Argument(
311-
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
321+
"exclude_types",
322+
List[List[int]],
323+
optional=True,
324+
default=[],
325+
doc=doc_exclude_types,
312326
),
313327
Argument(
314328
"set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
@@ -356,10 +370,14 @@ def descrpt_se_atten_common_args():
356370
doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix"
357371

358372
return [
359-
Argument("sel", [int, list, str], optional=True, default="auto", doc=doc_sel),
373+
Argument(
374+
"sel", [int, List[int], str], optional=True, default="auto", doc=doc_sel
375+
),
360376
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
361377
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
362-
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
378+
Argument(
379+
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
380+
),
363381
Argument(
364382
"axis_neuron",
365383
int,
@@ -383,7 +401,11 @@ def descrpt_se_atten_common_args():
383401
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
384402
Argument("seed", [int, None], optional=True, doc=doc_seed),
385403
Argument(
386-
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
404+
"exclude_types",
405+
List[List[int]],
406+
optional=True,
407+
default=[],
408+
doc=doc_exclude_types,
387409
),
388410
Argument("attn", int, optional=True, default=128, doc=doc_attn),
389411
Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer),
@@ -454,8 +476,10 @@ def descrpt_se_a_mask_args():
454476
doc_seed = "Random seed for parameter initialization"
455477

456478
return [
457-
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
458-
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
479+
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
480+
Argument(
481+
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
482+
),
459483
Argument(
460484
"axis_neuron",
461485
int,
@@ -476,7 +500,11 @@ def descrpt_se_a_mask_args():
476500
"type_one_side", bool, optional=True, default=False, doc=doc_type_one_side
477501
),
478502
Argument(
479-
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
503+
"exclude_types",
504+
List[List[int]],
505+
optional=True,
506+
default=[],
507+
doc=doc_exclude_types,
480508
),
481509
Argument("precision", str, optional=True, default="default", doc=doc_precision),
482510
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
@@ -525,7 +553,7 @@ def fitting_ener():
525553
doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
526554
doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\
527555
- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\
528-
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of tihs list should be equal to len(`neuron`)+1."
556+
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
529557
doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details."
530558
doc_seed = "Random seed for parameter initialization of the fitting net"
531559
doc_atom_ener = "Specify the atomic energy in vacuum for each type"
@@ -547,7 +575,7 @@ def fitting_ener():
547575
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
548576
Argument(
549577
"neuron",
550-
list,
578+
List[int],
551579
optional=True,
552580
default=[120, 120, 120],
553581
alias=["n_neuron"],
@@ -563,14 +591,24 @@ def fitting_ener():
563591
Argument("precision", str, optional=True, default="default", doc=doc_precision),
564592
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
565593
Argument(
566-
"trainable", [list, bool], optional=True, default=True, doc=doc_trainable
594+
"trainable",
595+
[List[bool], bool],
596+
optional=True,
597+
default=True,
598+
doc=doc_trainable,
567599
),
568600
Argument(
569601
"rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond
570602
),
571603
Argument("seed", [int, None], optional=True, doc=doc_seed),
572-
Argument("atom_ener", list, optional=True, default=[], doc=doc_atom_ener),
573-
Argument("layer_name", list, optional=True, doc=doc_layer_name),
604+
Argument(
605+
"atom_ener",
606+
List[Optional[float]],
607+
optional=True,
608+
default=[],
609+
doc=doc_atom_ener,
610+
),
611+
Argument("layer_name", List[str], optional=True, doc=doc_layer_name),
574612
Argument(
575613
"use_aparam_as_mask",
576614
bool,
@@ -602,7 +640,7 @@ def fitting_dos():
602640
Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam),
603641
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
604642
Argument(
605-
"neuron", list, optional=True, default=[120, 120, 120], doc=doc_neuron
643+
"neuron", List[int], optional=True, default=[120, 120, 120], doc=doc_neuron
606644
),
607645
Argument(
608646
"activation_function",
@@ -614,7 +652,11 @@ def fitting_dos():
614652
Argument("precision", str, optional=True, default="float64", doc=doc_precision),
615653
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
616654
Argument(
617-
"trainable", [list, bool], optional=True, default=True, doc=doc_trainable
655+
"trainable",
656+
[List[bool], bool],
657+
optional=True,
658+
default=True,
659+
doc=doc_trainable,
618660
),
619661
Argument(
620662
"rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond
@@ -642,7 +684,7 @@ def fitting_polar():
642684
return [
643685
Argument(
644686
"neuron",
645-
list,
687+
List[int],
646688
optional=True,
647689
default=[120, 120, 120],
648690
alias=["n_neuron"],
@@ -658,12 +700,14 @@ def fitting_polar():
658700
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
659701
Argument("precision", str, optional=True, default="default", doc=doc_precision),
660702
Argument("fit_diag", bool, optional=True, default=True, doc=doc_fit_diag),
661-
Argument("scale", [list, float], optional=True, default=1.0, doc=doc_scale),
703+
Argument(
704+
"scale", [List[float], float], optional=True, default=1.0, doc=doc_scale
705+
),
662706
# Argument("diag_shift", [list,float], optional = True, default = 0.0, doc = doc_diag_shift),
663707
Argument("shift_diag", bool, optional=True, default=True, doc=doc_shift_diag),
664708
Argument(
665709
"sel_type",
666-
[list, int, None],
710+
[List[int], int, None],
667711
optional=True,
668712
alias=["pol_type"],
669713
doc=doc_sel_type,
@@ -687,7 +731,7 @@ def fitting_dipole():
687731
return [
688732
Argument(
689733
"neuron",
690-
list,
734+
List[int],
691735
optional=True,
692736
default=[120, 120, 120],
693737
alias=["n_neuron"],
@@ -704,7 +748,7 @@ def fitting_dipole():
704748
Argument("precision", str, optional=True, default="default", doc=doc_precision),
705749
Argument(
706750
"sel_type",
707-
[list, int, None],
751+
[List[int], int, None],
708752
optional=True,
709753
alias=["dipole_type"],
710754
doc=doc_sel_type,
@@ -740,8 +784,10 @@ def modifier_dipole_charge():
740784

741785
return [
742786
Argument("model_name", str, optional=False, doc=doc_model_name),
743-
Argument("model_charge_map", list, optional=False, doc=doc_model_charge_map),
744-
Argument("sys_charge_map", list, optional=False, doc=doc_sys_charge_map),
787+
Argument(
788+
"model_charge_map", List[float], optional=False, doc=doc_model_charge_map
789+
),
790+
Argument("sys_charge_map", List[float], optional=False, doc=doc_sys_charge_map),
745791
Argument("ewald_beta", float, optional=True, default=0.4, doc=doc_ewald_beta),
746792
Argument("ewald_h", float, optional=True, default=1.0, doc=doc_ewald_h),
747793
]
@@ -770,7 +816,7 @@ def model_compression():
770816

771817
return [
772818
Argument("model_file", str, optional=False, doc=doc_model_file),
773-
Argument("table_config", list, optional=False, doc=doc_table_config),
819+
Argument("table_config", List[float], optional=False, doc=doc_table_config),
774820
Argument("min_nbor_dist", float, optional=False, doc=doc_min_nbor_dist),
775821
]
776822

@@ -814,7 +860,7 @@ def model_args(exclude_hybrid=False):
814860
"model",
815861
dict,
816862
[
817-
Argument("type_map", list, optional=True, doc=doc_type_map),
863+
Argument("type_map", List[str], optional=True, doc=doc_type_map),
818864
Argument(
819865
"data_stat_nbatch",
820866
int,
@@ -1456,11 +1502,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
14561502
)
14571503

14581504
args = [
1459-
Argument("systems", [list, str], optional=False, default=".", doc=doc_systems),
1505+
Argument(
1506+
"systems", [List[str], str], optional=False, default=".", doc=doc_systems
1507+
),
14601508
Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix),
14611509
Argument(
14621510
"batch_size",
1463-
[list, int, str],
1511+
[List[int], int, str],
14641512
optional=True,
14651513
default="auto",
14661514
doc=doc_batch_size,
@@ -1477,7 +1525,7 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
14771525
),
14781526
Argument(
14791527
"sys_probs",
1480-
list,
1528+
List[float],
14811529
optional=True,
14821530
default=None,
14831531
doc=doc_sys_probs,
@@ -1521,11 +1569,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
15211569
doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period."
15221570

15231571
args = [
1524-
Argument("systems", [list, str], optional=False, default=".", doc=doc_systems),
1572+
Argument(
1573+
"systems", [List[str], str], optional=False, default=".", doc=doc_systems
1574+
),
15251575
Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix),
15261576
Argument(
15271577
"batch_size",
1528-
[list, int, str],
1578+
[List[int], int, str],
15291579
optional=True,
15301580
default="auto",
15311581
doc=doc_batch_size,
@@ -1542,7 +1592,7 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
15421592
),
15431593
Argument(
15441594
"sys_probs",
1545-
list,
1595+
List[float],
15461596
optional=True,
15471597
default=None,
15481598
doc=doc_sys_probs,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dependencies = [
3838
'numpy',
3939
'scipy',
4040
'pyyaml',
41-
'dargs >= 0.3.5',
41+
'dargs >= 0.4.1',
4242
'python-hostlist >= 1.21',
4343
'typing_extensions; python_version < "3.8"',
4444
'importlib_metadata>=1.4; python_version < "3.8"',

0 commit comments

Comments
 (0)