Skip to content

Commit 4c8600a

Browse files
authored
feat: add plugin mode for data modifier (#4621)
- add plugin mode for data modifier - adapt existing data modifier (dipole_charge in TF backend) - adapt unit tests for data modifier <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a unified modifiers framework that supports dynamic instantiation and enhanced dipole charge modifications, improving data handling and model evaluation. - Added new classes `BaseModifier`, `DipoleChargeModifier`, and `DescrptDPA3` for better management of modifier functionalities and descriptor computations. - Implemented new activation function "silu" across various components, enhancing flexibility in neural network configurations. - Added new model configuration `model_dpa3` for testing purposes. - **Refactor** - Updated the training pipeline to adopt the generic modifier approach, replacing previous type-specific logic. - Restructured module organization for better clarity in imports and enhanced argument handling for modifiers, transitioning to a more dynamic retrieval process. - **Tests** - Improved clarity in test configurations by switching to explicitly named parameters for modifier setup. - Updated import paths for `DipoleChargeModifier` in test files to reflect new module organization. - Expanded testing coverage for the `DescrptDPA3` descriptor and added new test classes to validate functionality across different configurations. - Introduced new test class `TestDataMixType` to validate functionality of the `DeepmdData` class with various type maps. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5dd215c commit 4c8600a

File tree

13 files changed

+192
-43
lines changed

13 files changed

+192
-43
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .base_modifier import (
3+
make_base_modifier,
4+
)
5+
6+
__all__ = [
7+
"make_base_modifier",
8+
]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import inspect
3+
from abc import (
4+
ABC,
5+
abstractmethod,
6+
)
7+
8+
from deepmd.utils.plugin import (
9+
PluginVariant,
10+
make_plugin_registry,
11+
)
12+
13+
14+
def make_base_modifier() -> type[object]:
15+
class BaseModifier(ABC, PluginVariant, make_plugin_registry("modifier")):
16+
"""Base class for data modifier."""
17+
18+
def __new__(cls, *args, **kwargs):
19+
if cls is BaseModifier:
20+
cls = cls.get_class_by_type(kwargs["type"])
21+
return super().__new__(cls)
22+
23+
@abstractmethod
24+
def serialize(self) -> dict:
25+
"""Serialize the modifier.
26+
27+
Returns
28+
-------
29+
dict
30+
The serialized data
31+
"""
32+
pass
33+
34+
@classmethod
35+
def deserialize(cls, data: dict) -> "BaseModifier":
36+
"""Deserialize the modifier.
37+
38+
Parameters
39+
----------
40+
data : dict
41+
The serialized data
42+
43+
Returns
44+
-------
45+
BaseModel
46+
The deserialized modifier
47+
"""
48+
if inspect.isabstract(cls):
49+
return cls.get_class_by_type(data["type"]).deserialize(data)
50+
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
51+
52+
@classmethod
53+
def get_modifier(cls, modifier_params: dict) -> "BaseModifier":
54+
"""Get the modifier by the parameters.
55+
56+
By default, all the parameters are directly passed to the constructor.
57+
If not, override this method.
58+
59+
Parameters
60+
----------
61+
modifier_params : dict
62+
The modifier parameters
63+
64+
Returns
65+
-------
66+
BaseModifier
67+
The modifier
68+
"""
69+
modifier_params = modifier_params.copy()
70+
modifier_params.pop("type", None)
71+
modifier = cls(**modifier_params)
72+
return modifier
73+
74+
return BaseModifier

deepmd/tf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DeepEval,
2222
DeepPotential,
2323
)
24-
from .infer.data_modifier import (
24+
from .modifier import (
2525
DipoleChargeModifier,
2626
)
2727

deepmd/tf/entrypoints/train.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Can handle local or distributed training.
55
"""
66

7+
import copy
78
import json
89
import logging
910
import time
@@ -20,12 +21,12 @@
2021
reset_default_tf_session_config,
2122
tf,
2223
)
23-
from deepmd.tf.infer.data_modifier import (
24-
DipoleChargeModifier,
25-
)
2624
from deepmd.tf.model.model import (
2725
Model,
2826
)
27+
from deepmd.tf.modifier import (
28+
BaseModifier,
29+
)
2930
from deepmd.tf.train.run_options import (
3031
RunOptions,
3132
)
@@ -275,18 +276,13 @@ def _do_work(
275276

276277

277278
def get_modifier(modi_data=None):
278-
modifier: Optional[DipoleChargeModifier]
279+
modifier: Optional[BaseModifier]
279280
if modi_data is not None:
280-
if modi_data["type"] == "dipole_charge":
281-
modifier = DipoleChargeModifier(
282-
modi_data["model_name"],
283-
modi_data["model_charge_map"],
284-
modi_data["sys_charge_map"],
285-
modi_data["ewald_h"],
286-
modi_data["ewald_beta"],
287-
)
288-
else:
289-
raise RuntimeError("unknown modifier type " + str(modi_data["type"]))
281+
modifier_params = copy.deepcopy(modi_data)
282+
modifier_type = modifier_params.pop("type")
283+
modifier = BaseModifier.get_class_by_type(modifier_type).get_modifier(
284+
modifier_params
285+
)
290286
else:
291287
modifier = None
292288
return modifier

deepmd/tf/infer/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
DeepEval,
99
)
1010

11-
from .data_modifier import (
12-
DipoleChargeModifier,
13-
)
1411
from .deep_dipole import (
1512
DeepDipole,
1613
)
@@ -43,7 +40,6 @@
4340
"DeepPot",
4441
"DeepPotential",
4542
"DeepWFC",
46-
"DipoleChargeModifier",
4743
"EwaldRecp",
4844
"calc_model_devi",
4945
]

deepmd/tf/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(
139139

140140
# looks ugly...
141141
if self.modifier_type == "dipole_charge":
142-
from deepmd.tf.infer.data_modifier import (
142+
from deepmd.tf.modifier import (
143143
DipoleChargeModifier,
144144
)
145145

deepmd/tf/modifier/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .base_modifier import (
3+
BaseModifier,
4+
)
5+
from .dipole_charge import (
6+
DipoleChargeModifier,
7+
)
8+
9+
__all__ = [
10+
"BaseModifier",
11+
"DipoleChargeModifier",
12+
]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.modifier.base_modifier import (
3+
make_base_modifier,
4+
)
5+
from deepmd.tf.infer import (
6+
DeepPot,
7+
)
8+
9+
10+
class BaseModifier(DeepPot, make_base_modifier()):
11+
def __init__(self, *args, **kwargs) -> None:
12+
"""Construct a basic model for different tasks."""
13+
DeepPot.__init__(self, *args, **kwargs)
Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from deepmd.tf.infer.ewald_recp import (
1818
EwaldRecp,
1919
)
20+
from deepmd.tf.modifier.base_modifier import (
21+
BaseModifier,
22+
)
2023
from deepmd.tf.utils.data import (
2124
DeepmdData,
2225
)
@@ -25,7 +28,8 @@
2528
)
2629

2730

28-
class DipoleChargeModifier(DeepDipole):
31+
@BaseModifier.register("dipole_charge")
32+
class DipoleChargeModifier(DeepDipole, BaseModifier):
2933
"""Parameters
3034
----------
3135
model_name
@@ -40,6 +44,9 @@ class DipoleChargeModifier(DeepDipole):
4044
Splitting parameter of the Ewald sum. Unit: A^{-1}
4145
"""
4246

47+
def __new__(cls, *args, model_name=None, **kwargs):
48+
return super().__new__(cls, model_name)
49+
4350
def __init__(
4451
self,
4552
model_name: str,
@@ -82,6 +89,44 @@ def __init__(
8289
self.force = None
8390
self.ntypes = len(self.sel_a)
8491

92+
def serialize(self) -> dict:
93+
"""Serialize the modifier.
94+
95+
Returns
96+
-------
97+
dict
98+
The serialized data
99+
"""
100+
data = {
101+
"@class": "Modifier",
102+
"type": self.modifier_prefix,
103+
"@version": 3,
104+
"model_name": self.model_name,
105+
"model_charge_map": self.model_charge_map,
106+
"sys_charge_map": self.sys_charge_map,
107+
"ewald_h": self.ewald_h,
108+
"ewald_beta": self.ewald_beta,
109+
}
110+
return data
111+
112+
@classmethod
113+
def deserialize(cls, data: dict) -> "BaseModifier":
114+
"""Deserialize the modifier.
115+
116+
Parameters
117+
----------
118+
data : dict
119+
The serialized data
120+
121+
Returns
122+
-------
123+
BaseModel
124+
The deserialized modifier
125+
"""
126+
data = data.copy()
127+
modifier = cls(**data)
128+
return modifier
129+
85130
def build_fv_graph(self) -> tf.Tensor:
86131
"""Build the computational graph for the force and virial inference."""
87132
with tf.variable_scope("modifier_attr"):

deepmd/utils/argcheck.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has number of frames lines and number of energy grid columns (times number of atoms in `atom_dos.npy`). See `loss` parameter."
5656
doc_dipole = "Fit an atomic dipole model. Global dipole labels or atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file either has number of frames lines and 3 times of number of selected atoms columns, or has number of frames lines and 3 columns. See `loss` parameter."
5757
doc_polar = "Fit an atomic polarizability model. Global polarizazbility labels or atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file with has number of frames lines and 9 times of number of selected atoms columns, or has number of frames lines and 9 columns. See `loss` parameter."
58+
# modifier
59+
doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction."
5860

5961

6062
def list_to_doc(xx):
@@ -2015,6 +2017,10 @@ def fitting_variant_type_args():
20152017

20162018

20172019
# --- Modifier configurations: --- #
2020+
modifier_args_plugin = ArgsPlugin()
2021+
2022+
2023+
@modifier_args_plugin.register("dipole_charge", doc=doc_dipole_charge)
20182024
def modifier_dipole_charge():
20192025
doc_model_name = "The name of the frozen dipole model file."
20202026
doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model[standard]/fitting_net[dipole]/sel_type')}. "
@@ -2035,14 +2041,9 @@ def modifier_dipole_charge():
20352041

20362042
def modifier_variant_type_args():
20372043
doc_modifier_type = "The type of modifier."
2038-
doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction."
20392044
return Variant(
20402045
"type",
2041-
[
2042-
Argument(
2043-
"dipole_charge", dict, modifier_dipole_charge(), doc=doc_dipole_charge
2044-
),
2045-
],
2046+
modifier_args_plugin.get_all_argument(),
20462047
optional=False,
20472048
doc=doc_modifier_type,
20482049
)

0 commit comments

Comments
 (0)