Skip to content

Commit 134b36a

Browse files
committed
add plugin for data modifier
1 parent 918d4de commit 134b36a

File tree

8 files changed

+620
-25
lines changed

8 files changed

+620
-25
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .base_modifier import (
3+
make_base_modifier,
4+
)
5+
6+
__all__ = [
7+
"make_base_modifier",
8+
]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import inspect
3+
from abc import (
4+
ABC,
5+
abstractmethod,
6+
)
7+
8+
from deepmd.utils.plugin import (
9+
PluginVariant,
10+
make_plugin_registry,
11+
)
12+
13+
14+
def make_base_modifier() -> type[object]:
15+
class BaseModifier(ABC, PluginVariant, make_plugin_registry("modifier")):
16+
"""Base class for data modifier."""
17+
18+
def __new__(cls, *args, **kwargs):
19+
if cls is BaseModifier:
20+
cls = cls.get_class_by_type(kwargs["type"])
21+
return super().__new__(cls)
22+
23+
@abstractmethod
24+
def serialize(self) -> dict:
25+
"""Serialize the modifier.
26+
27+
Returns
28+
-------
29+
dict
30+
The serialized data
31+
"""
32+
pass
33+
34+
@classmethod
35+
def deserialize(cls, data: dict) -> "BaseModifier":
36+
"""Deserialize the modifier.
37+
38+
Parameters
39+
----------
40+
data : dict
41+
The serialized data
42+
43+
Returns
44+
-------
45+
BaseModel
46+
The deserialized modifier
47+
"""
48+
if inspect.isabstract(cls):
49+
return cls.get_class_by_type(data["type"]).deserialize(data)
50+
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
51+
52+
@classmethod
53+
def get_modifier(cls, modifier_params: dict) -> "BaseModifier":
54+
"""Get the modifier by the parameters.
55+
56+
By default, all the parameters are directly passed to the constructor.
57+
If not, override this method.
58+
59+
Parameters
60+
----------
61+
modifier_params : dict
62+
The modifier parameters
63+
64+
Returns
65+
-------
66+
BaseModifier
67+
The modifier
68+
"""
69+
modifier_params = modifier_params.copy()
70+
modifier_params.pop("type", None)
71+
modifier = cls(**modifier_params)
72+
return modifier
73+
74+
return BaseModifier

deepmd/tf/entrypoints/train.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Can handle local or distributed training.
55
"""
66

7+
import copy
78
import json
89
import logging
910
import time
@@ -20,12 +21,12 @@
2021
reset_default_tf_session_config,
2122
tf,
2223
)
23-
from deepmd.tf.infer.data_modifier import (
24-
DipoleChargeModifier,
25-
)
2624
from deepmd.tf.model.model import (
2725
Model,
2826
)
27+
from deepmd.tf.modifier import (
28+
BaseModifier,
29+
)
2930
from deepmd.tf.train.run_options import (
3031
RunOptions,
3132
)
@@ -275,18 +276,13 @@ def _do_work(
275276

276277

277278
def get_modifier(modi_data=None):
278-
modifier: Optional[DipoleChargeModifier]
279+
modifier: Optional[BaseModifier]
279280
if modi_data is not None:
280-
if modi_data["type"] == "dipole_charge":
281-
modifier = DipoleChargeModifier(
282-
modi_data["model_name"],
283-
modi_data["model_charge_map"],
284-
modi_data["sys_charge_map"],
285-
modi_data["ewald_h"],
286-
modi_data["ewald_beta"],
287-
)
288-
else:
289-
raise RuntimeError("unknown modifier type " + str(modi_data["type"]))
281+
modifier_params = copy.deepcopy(modi_data)
282+
modifier_type = modifier_params.pop("type")
283+
modifier = BaseModifier.get_class_by_type(modifier_type).get_modifier(
284+
modifier_params
285+
)
290286
else:
291287
modifier = None
292288
return modifier

deepmd/tf/modifier/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .base_modifier import (
3+
BaseModifier,
4+
)
5+
from .dipole_charge import (
6+
DipoleChargeModifier,
7+
)
8+
9+
__all__ = [
10+
"BaseModifier",
11+
"DipoleChargeModifier",
12+
]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.modifier.base_modifier import (
3+
make_base_modifier,
4+
)
5+
from deepmd.tf.infer import (
6+
DeepPot,
7+
)
8+
9+
10+
class BaseModifier(DeepPot, make_base_modifier()):
11+
def __init__(self, *args, **kwargs) -> None:
12+
"""Construct a basic model for different tasks."""
13+
DeepPot.__init__(self, *args, **kwargs)

0 commit comments

Comments
 (0)