Skip to content

Commit 2ce3276

Browse files
njzjzCopilot
andauthored
fix: add numb_fparam & numb_aparam to dipole & polar fitting (#4405)
Fix #4396. Fix #4397. Fix #4398. Throw errors for TF. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced new parameters `numb_fparam` and `numb_aparam` for improved fitting configurations in both dipole and polar fitting classes. - Added methods to retrieve the values of the new parameters and enhanced input requirement management. - **Documentation** - Updated training documentation to clarify the handling of new parameters and their limitations in the TensorFlow backend. - **Bug Fixes** - Updated test configurations to reflect the new parameter structure, ensuring consistency across tests for dipole and polar models. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent d38c398 commit 2ce3276

File tree

7 files changed

+134
-6
lines changed

7 files changed

+134
-6
lines changed

deepmd/tf/fit/dipole.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
one_layer,
3030
one_layer_rand_seed_shift,
3131
)
32+
from deepmd.utils.data import (
33+
DataRequirementItem,
34+
)
3235
from deepmd.utils.version import (
3336
check_version_compatibility,
3437
)
@@ -51,6 +54,10 @@ class DipoleFittingSeA(Fitting):
5154
resnet_dt : bool
5255
Time-step `dt` in the resnet construction:
5356
y = x + dt * \phi (Wx + b)
57+
numb_fparam
58+
Number of frame parameters
59+
numb_aparam
60+
Number of atomic parameters
5461
sel_type : list[int]
5562
The atom types selected to have an atomic dipole prediction. If is None, all atoms are selected.
5663
seed : int
@@ -75,6 +82,8 @@ def __init__(
7582
embedding_width: int,
7683
neuron: list[int] = [120, 120, 120],
7784
resnet_dt: bool = True,
85+
numb_fparam: int = 0,
86+
numb_aparam: int = 0,
7887
sel_type: Optional[list[int]] = None,
7988
seed: Optional[int] = None,
8089
activation_function: str = "tanh",
@@ -108,6 +117,18 @@ def __init__(
108117
self.mixed_prec = None
109118
self.mixed_types = mixed_types
110119
self.type_map = type_map
120+
self.numb_fparam = numb_fparam
121+
self.numb_aparam = numb_aparam
122+
if numb_fparam > 0:
123+
raise ValueError("numb_fparam is not supported in the dipole fitting")
124+
if numb_aparam > 0:
125+
raise ValueError("numb_aparam is not supported in the dipole fitting")
126+
self.fparam_avg = None
127+
self.fparam_std = None
128+
self.fparam_inv_std = None
129+
self.aparam_avg = None
130+
self.aparam_std = None
131+
self.aparam_inv_std = None
111132

112133
def get_sel_type(self) -> int:
113134
"""Get selected type."""
@@ -372,6 +393,8 @@ def serialize(self, suffix: str) -> dict:
372393
"dim_out": 3,
373394
"neuron": self.n_neuron,
374395
"resnet_dt": self.resnet_dt,
396+
"numb_fparam": self.numb_fparam,
397+
"numb_aparam": self.numb_aparam,
375398
"activation_function": self.activation_function_name,
376399
"precision": self.fitting_precision.name,
377400
"exclude_types": [],
@@ -412,3 +435,29 @@ def deserialize(cls, data: dict, suffix: str):
412435
suffix=suffix,
413436
)
414437
return fitting
438+
439+
@property
440+
def input_requirement(self) -> list[DataRequirementItem]:
441+
"""Return data requirements needed for the model input."""
442+
data_requirement = []
443+
if self.numb_fparam > 0:
444+
data_requirement.append(
445+
DataRequirementItem(
446+
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
447+
)
448+
)
449+
if self.numb_aparam > 0:
450+
data_requirement.append(
451+
DataRequirementItem(
452+
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
453+
)
454+
)
455+
return data_requirement
456+
457+
def get_numb_fparam(self) -> int:
458+
"""Get the number of frame parameters."""
459+
return self.numb_fparam
460+
461+
def get_numb_aparam(self) -> int:
462+
"""Get the number of atomic parameters."""
463+
return self.numb_aparam

deepmd/tf/fit/polar.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
one_layer,
3535
one_layer_rand_seed_shift,
3636
)
37+
from deepmd.utils.data import (
38+
DataRequirementItem,
39+
)
3740
from deepmd.utils.version import (
3841
check_version_compatibility,
3942
)
@@ -56,6 +59,10 @@ class PolarFittingSeA(Fitting):
5659
resnet_dt : bool
5760
Time-step `dt` in the resnet construction:
5861
y = x + dt * \phi (Wx + b)
62+
numb_fparam
63+
Number of frame parameters
64+
numb_aparam
65+
Number of atomic parameters
5966
sel_type : list[int]
6067
The atom types selected to have an atomic polarizability prediction. If is None, all atoms are selected.
6168
fit_diag : bool
@@ -86,6 +93,8 @@ def __init__(
8693
embedding_width: int,
8794
neuron: list[int] = [120, 120, 120],
8895
resnet_dt: bool = True,
96+
numb_fparam: int = 0,
97+
numb_aparam: int = 0,
8998
sel_type: Optional[list[int]] = None,
9099
fit_diag: bool = True,
91100
scale: Optional[list[float]] = None,
@@ -151,6 +160,18 @@ def __init__(
151160
self.mixed_prec = None
152161
self.mixed_types = mixed_types
153162
self.type_map = type_map
163+
self.numb_fparam = numb_fparam
164+
self.numb_aparam = numb_aparam
165+
if numb_fparam > 0:
166+
raise ValueError("numb_fparam is not supported in the dipole fitting")
167+
if numb_aparam > 0:
168+
raise ValueError("numb_aparam is not supported in the dipole fitting")
169+
self.fparam_avg = None
170+
self.fparam_std = None
171+
self.fparam_inv_std = None
172+
self.aparam_avg = None
173+
self.aparam_std = None
174+
self.aparam_inv_std = None
154175

155176
def get_sel_type(self) -> list[int]:
156177
"""Get selected atom types."""
@@ -565,6 +586,8 @@ def serialize(self, suffix: str) -> dict:
565586
"dim_out": 3,
566587
"neuron": self.n_neuron,
567588
"resnet_dt": self.resnet_dt,
589+
"numb_fparam": self.numb_fparam,
590+
"numb_aparam": self.numb_aparam,
568591
"activation_function": self.activation_function_name,
569592
"precision": self.fitting_precision.name,
570593
"exclude_types": [],
@@ -777,3 +800,29 @@ def get_loss(self, loss: dict, lr) -> Loss:
777800
atomic=False,
778801
label_name="polarizability",
779802
)
803+
804+
@property
805+
def input_requirement(self) -> list[DataRequirementItem]:
806+
"""Return data requirements needed for the model input."""
807+
data_requirement = []
808+
if self.numb_fparam > 0:
809+
data_requirement.append(
810+
DataRequirementItem(
811+
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
812+
)
813+
)
814+
if self.numb_aparam > 0:
815+
data_requirement.append(
816+
DataRequirementItem(
817+
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
818+
)
819+
)
820+
return data_requirement
821+
822+
def get_numb_fparam(self) -> int:
823+
"""Get the number of frame parameters."""
824+
return self.numb_fparam
825+
826+
def get_numb_aparam(self) -> int:
827+
"""Get the number of atomic parameters."""
828+
return self.numb_aparam

deepmd/utils/argcheck.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,8 @@ def fitting_property():
15951595

15961596
@fitting_args_plugin.register("polar", doc=doc_polar)
15971597
def fitting_polar():
1598+
doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams."
1599+
doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams."
15981600
doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built."
15991601
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
16001602
doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
@@ -1609,6 +1611,20 @@ def fitting_polar():
16091611
doc_shift_diag = "Whether to shift the diagonal of polar, which is beneficial to training. Default is true."
16101612

16111613
return [
1614+
Argument(
1615+
"numb_fparam",
1616+
int,
1617+
optional=True,
1618+
default=0,
1619+
doc=doc_only_pt_supported + doc_numb_fparam,
1620+
),
1621+
Argument(
1622+
"numb_aparam",
1623+
int,
1624+
optional=True,
1625+
default=0,
1626+
doc=doc_only_pt_supported + doc_numb_aparam,
1627+
),
16121628
Argument(
16131629
"neuron",
16141630
list[int],
@@ -1649,13 +1665,29 @@ def fitting_polar():
16491665

16501666
@fitting_args_plugin.register("dipole", doc=doc_dipole)
16511667
def fitting_dipole():
1668+
doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams."
1669+
doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams."
16521670
doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built."
16531671
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
16541672
doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
16551673
doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision."
16561674
doc_sel_type = "The atom types for which the atomic dipole will be provided. If not set, all types will be selected."
16571675
doc_seed = "Random seed for parameter initialization of the fitting net"
16581676
return [
1677+
Argument(
1678+
"numb_fparam",
1679+
int,
1680+
optional=True,
1681+
default=0,
1682+
doc=doc_only_pt_supported + doc_numb_fparam,
1683+
),
1684+
Argument(
1685+
"numb_aparam",
1686+
int,
1687+
optional=True,
1688+
default=0,
1689+
doc=doc_only_pt_supported + doc_numb_aparam,
1690+
),
16591691
Argument(
16601692
"neuron",
16611693
list[int],

doc/model/train-fitting-tensor.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,4 @@ During training, at each step when the `lcurve.out` is printed, the system used
247247
248248
To only fit against a subset of atomic types, in the TensorFlow backend, {ref}`fitting_net/sel_type <model[standard]/fitting_net[dipole]/sel_type>` should be set to selected types;
249249
in other backends, {ref}`atom_exclude_types <model/atom_exclude_types>` should be set to excluded types.
250+
The TensorFlow backend does not support {ref}`numb_fparam <model[standard]/fitting_net[dipole]/numb_fparam>` and {ref}`numb_aparam <model[standard]/fitting_net[dipole]/numb_aparam>`.

source/tests/consistent/model/test_dipole.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def data(self) -> dict:
6262
"type": "dipole",
6363
"neuron": [4, 4, 4],
6464
"resnet_dt": True,
65-
# TODO: add numb_fparam argument to dipole fitting
66-
"_numb_fparam": 0,
65+
"numb_fparam": 0,
6766
"precision": "float64",
6867
"seed": 1,
6968
},

source/tests/consistent/model/test_polar.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def data(self) -> dict:
6262
"type": "polar",
6363
"neuron": [4, 4, 4],
6464
"resnet_dt": True,
65-
# TODO: add numb_fparam argument to polar fitting
66-
"_numb_fparam": 0,
65+
"numb_fparam": 0,
6766
"precision": "float64",
6867
"seed": 1,
6968
},

source/tests/consistent/model/test_property.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def data(self) -> dict:
5757
"type": "property",
5858
"neuron": [4, 4, 4],
5959
"resnet_dt": True,
60-
# TODO: add numb_fparam argument to property fitting
61-
"_numb_fparam": 0,
60+
"numb_fparam": 0,
6261
"precision": "float64",
6362
"seed": 1,
6463
},

0 commit comments

Comments
 (0)