Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/dpmodel/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_modifier import (
make_base_modifier,
)

__all__ = [
"make_base_modifier",
]
74 changes: 74 additions & 0 deletions deepmd/dpmodel/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
from abc import (
ABC,
abstractmethod,
)

from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)


def make_base_modifier() -> type[object]:
class BaseModifier(ABC, PluginVariant, make_plugin_registry("modifier")):
"""Base class for data modifier."""

def __new__(cls, *args, **kwargs):
if cls is BaseModifier:
cls = cls.get_class_by_type(kwargs["type"])

Check warning on line 20 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L20

Added line #L20 was not covered by tests
return super().__new__(cls)

@abstractmethod
def serialize(self) -> dict:
"""Serialize the modifier.

Returns
-------
dict
The serialized data
"""
pass

Check warning on line 32 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L32

Added line #L32 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
"""Deserialize the modifier.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized modifier
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError(f"Not implemented in class {cls.__name__}")

Check warning on line 50 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L48-L50

Added lines #L48 - L50 were not covered by tests

@classmethod
def get_modifier(cls, modifier_params: dict) -> "BaseModifier":
"""Get the modifier by the parameters.

By default, all the parameters are directly passed to the constructor.
If not, override this method.

Parameters
----------
modifier_params : dict
The modifier parameters

Returns
-------
BaseModifier
The modifier
"""
modifier_params = modifier_params.copy()
modifier_params.pop("type", None)
modifier = cls(**modifier_params)
return modifier

Check warning on line 72 in deepmd/dpmodel/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/modifier/base_modifier.py#L69-L72

Added lines #L69 - L72 were not covered by tests

return BaseModifier
2 changes: 1 addition & 1 deletion deepmd/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DeepEval,
DeepPotential,
)
from .infer.data_modifier import (
from .modifier import (
DipoleChargeModifier,
)

Expand Down
24 changes: 10 additions & 14 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Can handle local or distributed training.
"""

import copy
import json
import logging
import time
Expand All @@ -20,12 +21,12 @@
reset_default_tf_session_config,
tf,
)
from deepmd.tf.infer.data_modifier import (
DipoleChargeModifier,
)
from deepmd.tf.model.model import (
Model,
)
from deepmd.tf.modifier import (
BaseModifier,
)
from deepmd.tf.train.run_options import (
RunOptions,
)
Expand Down Expand Up @@ -275,18 +276,13 @@


def get_modifier(modi_data=None):
modifier: Optional[DipoleChargeModifier]
modifier: Optional[BaseModifier]
if modi_data is not None:
if modi_data["type"] == "dipole_charge":
modifier = DipoleChargeModifier(
modi_data["model_name"],
modi_data["model_charge_map"],
modi_data["sys_charge_map"],
modi_data["ewald_h"],
modi_data["ewald_beta"],
)
else:
raise RuntimeError("unknown modifier type " + str(modi_data["type"]))
modifier_params = copy.deepcopy(modi_data)
modifier_type = modifier_params.pop("type")
modifier = BaseModifier.get_class_by_type(modifier_type).get_modifier(

Check warning on line 283 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L281-L283

Added lines #L281 - L283 were not covered by tests
modifier_params
)
else:
modifier = None
return modifier
Expand Down
4 changes: 0 additions & 4 deletions deepmd/tf/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
DeepEval,
)

from .data_modifier import (
DipoleChargeModifier,
)
from .deep_dipole import (
DeepDipole,
)
Expand Down Expand Up @@ -43,7 +40,6 @@
"DeepPot",
"DeepPotential",
"DeepWFC",
"DipoleChargeModifier",
"EwaldRecp",
"calc_model_devi",
]
2 changes: 1 addition & 1 deletion deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@

# looks ugly...
if self.modifier_type == "dipole_charge":
from deepmd.tf.infer.data_modifier import (
from deepmd.tf.modifier import (

Check warning on line 142 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L142

Added line #L142 was not covered by tests
DipoleChargeModifier,
)

Expand Down
12 changes: 12 additions & 0 deletions deepmd/tf/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_modifier import (
BaseModifier,
)
from .dipole_charge import (
DipoleChargeModifier,
)

__all__ = [
"BaseModifier",
"DipoleChargeModifier",
]
13 changes: 13 additions & 0 deletions deepmd/tf/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.modifier.base_modifier import (
make_base_modifier,
)
from deepmd.tf.infer import (
DeepPot,
)


class BaseModifier(DeepPot, make_base_modifier()):
def __init__(self, *args, **kwargs) -> None:
"""Construct a basic model for different tasks."""
DeepPot.__init__(self, *args, **kwargs)

Check warning on line 13 in deepmd/tf/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/base_modifier.py#L13

Added line #L13 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from deepmd.tf.infer.ewald_recp import (
EwaldRecp,
)
from deepmd.tf.modifier.base_modifier import (
BaseModifier,
)
from deepmd.tf.utils.data import (
DeepmdData,
)
Expand All @@ -25,7 +28,8 @@
)


class DipoleChargeModifier(DeepDipole):
@BaseModifier.register("dipole_charge")
class DipoleChargeModifier(DeepDipole, BaseModifier):
"""Parameters
----------
model_name
Expand All @@ -40,6 +44,9 @@
Splitting parameter of the Ewald sum. Unit: A^{-1}
"""

def __new__(cls, *args, model_name=None, **kwargs):
return super().__new__(cls, model_name)

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -82,6 +89,44 @@
self.force = None
self.ntypes = len(self.sel_a)

def serialize(self) -> dict:
"""Serialize the modifier.

Returns
-------
dict
The serialized data
"""
data = {

Check warning on line 100 in deepmd/tf/modifier/dipole_charge.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/dipole_charge.py#L100

Added line #L100 was not covered by tests
"@class": "Modifier",
"type": self.modifier_prefix,
"@version": 3,
"model_name": self.model_name,
"model_charge_map": self.model_charge_map,
"sys_charge_map": self.sys_charge_map,
"ewald_h": self.ewald_h,
"ewald_beta": self.ewald_beta,
}
return data

Check warning on line 110 in deepmd/tf/modifier/dipole_charge.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/dipole_charge.py#L110

Added line #L110 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
"""Deserialize the modifier.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized modifier
"""
data = data.copy()
modifier = cls(**data)
return modifier

Check warning on line 128 in deepmd/tf/modifier/dipole_charge.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/modifier/dipole_charge.py#L126-L128

Added lines #L126 - L128 were not covered by tests

def build_fv_graph(self) -> tf.Tensor:
"""Build the computational graph for the force and virial inference."""
with tf.variable_scope("modifier_attr"):
Expand Down
13 changes: 7 additions & 6 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has number of frames lines and number of energy grid columns (times number of atoms in `atom_dos.npy`). See `loss` parameter."
doc_dipole = "Fit an atomic dipole model. Global dipole labels or atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file either has number of frames lines and 3 times of number of selected atoms columns, or has number of frames lines and 3 columns. See `loss` parameter."
doc_polar = "Fit an atomic polarizability model. Global polarizazbility labels or atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file with has number of frames lines and 9 times of number of selected atoms columns, or has number of frames lines and 9 columns. See `loss` parameter."
# modifier
doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction."


def list_to_doc(xx):
Expand Down Expand Up @@ -2015,6 +2017,10 @@ def fitting_variant_type_args():


# --- Modifier configurations: --- #
modifier_args_plugin = ArgsPlugin()


@modifier_args_plugin.register("dipole_charge", doc=doc_dipole_charge)
def modifier_dipole_charge():
doc_model_name = "The name of the frozen dipole model file."
doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model[standard]/fitting_net[dipole]/sel_type')}. "
Expand All @@ -2035,14 +2041,9 @@ def modifier_dipole_charge():

def modifier_variant_type_args():
doc_modifier_type = "The type of modifier."
doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction."
return Variant(
"type",
[
Argument(
"dipole_charge", dict, modifier_dipole_charge(), doc=doc_dipole_charge
),
],
modifier_args_plugin.get_all_argument(),
optional=False,
doc=doc_modifier_type,
)
Expand Down
12 changes: 6 additions & 6 deletions source/tests/tf/test_data_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
GLOBAL_NP_FLOAT_PRECISION,
tf,
)
from deepmd.tf.infer.data_modifier import (
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
from deepmd.tf.train.run_options import (
Expand Down Expand Up @@ -97,11 +97,11 @@ def test_fv(self) -> None:

def _test_fv(self) -> None:
dcm = DipoleChargeModifier(
str(tests_path / os.path.join(modifier_datapath, "dipole.pb")),
[-8],
[6, 1],
1,
0.25,
model_name=str(tests_path / os.path.join(modifier_datapath, "dipole.pb")),
model_charge_map=[-8],
sys_charge_map=[6, 1],
ewald_h=1,
ewald_beta=0.25,
)
data = Data()
coord, box, atype = data.get_data()
Expand Down
16 changes: 8 additions & 8 deletions source/tests/tf/test_data_modifier_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
GLOBAL_NP_FLOAT_PRECISION,
tf,
)
from deepmd.tf.infer.data_modifier import (
DipoleChargeModifier,
)
from deepmd.tf.infer.deep_dipole import (
DeepDipole,
)
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
from deepmd.tf.train.run_options import (
RunOptions,
)
Expand Down Expand Up @@ -197,11 +197,11 @@ def test_z_dipole(self) -> None:

def test_modify(self) -> None:
dcm = DipoleChargeModifier(
os.path.join(modifier_datapath, "dipole.pb"),
[-1, -3],
[1, 1, 1, 1, 1],
1,
0.25,
model_name=os.path.join(modifier_datapath, "dipole.pb"),
model_charge_map=[-1, -3],
sys_charge_map=[1, 1, 1, 1, 1],
ewald_h=1,
ewald_beta=0.25,
)
ve0, vf0, vv0 = dcm.eval(self.coords0, self.box0, self.atom_types0)
ve1, vf1, vv1 = dcm.eval(self.coords1, self.box1, self.atom_types1)
Expand Down
8 changes: 6 additions & 2 deletions source/tests/tf/test_dipolecharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from deepmd.tf.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.tf.infer import (
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
from deepmd.tf.utils.convert import (
Expand All @@ -32,7 +32,11 @@ def setUpClass(cls) -> None:
"dipolecharge_d.pb",
)
cls.dp = DipoleChargeModifier(
"dipolecharge_d.pb", [-1.0, -3.0], [1.0, 1.0, 1.0, 1.0, 1.0], 4.0, 0.2
model_name="dipolecharge_d.pb",
model_charge_map=[-1.0, -3.0],
sys_charge_map=[1.0, 1.0, 1.0, 1.0, 1.0],
ewald_h=4.0,
ewald_beta=0.2,
)

def setUp(self) -> None:
Expand Down