Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/dpmodel/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_modifier import (
make_base_modifier,
)

__all__ = [
"make_base_modifier",
]
74 changes: 74 additions & 0 deletions deepmd/dpmodel/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
from abc import (
ABC,
abstractmethod,
)

from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)


def make_base_modifier() -> type[object]:
class BaseModifier(ABC, PluginVariant, make_plugin_registry("modifier")):
"""Base class for data modifier."""

def __new__(cls, *args, **kwargs):
if cls is BaseModifier:
cls = cls.get_class_by_type(kwargs["type"])
return super().__new__(cls)

@abstractmethod
def serialize(self) -> dict:
"""Serialize the modifier.

Returns
-------
dict
The serialized data
"""
pass

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
"""Deserialize the modifier.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized modifier
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError(f"Not implemented in class {cls.__name__}")

@classmethod
def get_modifier(cls, modifier_params: dict) -> "BaseModifier":
"""Get the modifier by the parameters.

By default, all the parameters are directly passed to the constructor.
If not, override this method.

Parameters
----------
modifier_params : dict
The modifier parameters

Returns
-------
BaseModifier
The modifier
"""
modifier_params = modifier_params.copy()
modifier_params.pop("type", None)
modifier = cls(**modifier_params)
return modifier

return BaseModifier
24 changes: 10 additions & 14 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Can handle local or distributed training.
"""

import copy
import json
import logging
import time
Expand All @@ -20,12 +21,12 @@
reset_default_tf_session_config,
tf,
)
from deepmd.tf.infer.data_modifier import (
DipoleChargeModifier,
)
from deepmd.tf.model.model import (
Model,
)
from deepmd.tf.modifier import (
BaseModifier,
)
from deepmd.tf.train.run_options import (
RunOptions,
)
Expand Down Expand Up @@ -275,18 +276,13 @@ def _do_work(


def get_modifier(modi_data=None):
modifier: Optional[DipoleChargeModifier]
modifier: Optional[BaseModifier]
if modi_data is not None:
if modi_data["type"] == "dipole_charge":
modifier = DipoleChargeModifier(
modi_data["model_name"],
modi_data["model_charge_map"],
modi_data["sys_charge_map"],
modi_data["ewald_h"],
modi_data["ewald_beta"],
)
else:
raise RuntimeError("unknown modifier type " + str(modi_data["type"]))
modifier_params = copy.deepcopy(modi_data)
modifier_type = modifier_params.pop("type")
modifier = BaseModifier.get_class_by_type(modifier_type).get_modifier(
modifier_params
)
else:
modifier = None
return modifier
Expand Down
12 changes: 12 additions & 0 deletions deepmd/tf/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_modifier import (
BaseModifier,
)
from .dipole_charge import (
DipoleChargeModifier,
)

__all__ = [
"BaseModifier",
"DipoleChargeModifier",
]
13 changes: 13 additions & 0 deletions deepmd/tf/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.modifier.base_modifier import (
make_base_modifier,
)
from deepmd.tf.infer import (
DeepPot,
)


class BaseModifier(DeepPot, make_base_modifier()):
def __init__(self, *args, **kwargs) -> None:
"""Construct a basic model for different tasks."""
DeepPot.__init__(self, *args, **kwargs)
Loading
Loading