diff --git a/deepmd/dpmodel/modifier/__init__.py b/deepmd/dpmodel/modifier/__init__.py new file mode 100644 index 0000000000..d4e8ab56e3 --- /dev/null +++ b/deepmd/dpmodel/modifier/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .base_modifier import ( + make_base_modifier, +) + +__all__ = [ + "make_base_modifier", +] diff --git a/deepmd/dpmodel/modifier/base_modifier.py b/deepmd/dpmodel/modifier/base_modifier.py new file mode 100644 index 0000000000..9edc4722e1 --- /dev/null +++ b/deepmd/dpmodel/modifier/base_modifier.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import inspect +from abc import ( + ABC, + abstractmethod, +) + +from deepmd.utils.plugin import ( + PluginVariant, + make_plugin_registry, +) + + +def make_base_modifier() -> type[object]: + class BaseModifier(ABC, PluginVariant, make_plugin_registry("modifier")): + """Base class for data modifier.""" + + def __new__(cls, *args, **kwargs): + if cls is BaseModifier: + cls = cls.get_class_by_type(kwargs["type"]) + return super().__new__(cls) + + @abstractmethod + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + pass + + @classmethod + def deserialize(cls, data: dict) -> "BaseModifier": + """Deserialize the modifier. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModel + The deserialized modifier + """ + if inspect.isabstract(cls): + return cls.get_class_by_type(data["type"]).deserialize(data) + raise NotImplementedError(f"Not implemented in class {cls.__name__}") + + @classmethod + def get_modifier(cls, modifier_params: dict) -> "BaseModifier": + """Get the modifier by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + modifier_params : dict + The modifier parameters + + Returns + ------- + BaseModifier + The modifier + """ + modifier_params = modifier_params.copy() + modifier_params.pop("type", None) + modifier = cls(**modifier_params) + return modifier + + return BaseModifier diff --git a/deepmd/tf/__init__.py b/deepmd/tf/__init__.py index 933729fde2..cca5e54e7a 100644 --- a/deepmd/tf/__init__.py +++ b/deepmd/tf/__init__.py @@ -21,7 +21,7 @@ DeepEval, DeepPotential, ) -from .infer.data_modifier import ( +from .modifier import ( DipoleChargeModifier, ) diff --git a/deepmd/tf/entrypoints/train.py b/deepmd/tf/entrypoints/train.py index 1762f1049a..b12e4fe1af 100755 --- a/deepmd/tf/entrypoints/train.py +++ b/deepmd/tf/entrypoints/train.py @@ -4,6 +4,7 @@ Can handle local or distributed training. """ +import copy import json import logging import time @@ -20,12 +21,12 @@ reset_default_tf_session_config, tf, ) -from deepmd.tf.infer.data_modifier import ( - DipoleChargeModifier, -) from deepmd.tf.model.model import ( Model, ) +from deepmd.tf.modifier import ( + BaseModifier, +) from deepmd.tf.train.run_options import ( RunOptions, ) @@ -275,18 +276,13 @@ def _do_work( def get_modifier(modi_data=None): - modifier: Optional[DipoleChargeModifier] + modifier: Optional[BaseModifier] if modi_data is not None: - if modi_data["type"] == "dipole_charge": - modifier = DipoleChargeModifier( - modi_data["model_name"], - modi_data["model_charge_map"], - modi_data["sys_charge_map"], - modi_data["ewald_h"], - modi_data["ewald_beta"], - ) - else: - raise RuntimeError("unknown modifier type " + str(modi_data["type"])) + modifier_params = copy.deepcopy(modi_data) + modifier_type = modifier_params.pop("type") + modifier = BaseModifier.get_class_by_type(modifier_type).get_modifier( + modifier_params + ) else: modifier = None return modifier diff --git a/deepmd/tf/infer/__init__.py b/deepmd/tf/infer/__init__.py index de8a77976e..ca9464ec43 100644 --- a/deepmd/tf/infer/__init__.py +++ b/deepmd/tf/infer/__init__.py @@ -8,9 +8,6 @@ DeepEval, ) -from .data_modifier import ( - DipoleChargeModifier, -) from .deep_dipole import ( DeepDipole, ) @@ -43,7 +40,6 @@ "DeepPot", "DeepPotential", "DeepWFC", - "DipoleChargeModifier", "EwaldRecp", "calc_model_devi", ] diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index 8e5e3deea5..d594a47115 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -139,7 +139,7 @@ def __init__( # looks ugly... if self.modifier_type == "dipole_charge": - from deepmd.tf.infer.data_modifier import ( + from deepmd.tf.modifier import ( DipoleChargeModifier, ) diff --git a/deepmd/tf/modifier/__init__.py b/deepmd/tf/modifier/__init__.py new file mode 100644 index 0000000000..2441b693bc --- /dev/null +++ b/deepmd/tf/modifier/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .base_modifier import ( + BaseModifier, +) +from .dipole_charge import ( + DipoleChargeModifier, +) + +__all__ = [ + "BaseModifier", + "DipoleChargeModifier", +] diff --git a/deepmd/tf/modifier/base_modifier.py b/deepmd/tf/modifier/base_modifier.py new file mode 100644 index 0000000000..4e214e0835 --- /dev/null +++ b/deepmd/tf/modifier/base_modifier.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.modifier.base_modifier import ( + make_base_modifier, +) +from deepmd.tf.infer import ( + DeepPot, +) + + +class BaseModifier(DeepPot, make_base_modifier()): + def __init__(self, *args, **kwargs) -> None: + """Construct a basic model for different tasks.""" + DeepPot.__init__(self, *args, **kwargs) diff --git a/deepmd/tf/infer/data_modifier.py b/deepmd/tf/modifier/dipole_charge.py similarity index 93% rename from deepmd/tf/infer/data_modifier.py rename to deepmd/tf/modifier/dipole_charge.py index ddb1af68d7..d40c9ccd2f 100644 --- a/deepmd/tf/infer/data_modifier.py +++ b/deepmd/tf/modifier/dipole_charge.py @@ -17,6 +17,9 @@ from deepmd.tf.infer.ewald_recp import ( EwaldRecp, ) +from deepmd.tf.modifier.base_modifier import ( + BaseModifier, +) from deepmd.tf.utils.data import ( DeepmdData, ) @@ -25,7 +28,8 @@ ) -class DipoleChargeModifier(DeepDipole): +@BaseModifier.register("dipole_charge") +class DipoleChargeModifier(DeepDipole, BaseModifier): """Parameters ---------- model_name @@ -40,6 +44,9 @@ class DipoleChargeModifier(DeepDipole): Splitting parameter of the Ewald sum. Unit: A^{-1} """ + def __new__(cls, *args, model_name=None, **kwargs): + return super().__new__(cls, model_name) + def __init__( self, model_name: str, @@ -82,6 +89,44 @@ def __init__( self.force = None self.ntypes = len(self.sel_a) + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + data = { + "@class": "Modifier", + "type": self.modifier_prefix, + "@version": 3, + "model_name": self.model_name, + "model_charge_map": self.model_charge_map, + "sys_charge_map": self.sys_charge_map, + "ewald_h": self.ewald_h, + "ewald_beta": self.ewald_beta, + } + return data + + @classmethod + def deserialize(cls, data: dict) -> "BaseModifier": + """Deserialize the modifier. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModel + The deserialized modifier + """ + data = data.copy() + modifier = cls(**data) + return modifier + def build_fv_graph(self) -> tf.Tensor: """Build the computational graph for the force and virial inference.""" with tf.variable_scope("modifier_attr"): diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 32c0766265..47071066ae 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -55,6 +55,8 @@ doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has number of frames lines and number of energy grid columns (times number of atoms in `atom_dos.npy`). See `loss` parameter." doc_dipole = "Fit an atomic dipole model. Global dipole labels or atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file either has number of frames lines and 3 times of number of selected atoms columns, or has number of frames lines and 3 columns. See `loss` parameter." doc_polar = "Fit an atomic polarizability model. Global polarizazbility labels or atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file with has number of frames lines and 9 times of number of selected atoms columns, or has number of frames lines and 9 columns. See `loss` parameter." +# modifier +doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction." def list_to_doc(xx): @@ -2015,6 +2017,10 @@ def fitting_variant_type_args(): # --- Modifier configurations: --- # +modifier_args_plugin = ArgsPlugin() + + +@modifier_args_plugin.register("dipole_charge", doc=doc_dipole_charge) def modifier_dipole_charge(): doc_model_name = "The name of the frozen dipole model file." doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model[standard]/fitting_net[dipole]/sel_type')}. " @@ -2035,14 +2041,9 @@ def modifier_dipole_charge(): def modifier_variant_type_args(): doc_modifier_type = "The type of modifier." - doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction." return Variant( "type", - [ - Argument( - "dipole_charge", dict, modifier_dipole_charge(), doc=doc_dipole_charge - ), - ], + modifier_args_plugin.get_all_argument(), optional=False, doc=doc_modifier_type, ) diff --git a/source/tests/tf/test_data_modifier.py b/source/tests/tf/test_data_modifier.py index db11fa5c2d..7a7558793e 100644 --- a/source/tests/tf/test_data_modifier.py +++ b/source/tests/tf/test_data_modifier.py @@ -7,7 +7,7 @@ GLOBAL_NP_FLOAT_PRECISION, tf, ) -from deepmd.tf.infer.data_modifier import ( +from deepmd.tf.modifier import ( DipoleChargeModifier, ) from deepmd.tf.train.run_options import ( @@ -97,11 +97,11 @@ def test_fv(self) -> None: def _test_fv(self) -> None: dcm = DipoleChargeModifier( - str(tests_path / os.path.join(modifier_datapath, "dipole.pb")), - [-8], - [6, 1], - 1, - 0.25, + model_name=str(tests_path / os.path.join(modifier_datapath, "dipole.pb")), + model_charge_map=[-8], + sys_charge_map=[6, 1], + ewald_h=1, + ewald_beta=0.25, ) data = Data() coord, box, atype = data.get_data() diff --git a/source/tests/tf/test_data_modifier_shuffle.py b/source/tests/tf/test_data_modifier_shuffle.py index 002b4f5746..49b46ead6a 100644 --- a/source/tests/tf/test_data_modifier_shuffle.py +++ b/source/tests/tf/test_data_modifier_shuffle.py @@ -8,12 +8,12 @@ GLOBAL_NP_FLOAT_PRECISION, tf, ) -from deepmd.tf.infer.data_modifier import ( - DipoleChargeModifier, -) from deepmd.tf.infer.deep_dipole import ( DeepDipole, ) +from deepmd.tf.modifier import ( + DipoleChargeModifier, +) from deepmd.tf.train.run_options import ( RunOptions, ) @@ -197,11 +197,11 @@ def test_z_dipole(self) -> None: def test_modify(self) -> None: dcm = DipoleChargeModifier( - os.path.join(modifier_datapath, "dipole.pb"), - [-1, -3], - [1, 1, 1, 1, 1], - 1, - 0.25, + model_name=os.path.join(modifier_datapath, "dipole.pb"), + model_charge_map=[-1, -3], + sys_charge_map=[1, 1, 1, 1, 1], + ewald_h=1, + ewald_beta=0.25, ) ve0, vf0, vv0 = dcm.eval(self.coords0, self.box0, self.atom_types0) ve1, vf1, vv1 = dcm.eval(self.coords1, self.box1, self.atom_types1) diff --git a/source/tests/tf/test_dipolecharge.py b/source/tests/tf/test_dipolecharge.py index d4ad3254bc..71c46446f6 100644 --- a/source/tests/tf/test_dipolecharge.py +++ b/source/tests/tf/test_dipolecharge.py @@ -7,7 +7,7 @@ from deepmd.tf.env import ( GLOBAL_NP_FLOAT_PRECISION, ) -from deepmd.tf.infer import ( +from deepmd.tf.modifier import ( DipoleChargeModifier, ) from deepmd.tf.utils.convert import ( @@ -32,7 +32,11 @@ def setUpClass(cls) -> None: "dipolecharge_d.pb", ) cls.dp = DipoleChargeModifier( - "dipolecharge_d.pb", [-1.0, -3.0], [1.0, 1.0, 1.0, 1.0, 1.0], 4.0, 0.2 + model_name="dipolecharge_d.pb", + model_charge_map=[-1.0, -3.0], + sys_charge_map=[1.0, 1.0, 1.0, 1.0, 1.0], + ewald_h=4.0, + ewald_beta=0.2, ) def setUp(self) -> None: