Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/dpmodel/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_modifier import (
make_base_modifier,
)

__all__ = [
"make_base_modifier",
]
74 changes: 74 additions & 0 deletions deepmd/dpmodel/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
from abc import (
ABC,
abstractmethod,
)

from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)


def make_base_modifier() -> type[object]:
class BaseModifier(ABC, PluginVariant, make_plugin_registry("modifier")):
"""Base class for data modifier."""

def __new__(cls, *args, **kwargs):
if cls is BaseModifier:
cls = cls.get_class_by_type(kwargs["type"])

Check warning on line 20 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L20

Added line #L20 was not covered by tests
return super().__new__(cls)

@abstractmethod
def serialize(self) -> dict:
"""Serialize the modifier.

Returns
-------
dict
The serialized data
"""
pass

Check warning on line 32 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L32

Added line #L32 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
"""Deserialize the modifier.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized modifier
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError(f"Not implemented in class {cls.__name__}")

Check warning on line 50 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L48-L50

Added lines #L48 - L50 were not covered by tests

@classmethod
def get_modifier(cls, modifier_params: dict) -> "BaseModifier":
"""Get the modifier by the parameters.

By default, all the parameters are directly passed to the constructor.
If not, override this method.

Parameters
----------
modifier_params : dict
The modifier parameters

Returns
-------
BaseModifier
The modifier
"""
modifier_params = modifier_params.copy()
modifier_params.pop("type", None)
modifier = cls(**modifier_params)
return modifier

Check warning on line 72 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L69-L72

Added lines #L69 - L72 were not covered by tests

return BaseModifier
2 changes: 1 addition & 1 deletion deepmd/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DeepEval,
DeepPotential,
)
from .infer.data_modifier import (
from .modifier import (
DipoleChargeModifier,
)

Expand Down
24 changes: 10 additions & 14 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Can handle local or distributed training.
"""

import copy
import json
import logging
import time
Expand All @@ -20,12 +21,12 @@
reset_default_tf_session_config,
tf,
)
from deepmd.tf.infer.data_modifier import (
DipoleChargeModifier,
)
from deepmd.tf.model.model import (
Model,
)
from deepmd.tf.modifier import (
BaseModifier,
)
from deepmd.tf.train.run_options import (
RunOptions,
)
Expand Down Expand Up @@ -275,18 +276,13 @@


def get_modifier(modi_data=None):
modifier: Optional[DipoleChargeModifier]
modifier: Optional[BaseModifier]
if modi_data is not None:
if modi_data["type"] == "dipole_charge":
modifier = DipoleChargeModifier(
modi_data["model_name"],
modi_data["model_charge_map"],
modi_data["sys_charge_map"],
modi_data["ewald_h"],
modi_data["ewald_beta"],
)
else:
raise RuntimeError("unknown modifier type " + str(modi_data["type"]))
modifier_params = copy.deepcopy(modi_data)
modifier_type = modifier_params.pop("type")
modifier = BaseModifier.get_class_by_type(modifier_type).get_modifier(

Check warning on line 283 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L281-L283

Added lines #L281 - L283 were not covered by tests
modifier_params
)
else:
modifier = None
return modifier
Expand Down
4 changes: 0 additions & 4 deletions deepmd/tf/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
DeepEval,
)

from .data_modifier import (
DipoleChargeModifier,
)
from .deep_dipole import (
DeepDipole,
)
Expand Down Expand Up @@ -43,7 +40,6 @@
"DeepPot",
"DeepPotential",
"DeepWFC",
"DipoleChargeModifier",
"EwaldRecp",
"calc_model_devi",
]
2 changes: 1 addition & 1 deletion deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@

# looks ugly...
if self.modifier_type == "dipole_charge":
from deepmd.tf.infer.data_modifier import (
from deepmd.tf.modifier import (

Check warning on line 142 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L142

Added line #L142 was not covered by tests
DipoleChargeModifier,
)

Expand Down
12 changes: 12 additions & 0 deletions deepmd/tf/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_modifier import (
BaseModifier,
)
from .dipole_charge import (
DipoleChargeModifier,
)

__all__ = [
"BaseModifier",
"DipoleChargeModifier",
]
13 changes: 13 additions & 0 deletions deepmd/tf/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.modifier.base_modifier import (
make_base_modifier,
)
from deepmd.tf.infer import (
DeepPot,
)


class BaseModifier(DeepPot, make_base_modifier()):
def __init__(self, *args, **kwargs) -> None:
"""Construct a basic model for different tasks."""
DeepPot.__init__(self, *args, **kwargs)

Check warning on line 13 in deepmd/tf/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/base_modifier.py#L13

Added line #L13 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from deepmd.tf.infer.ewald_recp import (
EwaldRecp,
)
from deepmd.tf.modifier.base_modifier import (
BaseModifier,
)
from deepmd.tf.utils.data import (
DeepmdData,
)
Expand All @@ -25,7 +28,8 @@
)


class DipoleChargeModifier(DeepDipole):
@BaseModifier.register("dipole_charge")
class DipoleChargeModifier(DeepDipole, BaseModifier):
"""Parameters
----------
model_name
Expand All @@ -40,6 +44,12 @@
Splitting parameter of the Ewald sum. Unit: A^{-1}
"""

def __new__(cls, *args, **kwargs):
model_file = kwargs.get("model_name", None)
if model_file is None:
raise TypeError("Missing required argument: 'model_name'")

Check warning on line 50 in deepmd/tf/modifier/dipole_charge.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/dipole_charge.py#L50

Added line #L50 was not covered by tests
return super().__new__(cls, model_file)

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -82,6 +92,44 @@
self.force = None
self.ntypes = len(self.sel_a)

def serialize(self) -> dict:
"""Serialize the modifier.
Returns
-------
dict
The serialized data
"""
data = {

Check warning on line 103 in deepmd/tf/modifier/dipole_charge.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/dipole_charge.py#L103

Added line #L103 was not covered by tests
"@class": "Modifier",
"type": self.modifier_prefix,
"@version": 3,
"model_name": self.model_name,
"model_charge_map": self.model_charge_map,
"sys_charge_map": self.sys_charge_map,
"ewald_h": self.ewald_h,
"ewald_beta": self.ewald_beta,
}
return data

Check warning on line 113 in deepmd/tf/modifier/dipole_charge.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/dipole_charge.py#L113

Added line #L113 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
"""Deserialize the modifier.
Parameters
----------
data : dict
The serialized data
Returns
-------
BaseModel
The deserialized modifier
"""
data = data.copy()
modifier = cls(**data)
return modifier

Check warning on line 131 in deepmd/tf/modifier/dipole_charge.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/dipole_charge.py#L129-L131

Added lines #L129 - L131 were not covered by tests

def build_fv_graph(self) -> tf.Tensor:
"""Build the computational graph for the force and virial inference."""
with tf.variable_scope("modifier_attr"):
Expand Down
13 changes: 7 additions & 6 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has number of frames lines and number of energy grid columns (times number of atoms in `atom_dos.npy`). See `loss` parameter."
doc_dipole = "Fit an atomic dipole model. Global dipole labels or atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file either has number of frames lines and 3 times of number of selected atoms columns, or has number of frames lines and 3 columns. See `loss` parameter."
doc_polar = "Fit an atomic polarizability model. Global polarizazbility labels or atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file with has number of frames lines and 9 times of number of selected atoms columns, or has number of frames lines and 9 columns. See `loss` parameter."
# modifier
doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction."


def list_to_doc(xx):
Expand Down Expand Up @@ -1782,6 +1784,10 @@ def fitting_variant_type_args():


# --- Modifier configurations: --- #
modifier_args_plugin = ArgsPlugin()


@modifier_args_plugin.register("dipole_charge", doc=doc_dipole_charge)
def modifier_dipole_charge():
doc_model_name = "The name of the frozen dipole model file."
doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model[standard]/fitting_net[dipole]/sel_type')}. "
Expand All @@ -1802,14 +1808,9 @@ def modifier_dipole_charge():

def modifier_variant_type_args():
doc_modifier_type = "The type of modifier."
doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction."
return Variant(
"type",
[
Argument(
"dipole_charge", dict, modifier_dipole_charge(), doc=doc_dipole_charge
),
],
modifier_args_plugin.get_all_argument(),
optional=False,
doc=doc_modifier_type,
)
Expand Down
12 changes: 6 additions & 6 deletions source/tests/tf/test_data_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
GLOBAL_NP_FLOAT_PRECISION,
tf,
)
from deepmd.tf.infer.data_modifier import (
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
from deepmd.tf.train.run_options import (
Expand Down Expand Up @@ -97,11 +97,11 @@ def test_fv(self) -> None:

def _test_fv(self) -> None:
dcm = DipoleChargeModifier(
str(tests_path / os.path.join(modifier_datapath, "dipole.pb")),
[-8],
[6, 1],
1,
0.25,
model_name=str(tests_path / os.path.join(modifier_datapath, "dipole.pb")),
model_charge_map=[-8],
sys_charge_map=[6, 1],
ewald_h=1,
ewald_beta=0.25,
)
data = Data()
coord, box, atype = data.get_data()
Expand Down
16 changes: 8 additions & 8 deletions source/tests/tf/test_data_modifier_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
GLOBAL_NP_FLOAT_PRECISION,
tf,
)
from deepmd.tf.infer.data_modifier import (
DipoleChargeModifier,
)
from deepmd.tf.infer.deep_dipole import (
DeepDipole,
)
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
from deepmd.tf.train.run_options import (
RunOptions,
)
Expand Down Expand Up @@ -197,11 +197,11 @@ def test_z_dipole(self) -> None:

def test_modify(self) -> None:
dcm = DipoleChargeModifier(
os.path.join(modifier_datapath, "dipole.pb"),
[-1, -3],
[1, 1, 1, 1, 1],
1,
0.25,
model_name=os.path.join(modifier_datapath, "dipole.pb"),
model_charge_map=[-1, -3],
sys_charge_map=[1, 1, 1, 1, 1],
ewald_h=1,
ewald_beta=0.25,
)
ve0, vf0, vv0 = dcm.eval(self.coords0, self.box0, self.atom_types0)
ve1, vf1, vv1 = dcm.eval(self.coords1, self.box1, self.atom_types1)
Expand Down
8 changes: 6 additions & 2 deletions source/tests/tf/test_dipolecharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from deepmd.tf.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.tf.infer import (
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
from deepmd.tf.utils.convert import (
Expand All @@ -32,7 +32,11 @@ def setUpClass(cls) -> None:
"dipolecharge_d.pb",
)
cls.dp = DipoleChargeModifier(
"dipolecharge_d.pb", [-1.0, -3.0], [1.0, 1.0, 1.0, 1.0, 1.0], 4.0, 0.2
model_name="dipolecharge_d.pb",
model_charge_map=[-1.0, -3.0],
sys_charge_map=[1.0, 1.0, 1.0, 1.0, 1.0],
ewald_h=4.0,
ewald_beta=0.2,
)

def setUp(self) -> None:
Expand Down