Skip to content

Commit a4e30cc

Browse files
Chore: refactor dpmodel model (#4296)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced several new atomic models: DPDipoleAtomicModel, DPDOSAtomicModel, DPEnergyAtomicModel, and DPPolarAtomicModel. - Enhanced the public API to include these new models for easier access. - Added new model classes: DipoleModel, PolarModel, and DOSModel, extending the functionality of the framework. - Implemented unit tests for Density of States (DOS) models, ensuring functionality across different computational backends. - **Bug Fixes** - Improved error handling in atomic model constructors to provide clearer messages for type requirements. - **Documentation** - Added a license identifier to the dos_model.py file. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 447ea94 commit a4e30cc

20 files changed

+561
-50
lines changed

deepmd/dpmodel/atomic_model/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,18 @@
1717
from .base_atomic_model import (
1818
BaseAtomicModel,
1919
)
20+
from .dipole_atomic_model import (
21+
DPDipoleAtomicModel,
22+
)
23+
from .dos_atomic_model import (
24+
DPDOSAtomicModel,
25+
)
2026
from .dp_atomic_model import (
2127
DPAtomicModel,
2228
)
29+
from .energy_atomic_model import (
30+
DPEnergyAtomicModel,
31+
)
2332
from .linear_atomic_model import (
2433
DPZBLLinearEnergyAtomicModel,
2534
LinearEnergyAtomicModel,
@@ -30,12 +39,19 @@
3039
from .pairtab_atomic_model import (
3140
PairTabAtomicModel,
3241
)
42+
from .polar_atomic_model import (
43+
DPPolarAtomicModel,
44+
)
3345

3446
__all__ = [
3547
"make_base_atomic_model",
3648
"BaseAtomicModel",
3749
"DPAtomicModel",
50+
"DPEnergyAtomicModel",
3851
"PairTabAtomicModel",
3952
"LinearEnergyAtomicModel",
4053
"DPZBLLinearEnergyAtomicModel",
54+
"DPDOSAtomicModel",
55+
"DPPolarAtomicModel",
56+
"DPDipoleAtomicModel",
4157
]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import numpy as np
3+
4+
from deepmd.dpmodel.fitting.dipole_fitting import (
5+
DipoleFitting,
6+
)
7+
8+
from .dp_atomic_model import (
9+
DPAtomicModel,
10+
)
11+
12+
13+
class DPDipoleAtomicModel(DPAtomicModel):
14+
def __init__(self, descriptor, fitting, type_map, **kwargs):
15+
if not isinstance(fitting, DipoleFitting):
16+
raise TypeError(
17+
"fitting must be an instance of DipoleFitting for DPDipoleAtomicModel"
18+
)
19+
super().__init__(descriptor, fitting, type_map, **kwargs)
20+
21+
def apply_out_stat(
22+
self,
23+
ret: dict[str, np.ndarray],
24+
atype: np.ndarray,
25+
):
26+
# dipole not applying bias
27+
return ret
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.fitting.dos_fitting import (
3+
DOSFittingNet,
4+
)
5+
6+
from .dp_atomic_model import (
7+
DPAtomicModel,
8+
)
9+
10+
11+
class DPDOSAtomicModel(DPAtomicModel):
12+
def __init__(self, descriptor, fitting, type_map, **kwargs):
13+
if not isinstance(fitting, DOSFittingNet):
14+
raise TypeError(
15+
"fitting must be an instance of DOSFittingNet for DPDOSAtomicModel"
16+
)
17+
super().__init__(descriptor, fitting, type_map, **kwargs)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.fitting.ener_fitting import (
3+
EnergyFittingNet,
4+
InvarFitting,
5+
)
6+
7+
from .dp_atomic_model import (
8+
DPAtomicModel,
9+
)
10+
11+
12+
class DPEnergyAtomicModel(DPAtomicModel):
13+
def __init__(self, descriptor, fitting, type_map, **kwargs):
14+
if not (
15+
isinstance(fitting, EnergyFittingNet) or isinstance(fitting, InvarFitting)
16+
):
17+
raise TypeError(
18+
"fitting must be an instance of EnergyFittingNet or InvarFitting for DPEnergyAtomicModel"
19+
)
20+
super().__init__(descriptor, fitting, type_map, **kwargs)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
import numpy as np
4+
5+
from deepmd.dpmodel.fitting.polarizability_fitting import (
6+
PolarFitting,
7+
)
8+
9+
from .dp_atomic_model import (
10+
DPAtomicModel,
11+
)
12+
13+
14+
class DPPolarAtomicModel(DPAtomicModel):
15+
def __init__(self, descriptor, fitting, type_map, **kwargs):
16+
if not isinstance(fitting, PolarFitting):
17+
raise TypeError(
18+
"fitting must be an instance of PolarFitting for DPPolarAtomicModel"
19+
)
20+
super().__init__(descriptor, fitting, type_map, **kwargs)
21+
22+
def apply_out_stat(
23+
self,
24+
ret: dict[str, np.ndarray],
25+
atype: np.ndarray,
26+
):
27+
"""Apply the stat to each atomic output.
28+
29+
Parameters
30+
----------
31+
ret
32+
The returned dict by the forward_atomic method
33+
atype
34+
The atom types. nf x nloc
35+
36+
"""
37+
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
38+
39+
if self.fitting_net.shift_diag:
40+
nframes, nloc = atype.shape
41+
dtype = out_bias[self.bias_keys[0]].dtype
42+
for kk in self.bias_keys:
43+
ntypes = out_bias[kk].shape[0]
44+
temp = np.zeros(ntypes, dtype=dtype)
45+
temp = np.mean(
46+
np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
47+
axis=1,
48+
)
49+
modified_bias = temp[atype]
50+
51+
# (nframes, nloc, 1)
52+
modified_bias = (
53+
modified_bias[..., np.newaxis] * (self.fitting_net.scale[atype])
54+
)
55+
56+
eye = np.eye(3, dtype=dtype)
57+
eye = np.tile(eye, (nframes, nloc, 1, 1))
58+
# (nframes, nloc, 3, 3)
59+
modified_bias = modified_bias[..., np.newaxis] * eye
60+
61+
# nf x nloc x odims, out_bias: ntypes x odims
62+
ret[kk] = ret[kk] + modified_bias
63+
return ret

deepmd/dpmodel/atomic_model/property_atomic_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010

1111
class DPPropertyAtomicModel(DPAtomicModel):
12-
def __init__(self, descriptor, fitting, type_map, **kwargs) -> None:
13-
assert isinstance(fitting, PropertyFittingNet)
12+
def __init__(self, descriptor, fitting, type_map, **kwargs):
13+
if not isinstance(fitting, PropertyFittingNet):
14+
raise TypeError(
15+
"fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel"
16+
)
1417
super().__init__(descriptor, fitting, type_map, **kwargs)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
4+
from deepmd.dpmodel.atomic_model import (
5+
DPDipoleAtomicModel,
6+
)
7+
from deepmd.dpmodel.model.base_model import (
8+
BaseModel,
9+
)
10+
11+
from .dp_model import (
12+
DPModelCommon,
13+
)
14+
from .make_model import (
15+
make_model,
16+
)
17+
18+
DPDipoleModel_ = make_model(DPDipoleAtomicModel)
19+
20+
21+
@BaseModel.register("dipole")
22+
class DipoleModel(DPModelCommon, DPDipoleModel_):
23+
model_type = "dipole"
24+
25+
def __init__(
26+
self,
27+
*args,
28+
**kwargs,
29+
):
30+
DPModelCommon.__init__(self)
31+
DPDipoleModel_.__init__(self, *args, **kwargs)

deepmd/dpmodel/model/dos_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
from deepmd.dpmodel.atomic_model import (
4+
DPDOSAtomicModel,
5+
)
6+
from deepmd.dpmodel.model.base_model import (
7+
BaseModel,
8+
)
9+
10+
from .dp_model import (
11+
DPModelCommon,
12+
)
13+
from .make_model import (
14+
make_model,
15+
)
16+
17+
DPDOSModel_ = make_model(DPDOSAtomicModel)
18+
19+
20+
@BaseModel.register("dos")
21+
class DOSModel(DPModelCommon, DPDOSModel_):
22+
model_type = "dos"
23+
24+
def __init__(
25+
self,
26+
*args,
27+
**kwargs,
28+
):
29+
DPModelCommon.__init__(self)
30+
DPDOSModel_.__init__(self, *args, **kwargs)

deepmd/dpmodel/model/ener_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
3-
DPAtomicModel,
2+
from deepmd.dpmodel.atomic_model import (
3+
DPEnergyAtomicModel,
44
)
55
from deepmd.dpmodel.model.base_model import (
66
BaseModel,
@@ -13,7 +13,7 @@
1313
make_model,
1414
)
1515

16-
DPEnergyModel_ = make_model(DPAtomicModel)
16+
DPEnergyModel_ = make_model(DPEnergyAtomicModel)
1717

1818

1919
@BaseModel.register("ener")

deepmd/dpmodel/model/model.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import copy
3+
24
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
35
DPAtomicModel,
46
)
@@ -8,18 +10,33 @@
810
from deepmd.dpmodel.descriptor.base_descriptor import (
911
BaseDescriptor,
1012
)
13+
from deepmd.dpmodel.fitting.base_fitting import (
14+
BaseFitting,
15+
)
1116
from deepmd.dpmodel.fitting.ener_fitting import (
1217
EnergyFittingNet,
1318
)
1419
from deepmd.dpmodel.model.base_model import (
1520
BaseModel,
1621
)
22+
from deepmd.dpmodel.model.dipole_model import (
23+
DipoleModel,
24+
)
25+
from deepmd.dpmodel.model.dos_model import (
26+
DOSModel,
27+
)
1728
from deepmd.dpmodel.model.dp_zbl_model import (
1829
DPZBLModel,
1930
)
2031
from deepmd.dpmodel.model.ener_model import (
2132
EnergyModel,
2233
)
34+
from deepmd.dpmodel.model.polar_model import (
35+
PolarModel,
36+
)
37+
from deepmd.dpmodel.model.property_model import (
38+
PropertyModel,
39+
)
2340
from deepmd.dpmodel.model.spin_model import (
2441
SpinModel,
2542
)
@@ -28,6 +45,29 @@
2845
)
2946

3047

48+
def _get_standard_model_components(data, ntypes):
49+
# descriptor
50+
data["descriptor"]["ntypes"] = ntypes
51+
data["descriptor"]["type_map"] = copy.deepcopy(data["type_map"])
52+
descriptor = BaseDescriptor(**data["descriptor"])
53+
# fitting
54+
fitting_net = data.get("fitting_net", {})
55+
fitting_net["type"] = fitting_net.get("type", "ener")
56+
fitting_net["ntypes"] = descriptor.get_ntypes()
57+
fitting_net["type_map"] = copy.deepcopy(data["type_map"])
58+
fitting_net["mixed_types"] = descriptor.mixed_types()
59+
if fitting_net["type"] in ["dipole", "polar"]:
60+
fitting_net["embedding_width"] = descriptor.get_dim_emb()
61+
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
62+
grad_force = "direct" not in fitting_net["type"]
63+
if not grad_force:
64+
fitting_net["out_dim"] = descriptor.get_dim_emb()
65+
if "ener" in fitting_net["type"]:
66+
fitting_net["return_energy"] = True
67+
fitting = BaseFitting(**fitting_net)
68+
return descriptor, fitting, fitting_net["type"]
69+
70+
3171
def get_standard_model(data: dict) -> EnergyModel:
3272
"""Get a EnergyModel from a dictionary.
3373
@@ -40,29 +80,33 @@ def get_standard_model(data: dict) -> EnergyModel:
4080
raise ValueError(
4181
"In the DP backend, type_embedding is not at the model level, but within the descriptor. See type embedding documentation for details."
4282
)
43-
data["descriptor"]["type_map"] = data["type_map"]
44-
data["descriptor"]["ntypes"] = len(data["type_map"])
45-
fitting_type = data["fitting_net"].pop("type")
46-
data["fitting_net"]["type_map"] = data["type_map"]
47-
descriptor = BaseDescriptor(
48-
**data["descriptor"],
49-
)
50-
if fitting_type == "ener":
51-
fitting = EnergyFittingNet(
52-
ntypes=descriptor.get_ntypes(),
53-
dim_descrpt=descriptor.get_dim_out(),
54-
mixed_types=descriptor.mixed_types(),
55-
**data["fitting_net"],
56-
)
83+
data = copy.deepcopy(data)
84+
ntypes = len(data["type_map"])
85+
descriptor, fitting, fitting_net_type = _get_standard_model_components(data, ntypes)
86+
atom_exclude_types = data.get("atom_exclude_types", [])
87+
pair_exclude_types = data.get("pair_exclude_types", [])
88+
89+
if fitting_net_type == "dipole":
90+
modelcls = DipoleModel
91+
elif fitting_net_type == "polar":
92+
modelcls = PolarModel
93+
elif fitting_net_type == "dos":
94+
modelcls = DOSModel
95+
elif fitting_net_type in ["ener", "direct_force_ener"]:
96+
modelcls = EnergyModel
97+
elif fitting_net_type == "property":
98+
modelcls = PropertyModel
5799
else:
58-
raise ValueError(f"Unknown fitting type {fitting_type}")
59-
return EnergyModel(
100+
raise RuntimeError(f"Unknown fitting type: {fitting_net_type}")
101+
102+
model = modelcls(
60103
descriptor=descriptor,
61104
fitting=fitting,
62105
type_map=data["type_map"],
63-
atom_exclude_types=data.get("atom_exclude_types", []),
64-
pair_exclude_types=data.get("pair_exclude_types", []),
106+
atom_exclude_types=atom_exclude_types,
107+
pair_exclude_types=pair_exclude_types,
65108
)
109+
return model
66110

67111

68112
def get_zbl_model(data: dict) -> DPZBLModel:

0 commit comments

Comments
 (0)