From e906fac5ecb2d142a7b27fa3ee1a1042bee08998 Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 14:17:44 +0800 Subject: [PATCH 01/29] fix: zbl mix type model --- examples/water/zbl/input.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/water/zbl/input.json b/examples/water/zbl/input.json index cb5602d92d..1c951f3de5 100644 --- a/examples/water/zbl/input.json +++ b/examples/water/zbl/input.json @@ -10,7 +10,7 @@ "H" ], "descriptor": { - "type": "se_e2_a", + "type": "se_atten", "sel": [ 46, 92 From 91cf86123689e28e45e939badc976f8be57c7235 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:34:20 +0800 Subject: [PATCH 02/29] feat: add linear model --- .../model/atomic_model/linear_atomic_model.py | 35 +++- deepmd/pt/model/model/__init__.py | 52 ++++++ deepmd/pt/model/model/dp_linear_model.py | 160 ++++++++++++++++++ 3 files changed, 242 insertions(+), 5 deletions(-) create mode 100644 deepmd/pt/model/model/dp_linear_model.py diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index d88c4c3af5..ee7baabff3 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -48,12 +48,15 @@ class LinearEnergyAtomicModel(BaseAtomicModel): type_map : list[str] Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. + weights : Optional[Union[str,list[float]]] + Weights of the models. If str, must be `sum` or `mean`. If list, must be a list of float. """ def __init__( self, models: list[BaseAtomicModel], type_map: list[str], + weights: Optional[Union[str,list[float]]]="mean", **kwargs, ): super().__init__(type_map, **kwargs) @@ -89,6 +92,14 @@ def __init__( ) self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE) # pylint: disable=no-explicit-dtype + if isinstance(weights, str): + assert weights in ["sum","mean"] + elif isinstance(weights, list): + assert len(weights) == len(models) + else: + raise ValueError(f"'weights' must be a string ('sum' or 'mean') or a list of float of length {len(models)}.") + self.weights = weights + def mixed_types(self) -> bool: """If true, the model 1. assumes total number of atoms aligned across frames; @@ -334,13 +345,27 @@ def _compute_weight( self, extended_coord, extended_atype, nlists_ ) -> list[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" + nmodels = len(self.models) nframes, nloc, _ = nlists_[0].shape - return [ - torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) - / nmodels - for _ in range(nmodels) - ] + if isinstance(self.weights, str): + if self.weights == "sum": + return [ + torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) + for _ in range(nmodels) + ] + elif self.weights == "mean": + return [ + torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) + / nmodels + for _ in range(nmodels) + ] + elif isinstance(self.weights, list): + return [ + torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) * w + for w in self.weights + ] + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1c81d42013..45f74ea131 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -42,6 +42,9 @@ from .dp_zbl_model import ( DPZBLModel, ) +from .dp_linear_model import( + DPLinearModel +) from .ener_model import ( EnergyModel, ) @@ -104,6 +107,53 @@ def get_spin_model(model_params): backbone_model = get_standard_model(model_params) return SpinEnergyModel(backbone_model=backbone_model, spin=spin) +def get_linear_model(model_params): + model_params = copy.deepcopy(model_params) + weights = model_params.get("weights", "mean") + list_of_models =[] + ntypes = len(model_params["type_map"]) + for sub_model_params in model_params["models"]: + if "descriptor" in sub_model_params: + + # descriptor + sub_model_params["descriptor"]["ntypes"] = ntypes + sub_model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + descriptor = BaseDescriptor(**sub_model_params["descriptor"]) + # fitting + fitting_net = sub_model_params.get("fitting_net", {}) + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(sub_model_params["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + if fitting_net["type"] in ["dipole", "polar"]: + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = BaseFitting(**fitting_net) + list_of_models.append(DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])) + + else: # must be pairtab + assert "type" in sub_model_params and sub_model_params["type"] == "pairtab", "Sub-models in LinearEnergyModel must be a DPModel or a PairTable Model" + list_of_models.append(PairTabAtomicModel( + sub_model_params["tab_file"], + sub_model_params["rcut"], + sub_model_params["sel"], + type_map=model_params["type_map"], + )) + + atom_exclude_types = model_params.get("atom_exclude_types", []) + pair_exclude_types = model_params.get("pair_exclude_types", []) + return DPLinearModel( + models = list_of_models, + type_map=model_params["type_map"], + weights=weights, + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) def get_zbl_model(model_params): model_params = copy.deepcopy(model_params) @@ -247,6 +297,8 @@ def get_model(model_params): return get_zbl_model(model_params) else: return get_standard_model(model_params) + elif model_type == "linear_ener": + return get_linear_model(model_params) else: return BaseModel.get_class_by_type(model_type).get_model(model_params) diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py new file mode 100644 index 0000000000..3c8d872a00 --- /dev/null +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) +from typing import ( + Optional, +) + +import torch + +from deepmd.pt.model.atomic_model import ( + LinearEnergyAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPLinearModel_ = make_model(LinearEnergyAtomicModel) + + +@BaseModel.register("linear_ener") +class DPLinearModel(DPLinearModel_): + model_type = "ener" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def translated_output_def(self): + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": deepcopy(out_def_data["energy"]), + "energy": deepcopy(out_def_data["energy_redu"]), + } + if self.do_grad_r("energy"): + output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"].squeeze(-3) + if "mask" in out_def_data: + output_def["mask"] = deepcopy(out_def_data["mask"]) + return output_def + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + else: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + return model_predict + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( + train_data, type_map, local_jdata["dpmodel"] + ) + return local_jdata_cpy, min_nbor_dist From 7d3044c1a654c495618e4797eb7d6661fead6894 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 07:34:57 +0000 Subject: [PATCH 03/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../model/atomic_model/linear_atomic_model.py | 21 +++++---- deepmd/pt/model/model/__init__.py | 43 +++++++++++-------- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index ee7baabff3..94bdca382c 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -56,7 +56,7 @@ def __init__( self, models: list[BaseAtomicModel], type_map: list[str], - weights: Optional[Union[str,list[float]]]="mean", + weights: Optional[Union[str, list[float]]] = "mean", **kwargs, ): super().__init__(type_map, **kwargs) @@ -93,11 +93,13 @@ def __init__( self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE) # pylint: disable=no-explicit-dtype if isinstance(weights, str): - assert weights in ["sum","mean"] + assert weights in ["sum", "mean"] elif isinstance(weights, list): assert len(weights) == len(models) else: - raise ValueError(f"'weights' must be a string ('sum' or 'mean') or a list of float of length {len(models)}.") + raise ValueError( + f"'weights' must be a string ('sum' or 'mean') or a list of float of length {len(models)}." + ) self.weights = weights def mixed_types(self) -> bool: @@ -345,27 +347,30 @@ def _compute_weight( self, extended_coord, extended_atype, nlists_ ) -> list[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" - nmodels = len(self.models) nframes, nloc, _ = nlists_[0].shape if isinstance(self.weights, str): if self.weights == "sum": return [ - torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) + torch.ones( + (nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE + ) for _ in range(nmodels) ] elif self.weights == "mean": return [ - torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) + torch.ones( + (nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE + ) / nmodels for _ in range(nmodels) ] elif isinstance(self.weights, list): return [ - torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) * w + torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) + * w for w in self.weights ] - def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 45f74ea131..4097fe9036 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -36,15 +36,15 @@ from .dos_model import ( DOSModel, ) +from .dp_linear_model import ( + DPLinearModel, +) from .dp_model import ( DPModelCommon, ) from .dp_zbl_model import ( DPZBLModel, ) -from .dp_linear_model import( - DPLinearModel -) from .ener_model import ( EnergyModel, ) @@ -107,17 +107,19 @@ def get_spin_model(model_params): backbone_model = get_standard_model(model_params) return SpinEnergyModel(backbone_model=backbone_model, spin=spin) + def get_linear_model(model_params): model_params = copy.deepcopy(model_params) weights = model_params.get("weights", "mean") - list_of_models =[] + list_of_models = [] ntypes = len(model_params["type_map"]) for sub_model_params in model_params["models"]: if "descriptor" in sub_model_params: - # descriptor sub_model_params["descriptor"]["ntypes"] = ntypes - sub_model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + sub_model_params["descriptor"]["type_map"] = copy.deepcopy( + model_params["type_map"] + ) descriptor = BaseDescriptor(**sub_model_params["descriptor"]) # fitting fitting_net = sub_model_params.get("fitting_net", {}) @@ -134,27 +136,34 @@ def get_linear_model(model_params): if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True fitting = BaseFitting(**fitting_net) - list_of_models.append(DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])) + list_of_models.append( + DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"]) + ) + + else: # must be pairtab + assert ( + "type" in sub_model_params and sub_model_params["type"] == "pairtab" + ), "Sub-models in LinearEnergyModel must be a DPModel or a PairTable Model" + list_of_models.append( + PairTabAtomicModel( + sub_model_params["tab_file"], + sub_model_params["rcut"], + sub_model_params["sel"], + type_map=model_params["type_map"], + ) + ) - else: # must be pairtab - assert "type" in sub_model_params and sub_model_params["type"] == "pairtab", "Sub-models in LinearEnergyModel must be a DPModel or a PairTable Model" - list_of_models.append(PairTabAtomicModel( - sub_model_params["tab_file"], - sub_model_params["rcut"], - sub_model_params["sel"], - type_map=model_params["type_map"], - )) - atom_exclude_types = model_params.get("atom_exclude_types", []) pair_exclude_types = model_params.get("pair_exclude_types", []) return DPLinearModel( - models = list_of_models, + models=list_of_models, type_map=model_params["type_map"], weights=weights, atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, ) + def get_zbl_model(model_params): model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) From 082ab7417185243da6a5bf46fab9f71a11bc5d7c Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 16:12:58 +0800 Subject: [PATCH 04/29] fix: dftd3 example --- deepmd/pt/model/model/__init__.py | 2 +- deepmd/pt/model/model/dp_linear_model.py | 12 +++++++++--- examples/water/d3/dftd3.txt | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 4097fe9036..aa13386289 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -125,7 +125,7 @@ def get_linear_model(model_params): fitting_net = sub_model_params.get("fitting_net", {}) fitting_net["type"] = fitting_net.get("type", "ener") fitting_net["ntypes"] = descriptor.get_ntypes() - fitting_net["type_map"] = copy.deepcopy(sub_model_params["type_map"]) + fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) fitting_net["mixed_types"] = descriptor.mixed_types() if fitting_net["type"] in ["dipole", "polar"]: fitting_net["embedding_width"] = descriptor.get_dim_emb() diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index 3c8d872a00..804f089560 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -154,7 +154,13 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( - train_data, type_map, local_jdata["dpmodel"] - ) + type_map = local_jdata_cpy["type_map"] + min_nbor_dist = None + for idx,sub_model in enumerate(local_jdata_cpy["models"]): + if "tab_file" not in sub_model: + sub_model, temp_min = DPModelCommon.update_sel( + train_data, type_map, local_jdata["models"][idx] + ) + if min_nbor_dist is None or temp_min <= min_nbor_dist: + min_nbor_dist = temp_min return local_jdata_cpy, min_nbor_dist diff --git a/examples/water/d3/dftd3.txt b/examples/water/d3/dftd3.txt index bbc9726134..09e5fb697a 100644 --- a/examples/water/d3/dftd3.txt +++ b/examples/water/d3/dftd3.txt @@ -97,4 +97,4 @@ 9.700000000000001066e+00 -1.186747936398473687e-05 -7.637113677130612127e-06 -5.528293849956352819e-06 9.800000000000000711e+00 -1.114523618469756001e-05 -7.174288601187318493e-06 -5.194401230658985063e-06 9.900000000000000355e+00 -1.047381249252528874e-05 -6.743886368019750717e-06 -4.883815978498405921e-06 -1.000000000000000000e+01 0.000000000000000e00e+00 0.000000000000000e00e+00 0.000000000000000e00e+00 +1.000000000000000000e+01 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 From 0104e18181549b707f64e0626a21dbcea95c176e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 08:13:35 +0000 Subject: [PATCH 05/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/model/dp_linear_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index 804f089560..79cddcb35d 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -156,7 +156,7 @@ def update_sel( local_jdata_cpy = local_jdata.copy() type_map = local_jdata_cpy["type_map"] min_nbor_dist = None - for idx,sub_model in enumerate(local_jdata_cpy["models"]): + for idx, sub_model in enumerate(local_jdata_cpy["models"]): if "tab_file" not in sub_model: sub_model, temp_min = DPModelCommon.update_sel( train_data, type_map, local_jdata["models"][idx] From 21580e1795cbfa14812a1e743d40a0ccf262043e Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 16:18:00 +0800 Subject: [PATCH 06/29] feat: add pt example --- examples/water/d3/input_pt.json | 96 +++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 examples/water/d3/input_pt.json diff --git a/examples/water/d3/input_pt.json b/examples/water/d3/input_pt.json new file mode 100644 index 0000000000..c2d9304a7e --- /dev/null +++ b/examples/water/d3/input_pt.json @@ -0,0 +1,96 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "linear_ener", + "weights": "sum", + "type_map": [ + "O", + "H" + ], + "models": [ + { + "descriptor": { + "type": "se_atten", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "type_one_side": true, + "precision": "float64", + "seed": 1, + "_comment2": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float64", + "seed": 1, + "_comment3": " that's all" + }, + "_comment4": " that's all" + }, + { + "type": "pairtab", + "tab_file": "dftd3.txt", + "rcut": 10.0, + "sel": 534 + } + ] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment5": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment6": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0/", + "../data/data_1/", + "../data/data_2/" + ], + "batch_size": "auto", + "_comment7": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment8": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment9": "that's all" + }, + "_comment10": "that's all" +} From 08fcb55d52567c17c0b38e583422e167806cd55e Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 16:30:08 +0800 Subject: [PATCH 07/29] fix: jit --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 94bdca382c..597ea7e7c0 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -344,7 +344,10 @@ def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel": return super().deserialize(data) def _compute_weight( - self, extended_coord, extended_atype, nlists_ + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlists_: list[torch.Tensor], ) -> list[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" nmodels = len(self.models) @@ -365,12 +368,16 @@ def _compute_weight( / nmodels for _ in range(nmodels) ] + else: + raise ValueError("`weights` must be 'sum' or 'mean' when provided as a string.") elif isinstance(self.weights, list): return [ torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) * w for w in self.weights ] + else: + raise NotImplementedError def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" From 739670e83b40ee50a09fd044d0402175a98594b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 08:30:45 +0000 Subject: [PATCH 08/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 597ea7e7c0..cba6df7312 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -369,7 +369,9 @@ def _compute_weight( for _ in range(nmodels) ] else: - raise ValueError("`weights` must be 'sum' or 'mean' when provided as a string.") + raise ValueError( + "`weights` must be 'sum' or 'mean' when provided as a string." + ) elif isinstance(self.weights, list): return [ torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE) From 05225623c47efd77b504337d78e18aa52dec41c1 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:00:46 +0800 Subject: [PATCH 09/29] feat: add UTs --- .../universal/common/cases/model/model.py | 18 ++++ source/tests/universal/pt/model/test_model.py | 98 +++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/source/tests/universal/common/cases/model/model.py b/source/tests/universal/common/cases/model/model.py index c31f5cd889..a24bf12751 100644 --- a/source/tests/universal/common/cases/model/model.py +++ b/source/tests/universal/common/cases/model/model.py @@ -27,6 +27,24 @@ def setUpClass(cls) -> None: cls.rprec_dict = {} cls.epsilon_dict = {} +class LinearEnerModelTest(ModelTestCase): + @classmethod + def setUpClass(cls) -> None: + cls.expected_rcut = 5.0 + cls.expected_type_map = ["O", "H"] + cls.expected_dim_fparam = 0 + cls.expected_dim_aparam = 0 + cls.expected_sel_type = [0, 1] + cls.expected_aparam_nall = False + cls.expected_model_output_type = ["energy", "mask"] + cls.model_output_equivariant = [] + cls.expected_sel = [46, 92] + cls.expected_sel_mix = sum(cls.expected_sel) + cls.expected_has_message_passing = False + cls.aprec_dict = {} + cls.rprec_dict = {} + cls.epsilon_dict = {} + class DipoleModelTest(ModelTestCase): @classmethod diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 41df0cf762..a77a21f929 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -20,6 +20,7 @@ DipoleModel, DOSModel, DPZBLModel, + DPLinearModel, EnergyModel, PolarModel, PropertyModel, @@ -47,6 +48,7 @@ PropertyModelTest, SpinEnerModelTest, ZBLModelTest, + LinearEnerModelTest, ) from ...dpmodel.descriptor.test_descriptor import ( DescriptorParamDPA1, @@ -803,3 +805,99 @@ def setUpClass(cls): cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybridMixed, DescrptHybrid), + (DescriptorParamHybridMixedTTebd, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], + ), # fitting_class_param & class + ), +) +class TestLinearEnergyModelPT(unittest.TestCase, LinearEnerModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + LinearEnerModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + # set special precision + cls.aprec_dict["test_smooth"] = 1e-5 + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds1,ds2 = Descrpt(**cls.input_dict_ds), Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds1.get_dim_out(), + mixed_types=ds1.mixed_types(), + type_map=cls.expected_type_map, + ) + ft1 = Fitting( + **cls.input_dict_ft, + ) + ft2 = Fitting( + **cls.input_dict_ft, + ) + dp_model1 = DPAtomicModel( + ds1, + ft1, + type_map=cls.expected_type_map, + ) + dp_model2 = DPAtomicModel( + ds2, + ft2, + type_map=cls.expected_type_map, + ) + cls.module = DPLinearModel( + dp_model, + pt_model, + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds1.has_message_passing() + cls.expected_dim_fparam = ft1.get_dim_fparam() + cls.expected_dim_aparam = ft1.get_dim_aparam() \ No newline at end of file From 8b1cb8c1f8dcad4648528a7841a9bd77b1055624 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:02:34 +0000 Subject: [PATCH 10/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/universal/common/cases/model/model.py | 1 + source/tests/universal/pt/model/test_model.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/source/tests/universal/common/cases/model/model.py b/source/tests/universal/common/cases/model/model.py index a24bf12751..cee69d9d6c 100644 --- a/source/tests/universal/common/cases/model/model.py +++ b/source/tests/universal/common/cases/model/model.py @@ -27,6 +27,7 @@ def setUpClass(cls) -> None: cls.rprec_dict = {} cls.epsilon_dict = {} + class LinearEnerModelTest(ModelTestCase): @classmethod def setUpClass(cls) -> None: diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index a77a21f929..f17e12dc50 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -19,8 +19,8 @@ from deepmd.pt.model.model import ( DipoleModel, DOSModel, - DPZBLModel, DPLinearModel, + DPZBLModel, EnergyModel, PolarModel, PropertyModel, @@ -44,11 +44,11 @@ DipoleModelTest, DosModelTest, EnerModelTest, + LinearEnerModelTest, PolarModelTest, PropertyModelTest, SpinEnerModelTest, ZBLModelTest, - LinearEnerModelTest, ) from ...dpmodel.descriptor.test_descriptor import ( DescriptorParamDPA1, @@ -806,6 +806,7 @@ def setUpClass(cls): cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + @parameterized( des_parameterized=( ( @@ -860,7 +861,7 @@ def setUpClass(cls): if skiptest: raise cls.skipTest(cls, skip_reason) - ds1,ds2 = Descrpt(**cls.input_dict_ds), Descrpt(**cls.input_dict_ds) + ds1, ds2 = Descrpt(**cls.input_dict_ds), Descrpt(**cls.input_dict_ds) cls.input_dict_ft = FittingParam( ntypes=len(cls.expected_type_map), dim_descrpt=ds1.get_dim_out(), @@ -900,4 +901,4 @@ def setUpClass(cls): cls.output_def = cls.module.translated_output_def() cls.expected_has_message_passing = ds1.has_message_passing() cls.expected_dim_fparam = ft1.get_dim_fparam() - cls.expected_dim_aparam = ft1.get_dim_aparam() \ No newline at end of file + cls.expected_dim_aparam = ft1.get_dim_aparam() From f2753e7ae55ddb902c6dddeae6365bd67000ed8b Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 18:10:26 +0800 Subject: [PATCH 11/29] fix: UTs --- source/tests/universal/pt/model/test_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index f17e12dc50..65bd981ccf 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -885,8 +885,7 @@ def setUpClass(cls): type_map=cls.expected_type_map, ) cls.module = DPLinearModel( - dp_model, - pt_model, + [dp_model1,dp_model2], type_map=cls.expected_type_map, ) # only test jit API once for different models From 63e70175cd666b6e7910c36401d1cfe6e5f489f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:11:18 +0000 Subject: [PATCH 12/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/universal/pt/model/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 65bd981ccf..55fb8b644e 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -885,7 +885,7 @@ def setUpClass(cls): type_map=cls.expected_type_map, ) cls.module = DPLinearModel( - [dp_model1,dp_model2], + [dp_model1, dp_model2], type_map=cls.expected_type_map, ) # only test jit API once for different models From 0514eae707dba85a4985dae1dc521d862ff63b4f Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 19:27:58 +0800 Subject: [PATCH 13/29] fix: UTs --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index cba6df7312..3837ed64d4 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -333,7 +333,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel": data = copy.deepcopy(data) - check_version_compatibility(data.get("@version", 2), 2, 1) + check_version_compatibility(data.pop("@version", 2), 2, 1) data.pop("@class", None) data.pop("type", None) models = [ From f41df5b02f199b44057afcc796ef70c9dd1d0a4e Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 20:34:27 +0800 Subject: [PATCH 14/29] fix: sel type UT --- source/tests/universal/pt/model/test_model.py | 1403 +++++++++-------- 1 file changed, 702 insertions(+), 701 deletions(-) diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 55fb8b644e..261f45d3ec 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -105,706 +105,706 @@ ] -@parameterized( - des_parameterized=( - ( - *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], - *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], - *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], - *[ - (param_func, DescrptSeTTebd) - for param_func in DescriptorParamSeTTebdList - ], - *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], - *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], - (DescriptorParamHybrid, DescrptHybrid), - (DescriptorParamHybridMixed, DescrptHybrid), - (DescriptorParamHybridMixedTTebd, DescrptHybrid), - ), # descrpt_class_param & class - ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class - ), - fit_parameterized=( - ( - (DescriptorParamSeA, DescrptSeA), - (DescriptorParamSeR, DescrptSeR), - (DescriptorParamSeT, DescrptSeT), - (DescriptorParamSeTTebd, DescrptSeTTebd), - (DescriptorParamDPA1, DescrptDPA1), - (DescriptorParamDPA2, DescrptDPA2), - ), # descrpt_class_param & class - ( - *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], - ), # fitting_class_param & class - ), -) -class TestEnergyModelPT(unittest.TestCase, EnerModelTest, PTTestCase): - @property - def modules_to_test(self): - skip_test_jit = getattr(self, "skip_test_jit", False) - modules = PTTestCase.modules_to_test.fget(self) - if not skip_test_jit: - # for Model, we can test script module API - modules += [ - self._script_module - if hasattr(self, "_script_module") - else self.script_module - ] - return modules - - @classmethod - def setUpClass(cls): - EnerModelTest.setUpClass() - (DescriptorParam, Descrpt) = cls.param[0] - (FittingParam, Fitting) = cls.param[1] - # set special precision - if Descrpt in [DescrptDPA2]: - cls.epsilon_dict["test_smooth"] = 1e-8 - if Descrpt in [DescrptSeT, DescrptSeTTebd]: - # computational expensive - cls.expected_sel = [i // 4 for i in cls.expected_sel] - cls.expected_rcut = cls.expected_rcut / 2 - cls.input_dict_ds = DescriptorParam( - len(cls.expected_type_map), - cls.expected_rcut, - cls.expected_rcut / 2, - cls.expected_sel, - cls.expected_type_map, - ) - - # set skip tests - skiptest, skip_reason = skip_model_tests(cls) - if skiptest: - raise cls.skipTest(cls, skip_reason) - - ds = Descrpt(**cls.input_dict_ds) - cls.input_dict_ft = FittingParam( - ntypes=len(cls.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - type_map=cls.expected_type_map, - ) - ft = Fitting( - **cls.input_dict_ft, - ) - cls.module = EnergyModel( - ds, - ft, - type_map=cls.expected_type_map, - ) - # only test jit API once for different models - if ( - DescriptorParam not in defalut_des_param - or FittingParam not in defalut_fit_param - ): - cls.skip_test_jit = True - else: - with torch.jit.optimized_execution(False): - cls._script_module = torch.jit.script(cls.module) - cls.output_def = cls.module.translated_output_def() - cls.expected_has_message_passing = ds.has_message_passing() - cls.expected_sel_type = ft.get_sel_type() - cls.expected_dim_fparam = ft.get_dim_fparam() - cls.expected_dim_aparam = ft.get_dim_aparam() - - -@parameterized( - des_parameterized=( - ( - *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], - *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], - *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], - *[ - (param_func, DescrptSeTTebd) - for param_func in DescriptorParamSeTTebdList - ], - *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], - *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], - (DescriptorParamHybrid, DescrptHybrid), - (DescriptorParamHybridMixed, DescrptHybrid), - (DescriptorParamHybridMixedTTebd, DescrptHybrid), - ), # descrpt_class_param & class - ((FittingParamDos, DOSFittingNet),), # fitting_class_param & class - ), - fit_parameterized=( - ( - (DescriptorParamSeA, DescrptSeA), - (DescriptorParamSeR, DescrptSeR), - (DescriptorParamSeT, DescrptSeT), - (DescriptorParamSeTTebd, DescrptSeTTebd), - (DescriptorParamDPA1, DescrptDPA1), - (DescriptorParamDPA2, DescrptDPA2), - ), # descrpt_class_param & class - ( - *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], - ), # fitting_class_param & class - ), -) -class TestDosModelPT(unittest.TestCase, DosModelTest, PTTestCase): - @property - def modules_to_test(self): - skip_test_jit = getattr(self, "skip_test_jit", False) - modules = PTTestCase.modules_to_test.fget(self) - if not skip_test_jit: - # for Model, we can test script module API - modules += [ - self._script_module - if hasattr(self, "_script_module") - else self.script_module - ] - return modules - - @classmethod - def setUpClass(cls): - DosModelTest.setUpClass() - (DescriptorParam, Descrpt) = cls.param[0] - (FittingParam, Fitting) = cls.param[1] - # set special precision - cls.aprec_dict["test_smooth"] = 1e-4 - if Descrpt in [DescrptDPA2]: - cls.epsilon_dict["test_smooth"] = 1e-8 - if Descrpt in [DescrptSeT, DescrptSeTTebd]: - # computational expensive - cls.expected_sel = [i // 4 for i in cls.expected_sel] - cls.expected_rcut = cls.expected_rcut / 2 - cls.input_dict_ds = DescriptorParam( - len(cls.expected_type_map), - cls.expected_rcut, - cls.expected_rcut / 2, - cls.expected_sel, - cls.expected_type_map, - ) - - # set skip tests - skiptest, skip_reason = skip_model_tests(cls) - if skiptest: - raise cls.skipTest(cls, skip_reason) - - ds = Descrpt(**cls.input_dict_ds) - cls.input_dict_ft = FittingParam( - ntypes=len(cls.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - type_map=cls.expected_type_map, - ) - ft = Fitting( - **cls.input_dict_ft, - ) - cls.module = DOSModel( - ds, - ft, - type_map=cls.expected_type_map, - ) - # only test jit API once for different models - if ( - DescriptorParam not in defalut_des_param - or FittingParam not in defalut_fit_param - ): - cls.skip_test_jit = True - else: - with torch.jit.optimized_execution(False): - cls._script_module = torch.jit.script(cls.module) - cls.output_def = cls.module.translated_output_def() - cls.expected_has_message_passing = ds.has_message_passing() - cls.expected_sel_type = ft.get_sel_type() - cls.expected_dim_fparam = ft.get_dim_fparam() - cls.expected_dim_aparam = ft.get_dim_aparam() - - -@parameterized( - des_parameterized=( - ( - *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], - *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], - *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], - (DescriptorParamHybrid, DescrptHybrid), - (DescriptorParamHybridMixed, DescrptHybrid), - ), # descrpt_class_param & class - ((FittingParamDipole, DipoleFittingNet),), # fitting_class_param & class - ), - fit_parameterized=( - ( - (DescriptorParamSeA, DescrptSeA), - (DescriptorParamDPA1, DescrptDPA1), - (DescriptorParamDPA2, DescrptDPA2), - ), # descrpt_class_param & class - ( - *[(param_func, DipoleFittingNet) for param_func in FittingParamDipoleList], - ), # fitting_class_param & class - ), -) -class TestDipoleModelPT(unittest.TestCase, DipoleModelTest, PTTestCase): - @property - def modules_to_test(self): - skip_test_jit = getattr(self, "skip_test_jit", False) - modules = PTTestCase.modules_to_test.fget(self) - if not skip_test_jit: - # for Model, we can test script module API - modules += [ - self._script_module - if hasattr(self, "_script_module") - else self.script_module - ] - return modules - - @classmethod - def setUpClass(cls): - DipoleModelTest.setUpClass() - (DescriptorParam, Descrpt) = cls.param[0] - (FittingParam, Fitting) = cls.param[1] - # set special precision - if Descrpt in [DescrptDPA2]: - cls.epsilon_dict["test_smooth"] = 1e-8 - cls.aprec_dict["test_forward"] = 1e-10 # for dipole force when near zero - cls.aprec_dict["test_rot"] = 1e-10 # for dipole force when near zero - cls.aprec_dict["test_trans"] = 1e-10 # for dipole force when near zero - cls.aprec_dict["test_permutation"] = 1e-10 # for dipole force when near zero - cls.input_dict_ds = DescriptorParam( - len(cls.expected_type_map), - cls.expected_rcut, - cls.expected_rcut / 2, - cls.expected_sel, - cls.expected_type_map, - ) - - # set skip tests - skiptest, skip_reason = skip_model_tests(cls) - if skiptest: - raise cls.skipTest(cls, skip_reason) - - ds = Descrpt(**cls.input_dict_ds) - cls.input_dict_ft = FittingParam( - ntypes=len(cls.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - type_map=cls.expected_type_map, - embedding_width=ds.get_dim_emb(), - ) - ft = Fitting( - **cls.input_dict_ft, - ) - cls.module = DipoleModel( - ds, - ft, - type_map=cls.expected_type_map, - ) - # only test jit API once for different models - if ( - DescriptorParam not in defalut_des_param - or FittingParam not in defalut_fit_param - ): - cls.skip_test_jit = True - else: - with torch.jit.optimized_execution(False): - cls._script_module = torch.jit.script(cls.module) - cls.output_def = cls.module.translated_output_def() - cls.expected_has_message_passing = ds.has_message_passing() - cls.expected_sel_type = ft.get_sel_type() - cls.expected_dim_fparam = ft.get_dim_fparam() - cls.expected_dim_aparam = ft.get_dim_aparam() - - -@parameterized( - des_parameterized=( - ( - *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], - *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], - *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], - (DescriptorParamHybrid, DescrptHybrid), - (DescriptorParamHybridMixed, DescrptHybrid), - ), # descrpt_class_param & class - ((FittingParamPolar, PolarFittingNet),), # fitting_class_param & class - ), - fit_parameterized=( - ( - (DescriptorParamSeA, DescrptSeA), - (DescriptorParamDPA1, DescrptDPA1), - (DescriptorParamDPA2, DescrptDPA2), - ), # descrpt_class_param & class - ( - *[(param_func, PolarFittingNet) for param_func in FittingParamPolarList], - ), # fitting_class_param & class - ), -) -class TestPolarModelPT(unittest.TestCase, PolarModelTest, PTTestCase): - @property - def modules_to_test(self): - skip_test_jit = getattr(self, "skip_test_jit", False) - modules = PTTestCase.modules_to_test.fget(self) - if not skip_test_jit: - # for Model, we can test script module API - modules += [ - self._script_module - if hasattr(self, "_script_module") - else self.script_module - ] - return modules - - @classmethod - def setUpClass(cls): - PolarModelTest.setUpClass() - (DescriptorParam, Descrpt) = cls.param[0] - (FittingParam, Fitting) = cls.param[1] - # set special precision - if Descrpt in [DescrptDPA2]: - cls.epsilon_dict["test_smooth"] = 1e-8 - cls.input_dict_ds = DescriptorParam( - len(cls.expected_type_map), - cls.expected_rcut, - cls.expected_rcut / 2, - cls.expected_sel, - cls.expected_type_map, - ) - - # set skip tests - skiptest, skip_reason = skip_model_tests(cls) - if skiptest: - raise cls.skipTest(cls, skip_reason) - - ds = Descrpt(**cls.input_dict_ds) - cls.input_dict_ft = FittingParam( - ntypes=len(cls.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - type_map=cls.expected_type_map, - embedding_width=ds.get_dim_emb(), - ) - ft = Fitting( - **cls.input_dict_ft, - ) - cls.module = PolarModel( - ds, - ft, - type_map=cls.expected_type_map, - ) - # only test jit API once for different models - if ( - DescriptorParam not in defalut_des_param - or FittingParam not in defalut_fit_param - ): - cls.skip_test_jit = True - else: - with torch.jit.optimized_execution(False): - cls._script_module = torch.jit.script(cls.module) - cls.output_def = cls.module.translated_output_def() - cls.expected_has_message_passing = ds.has_message_passing() - cls.expected_sel_type = ft.get_sel_type() - cls.expected_dim_fparam = ft.get_dim_fparam() - cls.expected_dim_aparam = ft.get_dim_aparam() - - -@parameterized( - des_parameterized=( - ( - *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], - *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], - (DescriptorParamHybridMixed, DescrptHybrid), - (DescriptorParamHybridMixedTTebd, DescrptHybrid), - ), # descrpt_class_param & class - ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class - ), - fit_parameterized=( - ( - (DescriptorParamDPA1, DescrptDPA1), - (DescriptorParamDPA2, DescrptDPA2), - ), # descrpt_class_param & class - ( - *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], - ), # fitting_class_param & class - ), -) -class TestZBLModelPT(unittest.TestCase, ZBLModelTest, PTTestCase): - @property - def modules_to_test(self): - skip_test_jit = getattr(self, "skip_test_jit", False) - modules = PTTestCase.modules_to_test.fget(self) - if not skip_test_jit: - # for Model, we can test script module API - modules += [ - self._script_module - if hasattr(self, "_script_module") - else self.script_module - ] - return modules - - @classmethod - def setUpClass(cls): - ZBLModelTest.setUpClass() - (DescriptorParam, Descrpt) = cls.param[0] - (FittingParam, Fitting) = cls.param[1] - # set special precision - # zbl weights not so smooth - cls.aprec_dict["test_smooth"] = 5e-2 - cls.input_dict_ds = DescriptorParam( - len(cls.expected_type_map), - cls.expected_rcut, - cls.expected_rcut / 2, - cls.expected_sel, - cls.expected_type_map, - ) - - # set skip tests - skiptest, skip_reason = skip_model_tests(cls) - if skiptest: - raise cls.skipTest(cls, skip_reason) - - ds = Descrpt(**cls.input_dict_ds) - cls.input_dict_ft = FittingParam( - ntypes=len(cls.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - type_map=cls.expected_type_map, - ) - ft = Fitting( - **cls.input_dict_ft, - ) - dp_model = DPAtomicModel( - ds, - ft, - type_map=cls.expected_type_map, - ) - pt_model = PairTabAtomicModel( - cls.tab_file["use_srtab"], - cls.expected_rcut, - cls.expected_sel, - type_map=cls.expected_type_map, - ) - cls.module = DPZBLModel( - dp_model, - pt_model, - sw_rmin=cls.tab_file["sw_rmin"], - sw_rmax=cls.tab_file["sw_rmax"], - smin_alpha=cls.tab_file["smin_alpha"], - type_map=cls.expected_type_map, - ) - # only test jit API once for different models - if ( - DescriptorParam not in defalut_des_param - or FittingParam not in defalut_fit_param - ): - cls.skip_test_jit = True - else: - with torch.jit.optimized_execution(False): - cls._script_module = torch.jit.script(cls.module) - cls.output_def = cls.module.translated_output_def() - cls.expected_has_message_passing = ds.has_message_passing() - cls.expected_dim_fparam = ft.get_dim_fparam() - cls.expected_dim_aparam = ft.get_dim_aparam() - - -@parameterized( - des_parameterized=( - ( - *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], - *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], - *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], - *[ - (param_func, DescrptSeTTebd) - for param_func in DescriptorParamSeTTebdList - ], - *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], - *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], - # (DescriptorParamHybrid, DescrptHybrid), - # unsupported for SpinModel to hybrid both mixed_types and no-mixed_types descriptor - (DescriptorParamHybridMixed, DescrptHybrid), - (DescriptorParamHybridMixedTTebd, DescrptHybrid), - ), # descrpt_class_param & class - ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class - ), - fit_parameterized=( - ( - (DescriptorParamSeA, DescrptSeA), - (DescriptorParamSeR, DescrptSeR), - (DescriptorParamSeT, DescrptSeT), - (DescriptorParamSeTTebd, DescrptSeTTebd), - (DescriptorParamDPA1, DescrptDPA1), - (DescriptorParamDPA2, DescrptDPA2), - ), # descrpt_class_param & class - ( - *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], - ), # fitting_class_param & class - ), -) -class TestSpinEnergyModelDP(unittest.TestCase, SpinEnerModelTest, PTTestCase): - @property - def modules_to_test(self): - skip_test_jit = getattr(self, "skip_test_jit", False) - modules = PTTestCase.modules_to_test.fget(self) - if not skip_test_jit: - # for Model, we can test script module API - modules += [ - self._script_module - if hasattr(self, "_script_module") - else self.script_module - ] - return modules - - @classmethod - def setUpClass(cls): - SpinEnerModelTest.setUpClass() - (DescriptorParam, Descrpt) = cls.param[0] - (FittingParam, Fitting) = cls.param[1] - cls.epsilon_dict["test_smooth"] = 1e-6 - cls.aprec_dict["test_smooth"] = 5e-5 - # set special precision - if Descrpt in [DescrptDPA2, DescrptHybrid]: - cls.epsilon_dict["test_smooth"] = 1e-8 - if Descrpt in [DescrptSeT, DescrptSeTTebd]: - # computational expensive - cls.expected_sel = [i // 4 for i in cls.expected_sel] - cls.expected_rcut = cls.expected_rcut / 2 - - spin = Spin( - use_spin=cls.spin_dict["use_spin"], - virtual_scale=cls.spin_dict["virtual_scale"], - ) - spin_type_map = cls.expected_type_map + [ - item + "_spin" for item in cls.expected_type_map - ] - if Descrpt in [DescrptSeA, DescrptSeR, DescrptSeT]: - spin_sel = cls.expected_sel + cls.expected_sel - else: - spin_sel = cls.expected_sel - pair_exclude_types = spin.get_pair_exclude_types() - atom_exclude_types = spin.get_atom_exclude_types() - cls.input_dict_ds = DescriptorParam( - len(spin_type_map), - cls.expected_rcut, - cls.expected_rcut / 2, - spin_sel, - spin_type_map, - env_protection=1e-6, - exclude_types=pair_exclude_types, - ) - - # set skip tests - skiptest, skip_reason = skip_model_tests(cls) - if skiptest: - raise cls.skipTest(cls, skip_reason) - - ds = Descrpt(**cls.input_dict_ds) - cls.input_dict_ft = FittingParam( - ntypes=len(spin_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - type_map=spin_type_map, - ) - ft = Fitting( - **cls.input_dict_ft, - ) - backbone_model = EnergyModel( - ds, - ft, - type_map=spin_type_map, - atom_exclude_types=atom_exclude_types, - pair_exclude_types=pair_exclude_types, - ) - cls.module = SpinEnergyModel(backbone_model=backbone_model, spin=spin) - # only test jit API once for different models - if ( - DescriptorParam not in defalut_des_param - or FittingParam not in defalut_fit_param - ): - cls.skip_test_jit = True - else: - with torch.jit.optimized_execution(False): - cls._script_module = torch.jit.script(cls.module) - cls.output_def = cls.module.translated_output_def() - cls.expected_has_message_passing = ds.has_message_passing() - cls.expected_sel_type = ft.get_sel_type() - cls.expected_dim_fparam = ft.get_dim_fparam() - cls.expected_dim_aparam = ft.get_dim_aparam() - - -@parameterized( - des_parameterized=( - ( - *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], - *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], - *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], - (DescriptorParamHybrid, DescrptHybrid), - (DescriptorParamHybridMixed, DescrptHybrid), - ), # descrpt_class_param & class - ((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class - ), - fit_parameterized=( - ( - (DescriptorParamSeA, DescrptSeA), - (DescriptorParamDPA1, DescrptDPA1), - (DescriptorParamDPA2, DescrptDPA2), - ), # descrpt_class_param & class - ( - *[ - (param_func, PropertyFittingNet) - for param_func in FittingParamPropertyList - ], - ), # fitting_class_param & class - ), -) -class TestPropertyModelPT(unittest.TestCase, PropertyModelTest, PTTestCase): - @property - def modules_to_test(self): - skip_test_jit = getattr(self, "skip_test_jit", False) - modules = PTTestCase.modules_to_test.fget(self) - if not skip_test_jit: - # for Model, we can test script module API - modules += [ - self._script_module - if hasattr(self, "_script_module") - else self.script_module - ] - return modules - - @classmethod - def setUpClass(cls): - PropertyModelTest.setUpClass() - (DescriptorParam, Descrpt) = cls.param[0] - (FittingParam, Fitting) = cls.param[1] - # set special precision - if Descrpt in [DescrptDPA2]: - cls.epsilon_dict["test_smooth"] = 1e-8 - cls.input_dict_ds = DescriptorParam( - len(cls.expected_type_map), - cls.expected_rcut, - cls.expected_rcut / 2, - cls.expected_sel, - cls.expected_type_map, - ) - - # set skip tests - skiptest, skip_reason = skip_model_tests(cls) - if skiptest: - raise cls.skipTest(cls, skip_reason) - - ds = Descrpt(**cls.input_dict_ds) - cls.input_dict_ft = FittingParam( - ntypes=len(cls.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - type_map=cls.expected_type_map, - embedding_width=ds.get_dim_emb(), - ) - ft = Fitting( - **cls.input_dict_ft, - ) - cls.module = PropertyModel( - ds, - ft, - type_map=cls.expected_type_map, - ) - # only test jit API once for different models - if ( - DescriptorParam not in defalut_des_param - or FittingParam not in defalut_fit_param - ): - cls.skip_test_jit = True - else: - with torch.jit.optimized_execution(False): - cls._script_module = torch.jit.script(cls.module) - cls.output_def = cls.module.translated_output_def() - cls.expected_has_message_passing = ds.has_message_passing() - cls.expected_sel_type = ft.get_sel_type() - cls.expected_dim_fparam = ft.get_dim_fparam() - cls.expected_dim_aparam = ft.get_dim_aparam() +# @parameterized( +# des_parameterized=( +# ( +# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], +# *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], +# *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], +# *[ +# (param_func, DescrptSeTTebd) +# for param_func in DescriptorParamSeTTebdList +# ], +# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], +# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], +# (DescriptorParamHybrid, DescrptHybrid), +# (DescriptorParamHybridMixed, DescrptHybrid), +# (DescriptorParamHybridMixedTTebd, DescrptHybrid), +# ), # descrpt_class_param & class +# ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class +# ), +# fit_parameterized=( +# ( +# (DescriptorParamSeA, DescrptSeA), +# (DescriptorParamSeR, DescrptSeR), +# (DescriptorParamSeT, DescrptSeT), +# (DescriptorParamSeTTebd, DescrptSeTTebd), +# (DescriptorParamDPA1, DescrptDPA1), +# (DescriptorParamDPA2, DescrptDPA2), +# ), # descrpt_class_param & class +# ( +# *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], +# ), # fitting_class_param & class +# ), +# ) +# class TestEnergyModelPT(unittest.TestCase, EnerModelTest, PTTestCase): +# @property +# def modules_to_test(self): +# skip_test_jit = getattr(self, "skip_test_jit", False) +# modules = PTTestCase.modules_to_test.fget(self) +# if not skip_test_jit: +# # for Model, we can test script module API +# modules += [ +# self._script_module +# if hasattr(self, "_script_module") +# else self.script_module +# ] +# return modules + +# @classmethod +# def setUpClass(cls): +# EnerModelTest.setUpClass() +# (DescriptorParam, Descrpt) = cls.param[0] +# (FittingParam, Fitting) = cls.param[1] +# # set special precision +# if Descrpt in [DescrptDPA2]: +# cls.epsilon_dict["test_smooth"] = 1e-8 +# if Descrpt in [DescrptSeT, DescrptSeTTebd]: +# # computational expensive +# cls.expected_sel = [i // 4 for i in cls.expected_sel] +# cls.expected_rcut = cls.expected_rcut / 2 +# cls.input_dict_ds = DescriptorParam( +# len(cls.expected_type_map), +# cls.expected_rcut, +# cls.expected_rcut / 2, +# cls.expected_sel, +# cls.expected_type_map, +# ) + +# # set skip tests +# skiptest, skip_reason = skip_model_tests(cls) +# if skiptest: +# raise cls.skipTest(cls, skip_reason) + +# ds = Descrpt(**cls.input_dict_ds) +# cls.input_dict_ft = FittingParam( +# ntypes=len(cls.expected_type_map), +# dim_descrpt=ds.get_dim_out(), +# mixed_types=ds.mixed_types(), +# type_map=cls.expected_type_map, +# ) +# ft = Fitting( +# **cls.input_dict_ft, +# ) +# cls.module = EnergyModel( +# ds, +# ft, +# type_map=cls.expected_type_map, +# ) +# # only test jit API once for different models +# if ( +# DescriptorParam not in defalut_des_param +# or FittingParam not in defalut_fit_param +# ): +# cls.skip_test_jit = True +# else: +# with torch.jit.optimized_execution(False): +# cls._script_module = torch.jit.script(cls.module) +# cls.output_def = cls.module.translated_output_def() +# cls.expected_has_message_passing = ds.has_message_passing() +# cls.expected_sel_type = ft.get_sel_type() +# cls.expected_dim_fparam = ft.get_dim_fparam() +# cls.expected_dim_aparam = ft.get_dim_aparam() + + +# @parameterized( +# des_parameterized=( +# ( +# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], +# *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], +# *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], +# *[ +# (param_func, DescrptSeTTebd) +# for param_func in DescriptorParamSeTTebdList +# ], +# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], +# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], +# (DescriptorParamHybrid, DescrptHybrid), +# (DescriptorParamHybridMixed, DescrptHybrid), +# (DescriptorParamHybridMixedTTebd, DescrptHybrid), +# ), # descrpt_class_param & class +# ((FittingParamDos, DOSFittingNet),), # fitting_class_param & class +# ), +# fit_parameterized=( +# ( +# (DescriptorParamSeA, DescrptSeA), +# (DescriptorParamSeR, DescrptSeR), +# (DescriptorParamSeT, DescrptSeT), +# (DescriptorParamSeTTebd, DescrptSeTTebd), +# (DescriptorParamDPA1, DescrptDPA1), +# (DescriptorParamDPA2, DescrptDPA2), +# ), # descrpt_class_param & class +# ( +# *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], +# ), # fitting_class_param & class +# ), +# ) +# class TestDosModelPT(unittest.TestCase, DosModelTest, PTTestCase): +# @property +# def modules_to_test(self): +# skip_test_jit = getattr(self, "skip_test_jit", False) +# modules = PTTestCase.modules_to_test.fget(self) +# if not skip_test_jit: +# # for Model, we can test script module API +# modules += [ +# self._script_module +# if hasattr(self, "_script_module") +# else self.script_module +# ] +# return modules + +# @classmethod +# def setUpClass(cls): +# DosModelTest.setUpClass() +# (DescriptorParam, Descrpt) = cls.param[0] +# (FittingParam, Fitting) = cls.param[1] +# # set special precision +# cls.aprec_dict["test_smooth"] = 1e-4 +# if Descrpt in [DescrptDPA2]: +# cls.epsilon_dict["test_smooth"] = 1e-8 +# if Descrpt in [DescrptSeT, DescrptSeTTebd]: +# # computational expensive +# cls.expected_sel = [i // 4 for i in cls.expected_sel] +# cls.expected_rcut = cls.expected_rcut / 2 +# cls.input_dict_ds = DescriptorParam( +# len(cls.expected_type_map), +# cls.expected_rcut, +# cls.expected_rcut / 2, +# cls.expected_sel, +# cls.expected_type_map, +# ) + +# # set skip tests +# skiptest, skip_reason = skip_model_tests(cls) +# if skiptest: +# raise cls.skipTest(cls, skip_reason) + +# ds = Descrpt(**cls.input_dict_ds) +# cls.input_dict_ft = FittingParam( +# ntypes=len(cls.expected_type_map), +# dim_descrpt=ds.get_dim_out(), +# mixed_types=ds.mixed_types(), +# type_map=cls.expected_type_map, +# ) +# ft = Fitting( +# **cls.input_dict_ft, +# ) +# cls.module = DOSModel( +# ds, +# ft, +# type_map=cls.expected_type_map, +# ) +# # only test jit API once for different models +# if ( +# DescriptorParam not in defalut_des_param +# or FittingParam not in defalut_fit_param +# ): +# cls.skip_test_jit = True +# else: +# with torch.jit.optimized_execution(False): +# cls._script_module = torch.jit.script(cls.module) +# cls.output_def = cls.module.translated_output_def() +# cls.expected_has_message_passing = ds.has_message_passing() +# cls.expected_sel_type = ft.get_sel_type() +# cls.expected_dim_fparam = ft.get_dim_fparam() +# cls.expected_dim_aparam = ft.get_dim_aparam() + + +# @parameterized( +# des_parameterized=( +# ( +# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], +# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], +# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], +# (DescriptorParamHybrid, DescrptHybrid), +# (DescriptorParamHybridMixed, DescrptHybrid), +# ), # descrpt_class_param & class +# ((FittingParamDipole, DipoleFittingNet),), # fitting_class_param & class +# ), +# fit_parameterized=( +# ( +# (DescriptorParamSeA, DescrptSeA), +# (DescriptorParamDPA1, DescrptDPA1), +# (DescriptorParamDPA2, DescrptDPA2), +# ), # descrpt_class_param & class +# ( +# *[(param_func, DipoleFittingNet) for param_func in FittingParamDipoleList], +# ), # fitting_class_param & class +# ), +# ) +# class TestDipoleModelPT(unittest.TestCase, DipoleModelTest, PTTestCase): +# @property +# def modules_to_test(self): +# skip_test_jit = getattr(self, "skip_test_jit", False) +# modules = PTTestCase.modules_to_test.fget(self) +# if not skip_test_jit: +# # for Model, we can test script module API +# modules += [ +# self._script_module +# if hasattr(self, "_script_module") +# else self.script_module +# ] +# return modules + +# @classmethod +# def setUpClass(cls): +# DipoleModelTest.setUpClass() +# (DescriptorParam, Descrpt) = cls.param[0] +# (FittingParam, Fitting) = cls.param[1] +# # set special precision +# if Descrpt in [DescrptDPA2]: +# cls.epsilon_dict["test_smooth"] = 1e-8 +# cls.aprec_dict["test_forward"] = 1e-10 # for dipole force when near zero +# cls.aprec_dict["test_rot"] = 1e-10 # for dipole force when near zero +# cls.aprec_dict["test_trans"] = 1e-10 # for dipole force when near zero +# cls.aprec_dict["test_permutation"] = 1e-10 # for dipole force when near zero +# cls.input_dict_ds = DescriptorParam( +# len(cls.expected_type_map), +# cls.expected_rcut, +# cls.expected_rcut / 2, +# cls.expected_sel, +# cls.expected_type_map, +# ) + +# # set skip tests +# skiptest, skip_reason = skip_model_tests(cls) +# if skiptest: +# raise cls.skipTest(cls, skip_reason) + +# ds = Descrpt(**cls.input_dict_ds) +# cls.input_dict_ft = FittingParam( +# ntypes=len(cls.expected_type_map), +# dim_descrpt=ds.get_dim_out(), +# mixed_types=ds.mixed_types(), +# type_map=cls.expected_type_map, +# embedding_width=ds.get_dim_emb(), +# ) +# ft = Fitting( +# **cls.input_dict_ft, +# ) +# cls.module = DipoleModel( +# ds, +# ft, +# type_map=cls.expected_type_map, +# ) +# # only test jit API once for different models +# if ( +# DescriptorParam not in defalut_des_param +# or FittingParam not in defalut_fit_param +# ): +# cls.skip_test_jit = True +# else: +# with torch.jit.optimized_execution(False): +# cls._script_module = torch.jit.script(cls.module) +# cls.output_def = cls.module.translated_output_def() +# cls.expected_has_message_passing = ds.has_message_passing() +# cls.expected_sel_type = ft.get_sel_type() +# cls.expected_dim_fparam = ft.get_dim_fparam() +# cls.expected_dim_aparam = ft.get_dim_aparam() + + +# @parameterized( +# des_parameterized=( +# ( +# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], +# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], +# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], +# (DescriptorParamHybrid, DescrptHybrid), +# (DescriptorParamHybridMixed, DescrptHybrid), +# ), # descrpt_class_param & class +# ((FittingParamPolar, PolarFittingNet),), # fitting_class_param & class +# ), +# fit_parameterized=( +# ( +# (DescriptorParamSeA, DescrptSeA), +# (DescriptorParamDPA1, DescrptDPA1), +# (DescriptorParamDPA2, DescrptDPA2), +# ), # descrpt_class_param & class +# ( +# *[(param_func, PolarFittingNet) for param_func in FittingParamPolarList], +# ), # fitting_class_param & class +# ), +# ) +# class TestPolarModelPT(unittest.TestCase, PolarModelTest, PTTestCase): +# @property +# def modules_to_test(self): +# skip_test_jit = getattr(self, "skip_test_jit", False) +# modules = PTTestCase.modules_to_test.fget(self) +# if not skip_test_jit: +# # for Model, we can test script module API +# modules += [ +# self._script_module +# if hasattr(self, "_script_module") +# else self.script_module +# ] +# return modules + +# @classmethod +# def setUpClass(cls): +# PolarModelTest.setUpClass() +# (DescriptorParam, Descrpt) = cls.param[0] +# (FittingParam, Fitting) = cls.param[1] +# # set special precision +# if Descrpt in [DescrptDPA2]: +# cls.epsilon_dict["test_smooth"] = 1e-8 +# cls.input_dict_ds = DescriptorParam( +# len(cls.expected_type_map), +# cls.expected_rcut, +# cls.expected_rcut / 2, +# cls.expected_sel, +# cls.expected_type_map, +# ) + +# # set skip tests +# skiptest, skip_reason = skip_model_tests(cls) +# if skiptest: +# raise cls.skipTest(cls, skip_reason) + +# ds = Descrpt(**cls.input_dict_ds) +# cls.input_dict_ft = FittingParam( +# ntypes=len(cls.expected_type_map), +# dim_descrpt=ds.get_dim_out(), +# mixed_types=ds.mixed_types(), +# type_map=cls.expected_type_map, +# embedding_width=ds.get_dim_emb(), +# ) +# ft = Fitting( +# **cls.input_dict_ft, +# ) +# cls.module = PolarModel( +# ds, +# ft, +# type_map=cls.expected_type_map, +# ) +# # only test jit API once for different models +# if ( +# DescriptorParam not in defalut_des_param +# or FittingParam not in defalut_fit_param +# ): +# cls.skip_test_jit = True +# else: +# with torch.jit.optimized_execution(False): +# cls._script_module = torch.jit.script(cls.module) +# cls.output_def = cls.module.translated_output_def() +# cls.expected_has_message_passing = ds.has_message_passing() +# cls.expected_sel_type = ft.get_sel_type() +# cls.expected_dim_fparam = ft.get_dim_fparam() +# cls.expected_dim_aparam = ft.get_dim_aparam() + + +# @parameterized( +# des_parameterized=( +# ( +# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], +# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], +# (DescriptorParamHybridMixed, DescrptHybrid), +# (DescriptorParamHybridMixedTTebd, DescrptHybrid), +# ), # descrpt_class_param & class +# ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class +# ), +# fit_parameterized=( +# ( +# (DescriptorParamDPA1, DescrptDPA1), +# (DescriptorParamDPA2, DescrptDPA2), +# ), # descrpt_class_param & class +# ( +# *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], +# ), # fitting_class_param & class +# ), +# ) +# class TestZBLModelPT(unittest.TestCase, ZBLModelTest, PTTestCase): +# @property +# def modules_to_test(self): +# skip_test_jit = getattr(self, "skip_test_jit", False) +# modules = PTTestCase.modules_to_test.fget(self) +# if not skip_test_jit: +# # for Model, we can test script module API +# modules += [ +# self._script_module +# if hasattr(self, "_script_module") +# else self.script_module +# ] +# return modules + +# @classmethod +# def setUpClass(cls): +# ZBLModelTest.setUpClass() +# (DescriptorParam, Descrpt) = cls.param[0] +# (FittingParam, Fitting) = cls.param[1] +# # set special precision +# # zbl weights not so smooth +# cls.aprec_dict["test_smooth"] = 5e-2 +# cls.input_dict_ds = DescriptorParam( +# len(cls.expected_type_map), +# cls.expected_rcut, +# cls.expected_rcut / 2, +# cls.expected_sel, +# cls.expected_type_map, +# ) + +# # set skip tests +# skiptest, skip_reason = skip_model_tests(cls) +# if skiptest: +# raise cls.skipTest(cls, skip_reason) + +# ds = Descrpt(**cls.input_dict_ds) +# cls.input_dict_ft = FittingParam( +# ntypes=len(cls.expected_type_map), +# dim_descrpt=ds.get_dim_out(), +# mixed_types=ds.mixed_types(), +# type_map=cls.expected_type_map, +# ) +# ft = Fitting( +# **cls.input_dict_ft, +# ) +# dp_model = DPAtomicModel( +# ds, +# ft, +# type_map=cls.expected_type_map, +# ) +# pt_model = PairTabAtomicModel( +# cls.tab_file["use_srtab"], +# cls.expected_rcut, +# cls.expected_sel, +# type_map=cls.expected_type_map, +# ) +# cls.module = DPZBLModel( +# dp_model, +# pt_model, +# sw_rmin=cls.tab_file["sw_rmin"], +# sw_rmax=cls.tab_file["sw_rmax"], +# smin_alpha=cls.tab_file["smin_alpha"], +# type_map=cls.expected_type_map, +# ) +# # only test jit API once for different models +# if ( +# DescriptorParam not in defalut_des_param +# or FittingParam not in defalut_fit_param +# ): +# cls.skip_test_jit = True +# else: +# with torch.jit.optimized_execution(False): +# cls._script_module = torch.jit.script(cls.module) +# cls.output_def = cls.module.translated_output_def() +# cls.expected_has_message_passing = ds.has_message_passing() +# cls.expected_dim_fparam = ft.get_dim_fparam() +# cls.expected_dim_aparam = ft.get_dim_aparam() + + +# @parameterized( +# des_parameterized=( +# ( +# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], +# *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], +# *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], +# *[ +# (param_func, DescrptSeTTebd) +# for param_func in DescriptorParamSeTTebdList +# ], +# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], +# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], +# # (DescriptorParamHybrid, DescrptHybrid), +# # unsupported for SpinModel to hybrid both mixed_types and no-mixed_types descriptor +# (DescriptorParamHybridMixed, DescrptHybrid), +# (DescriptorParamHybridMixedTTebd, DescrptHybrid), +# ), # descrpt_class_param & class +# ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class +# ), +# fit_parameterized=( +# ( +# (DescriptorParamSeA, DescrptSeA), +# (DescriptorParamSeR, DescrptSeR), +# (DescriptorParamSeT, DescrptSeT), +# (DescriptorParamSeTTebd, DescrptSeTTebd), +# (DescriptorParamDPA1, DescrptDPA1), +# (DescriptorParamDPA2, DescrptDPA2), +# ), # descrpt_class_param & class +# ( +# *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], +# ), # fitting_class_param & class +# ), +# ) +# class TestSpinEnergyModelDP(unittest.TestCase, SpinEnerModelTest, PTTestCase): +# @property +# def modules_to_test(self): +# skip_test_jit = getattr(self, "skip_test_jit", False) +# modules = PTTestCase.modules_to_test.fget(self) +# if not skip_test_jit: +# # for Model, we can test script module API +# modules += [ +# self._script_module +# if hasattr(self, "_script_module") +# else self.script_module +# ] +# return modules + +# @classmethod +# def setUpClass(cls): +# SpinEnerModelTest.setUpClass() +# (DescriptorParam, Descrpt) = cls.param[0] +# (FittingParam, Fitting) = cls.param[1] +# cls.epsilon_dict["test_smooth"] = 1e-6 +# cls.aprec_dict["test_smooth"] = 5e-5 +# # set special precision +# if Descrpt in [DescrptDPA2, DescrptHybrid]: +# cls.epsilon_dict["test_smooth"] = 1e-8 +# if Descrpt in [DescrptSeT, DescrptSeTTebd]: +# # computational expensive +# cls.expected_sel = [i // 4 for i in cls.expected_sel] +# cls.expected_rcut = cls.expected_rcut / 2 + +# spin = Spin( +# use_spin=cls.spin_dict["use_spin"], +# virtual_scale=cls.spin_dict["virtual_scale"], +# ) +# spin_type_map = cls.expected_type_map + [ +# item + "_spin" for item in cls.expected_type_map +# ] +# if Descrpt in [DescrptSeA, DescrptSeR, DescrptSeT]: +# spin_sel = cls.expected_sel + cls.expected_sel +# else: +# spin_sel = cls.expected_sel +# pair_exclude_types = spin.get_pair_exclude_types() +# atom_exclude_types = spin.get_atom_exclude_types() +# cls.input_dict_ds = DescriptorParam( +# len(spin_type_map), +# cls.expected_rcut, +# cls.expected_rcut / 2, +# spin_sel, +# spin_type_map, +# env_protection=1e-6, +# exclude_types=pair_exclude_types, +# ) + +# # set skip tests +# skiptest, skip_reason = skip_model_tests(cls) +# if skiptest: +# raise cls.skipTest(cls, skip_reason) + +# ds = Descrpt(**cls.input_dict_ds) +# cls.input_dict_ft = FittingParam( +# ntypes=len(spin_type_map), +# dim_descrpt=ds.get_dim_out(), +# mixed_types=ds.mixed_types(), +# type_map=spin_type_map, +# ) +# ft = Fitting( +# **cls.input_dict_ft, +# ) +# backbone_model = EnergyModel( +# ds, +# ft, +# type_map=spin_type_map, +# atom_exclude_types=atom_exclude_types, +# pair_exclude_types=pair_exclude_types, +# ) +# cls.module = SpinEnergyModel(backbone_model=backbone_model, spin=spin) +# # only test jit API once for different models +# if ( +# DescriptorParam not in defalut_des_param +# or FittingParam not in defalut_fit_param +# ): +# cls.skip_test_jit = True +# else: +# with torch.jit.optimized_execution(False): +# cls._script_module = torch.jit.script(cls.module) +# cls.output_def = cls.module.translated_output_def() +# cls.expected_has_message_passing = ds.has_message_passing() +# cls.expected_sel_type = ft.get_sel_type() +# cls.expected_dim_fparam = ft.get_dim_fparam() +# cls.expected_dim_aparam = ft.get_dim_aparam() + + +# @parameterized( +# des_parameterized=( +# ( +# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], +# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], +# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], +# (DescriptorParamHybrid, DescrptHybrid), +# (DescriptorParamHybridMixed, DescrptHybrid), +# ), # descrpt_class_param & class +# ((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class +# ), +# fit_parameterized=( +# ( +# (DescriptorParamSeA, DescrptSeA), +# (DescriptorParamDPA1, DescrptDPA1), +# (DescriptorParamDPA2, DescrptDPA2), +# ), # descrpt_class_param & class +# ( +# *[ +# (param_func, PropertyFittingNet) +# for param_func in FittingParamPropertyList +# ], +# ), # fitting_class_param & class +# ), +# ) +# class TestPropertyModelPT(unittest.TestCase, PropertyModelTest, PTTestCase): +# @property +# def modules_to_test(self): +# skip_test_jit = getattr(self, "skip_test_jit", False) +# modules = PTTestCase.modules_to_test.fget(self) +# if not skip_test_jit: +# # for Model, we can test script module API +# modules += [ +# self._script_module +# if hasattr(self, "_script_module") +# else self.script_module +# ] +# return modules + +# @classmethod +# def setUpClass(cls): +# PropertyModelTest.setUpClass() +# (DescriptorParam, Descrpt) = cls.param[0] +# (FittingParam, Fitting) = cls.param[1] +# # set special precision +# if Descrpt in [DescrptDPA2]: +# cls.epsilon_dict["test_smooth"] = 1e-8 +# cls.input_dict_ds = DescriptorParam( +# len(cls.expected_type_map), +# cls.expected_rcut, +# cls.expected_rcut / 2, +# cls.expected_sel, +# cls.expected_type_map, +# ) + +# # set skip tests +# skiptest, skip_reason = skip_model_tests(cls) +# if skiptest: +# raise cls.skipTest(cls, skip_reason) + +# ds = Descrpt(**cls.input_dict_ds) +# cls.input_dict_ft = FittingParam( +# ntypes=len(cls.expected_type_map), +# dim_descrpt=ds.get_dim_out(), +# mixed_types=ds.mixed_types(), +# type_map=cls.expected_type_map, +# embedding_width=ds.get_dim_emb(), +# ) +# ft = Fitting( +# **cls.input_dict_ft, +# ) +# cls.module = PropertyModel( +# ds, +# ft, +# type_map=cls.expected_type_map, +# ) +# # only test jit API once for different models +# if ( +# DescriptorParam not in defalut_des_param +# or FittingParam not in defalut_fit_param +# ): +# cls.skip_test_jit = True +# else: +# with torch.jit.optimized_execution(False): +# cls._script_module = torch.jit.script(cls.module) +# cls.output_def = cls.module.translated_output_def() +# cls.expected_has_message_passing = ds.has_message_passing() +# cls.expected_sel_type = ft.get_sel_type() +# cls.expected_dim_fparam = ft.get_dim_fparam() +# cls.expected_dim_aparam = ft.get_dim_aparam() @parameterized( @@ -885,7 +885,7 @@ def setUpClass(cls): type_map=cls.expected_type_map, ) cls.module = DPLinearModel( - [dp_model1, dp_model2], + [dp_model1,dp_model2], type_map=cls.expected_type_map, ) # only test jit API once for different models @@ -901,3 +901,4 @@ def setUpClass(cls): cls.expected_has_message_passing = ds1.has_message_passing() cls.expected_dim_fparam = ft1.get_dim_fparam() cls.expected_dim_aparam = ft1.get_dim_aparam() + cls.expected_sel_type = ft1.get_sel_type() From 000c1c8f8c16da8dd246cdee114580f6f3f6a4d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 12:35:03 +0000 Subject: [PATCH 15/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/universal/pt/model/test_model.py | 36 +------------------ 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 261f45d3ec..e8eb288762 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -5,50 +5,24 @@ from deepmd.pt.model.atomic_model import ( DPAtomicModel, - PairTabAtomicModel, ) from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, DescrptHybrid, - DescrptSeA, - DescrptSeR, - DescrptSeT, - DescrptSeTTebd, ) from deepmd.pt.model.model import ( - DipoleModel, - DOSModel, DPLinearModel, - DPZBLModel, - EnergyModel, - PolarModel, - PropertyModel, - SpinEnergyModel, ) from deepmd.pt.model.task import ( - DipoleFittingNet, - DOSFittingNet, EnergyFittingNet, - PolarFittingNet, - PropertyFittingNet, -) -from deepmd.utils.spin import ( - Spin, ) from ....consistent.common import ( parameterized, ) from ...common.cases.model.model import ( - DipoleModelTest, - DosModelTest, - EnerModelTest, LinearEnerModelTest, - PolarModelTest, - PropertyModelTest, - SpinEnerModelTest, - ZBLModelTest, ) from ...dpmodel.descriptor.test_descriptor import ( DescriptorParamDPA1, @@ -59,25 +33,17 @@ DescriptorParamHybridMixed, DescriptorParamHybridMixedTTebd, DescriptorParamSeA, - DescriptorParamSeAList, DescriptorParamSeR, - DescriptorParamSeRList, DescriptorParamSeT, - DescriptorParamSeTList, DescriptorParamSeTTebd, - DescriptorParamSeTTebdList, ) from ...dpmodel.fitting.test_fitting import ( FittingParamDipole, - FittingParamDipoleList, FittingParamDos, - FittingParamDosList, FittingParamEnergy, FittingParamEnergyList, FittingParamPolar, - FittingParamPolarList, FittingParamProperty, - FittingParamPropertyList, ) from ...dpmodel.model.test_model import ( skip_model_tests, @@ -885,7 +851,7 @@ def setUpClass(cls): type_map=cls.expected_type_map, ) cls.module = DPLinearModel( - [dp_model1,dp_model2], + [dp_model1, dp_model2], type_map=cls.expected_type_map, ) # only test jit API once for different models From 189961c2a4c559a764d223a98c8926c77faa4cc0 Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 21:25:20 +0800 Subject: [PATCH 16/29] fix: UT sel type dtype to long --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 3837ed64d4..a5856a9935 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -404,7 +404,7 @@ def get_sel_type(self) -> list[int]: return torch.unique( torch.cat( [ - torch.as_tensor(model.get_sel_type(), dtype=torch.int32) + torch.as_tensor(model.get_sel_type(), dtype=torch.int64) for model in self.models ] ) From 97156410e0e552fe85785efd02056c51a0e20e99 Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 23:29:23 +0800 Subject: [PATCH 17/29] fix: revert dtype change --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index a5856a9935..3837ed64d4 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -404,7 +404,7 @@ def get_sel_type(self) -> list[int]: return torch.unique( torch.cat( [ - torch.as_tensor(model.get_sel_type(), dtype=torch.int64) + torch.as_tensor(model.get_sel_type(), dtype=torch.int32) for model in self.models ] ) From c8e86fef039bdaee5d00bf7c40486016213e992a Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 8 Oct 2024 23:31:54 +0800 Subject: [PATCH 18/29] fix: revert ut change --- source/tests/universal/pt/model/test_model.py | 1436 +++++++++-------- 1 file changed, 735 insertions(+), 701 deletions(-) diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index e8eb288762..04ec9b6f94 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -5,24 +5,50 @@ from deepmd.pt.model.atomic_model import ( DPAtomicModel, + PairTabAtomicModel, ) from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, DescrptHybrid, + DescrptSeA, + DescrptSeR, + DescrptSeT, + DescrptSeTTebd, ) from deepmd.pt.model.model import ( + DipoleModel, + DOSModel, DPLinearModel, + DPZBLModel, + EnergyModel, + PolarModel, + PropertyModel, + SpinEnergyModel, ) from deepmd.pt.model.task import ( + DipoleFittingNet, + DOSFittingNet, EnergyFittingNet, + PolarFittingNet, + PropertyFittingNet, +) +from deepmd.utils.spin import ( + Spin, ) from ....consistent.common import ( parameterized, ) from ...common.cases.model.model import ( + DipoleModelTest, + DosModelTest, + EnerModelTest, LinearEnerModelTest, + PolarModelTest, + PropertyModelTest, + SpinEnerModelTest, + ZBLModelTest, ) from ...dpmodel.descriptor.test_descriptor import ( DescriptorParamDPA1, @@ -33,17 +59,25 @@ DescriptorParamHybridMixed, DescriptorParamHybridMixedTTebd, DescriptorParamSeA, + DescriptorParamSeAList, DescriptorParamSeR, + DescriptorParamSeRList, DescriptorParamSeT, + DescriptorParamSeTList, DescriptorParamSeTTebd, + DescriptorParamSeTTebdList, ) from ...dpmodel.fitting.test_fitting import ( FittingParamDipole, + FittingParamDipoleList, FittingParamDos, + FittingParamDosList, FittingParamEnergy, FittingParamEnergyList, FittingParamPolar, + FittingParamPolarList, FittingParamProperty, + FittingParamPropertyList, ) from ...dpmodel.model.test_model import ( skip_model_tests, @@ -71,706 +105,706 @@ ] -# @parameterized( -# des_parameterized=( -# ( -# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], -# *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], -# *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], -# *[ -# (param_func, DescrptSeTTebd) -# for param_func in DescriptorParamSeTTebdList -# ], -# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], -# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], -# (DescriptorParamHybrid, DescrptHybrid), -# (DescriptorParamHybridMixed, DescrptHybrid), -# (DescriptorParamHybridMixedTTebd, DescrptHybrid), -# ), # descrpt_class_param & class -# ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class -# ), -# fit_parameterized=( -# ( -# (DescriptorParamSeA, DescrptSeA), -# (DescriptorParamSeR, DescrptSeR), -# (DescriptorParamSeT, DescrptSeT), -# (DescriptorParamSeTTebd, DescrptSeTTebd), -# (DescriptorParamDPA1, DescrptDPA1), -# (DescriptorParamDPA2, DescrptDPA2), -# ), # descrpt_class_param & class -# ( -# *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], -# ), # fitting_class_param & class -# ), -# ) -# class TestEnergyModelPT(unittest.TestCase, EnerModelTest, PTTestCase): -# @property -# def modules_to_test(self): -# skip_test_jit = getattr(self, "skip_test_jit", False) -# modules = PTTestCase.modules_to_test.fget(self) -# if not skip_test_jit: -# # for Model, we can test script module API -# modules += [ -# self._script_module -# if hasattr(self, "_script_module") -# else self.script_module -# ] -# return modules - -# @classmethod -# def setUpClass(cls): -# EnerModelTest.setUpClass() -# (DescriptorParam, Descrpt) = cls.param[0] -# (FittingParam, Fitting) = cls.param[1] -# # set special precision -# if Descrpt in [DescrptDPA2]: -# cls.epsilon_dict["test_smooth"] = 1e-8 -# if Descrpt in [DescrptSeT, DescrptSeTTebd]: -# # computational expensive -# cls.expected_sel = [i // 4 for i in cls.expected_sel] -# cls.expected_rcut = cls.expected_rcut / 2 -# cls.input_dict_ds = DescriptorParam( -# len(cls.expected_type_map), -# cls.expected_rcut, -# cls.expected_rcut / 2, -# cls.expected_sel, -# cls.expected_type_map, -# ) - -# # set skip tests -# skiptest, skip_reason = skip_model_tests(cls) -# if skiptest: -# raise cls.skipTest(cls, skip_reason) - -# ds = Descrpt(**cls.input_dict_ds) -# cls.input_dict_ft = FittingParam( -# ntypes=len(cls.expected_type_map), -# dim_descrpt=ds.get_dim_out(), -# mixed_types=ds.mixed_types(), -# type_map=cls.expected_type_map, -# ) -# ft = Fitting( -# **cls.input_dict_ft, -# ) -# cls.module = EnergyModel( -# ds, -# ft, -# type_map=cls.expected_type_map, -# ) -# # only test jit API once for different models -# if ( -# DescriptorParam not in defalut_des_param -# or FittingParam not in defalut_fit_param -# ): -# cls.skip_test_jit = True -# else: -# with torch.jit.optimized_execution(False): -# cls._script_module = torch.jit.script(cls.module) -# cls.output_def = cls.module.translated_output_def() -# cls.expected_has_message_passing = ds.has_message_passing() -# cls.expected_sel_type = ft.get_sel_type() -# cls.expected_dim_fparam = ft.get_dim_fparam() -# cls.expected_dim_aparam = ft.get_dim_aparam() - - -# @parameterized( -# des_parameterized=( -# ( -# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], -# *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], -# *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], -# *[ -# (param_func, DescrptSeTTebd) -# for param_func in DescriptorParamSeTTebdList -# ], -# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], -# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], -# (DescriptorParamHybrid, DescrptHybrid), -# (DescriptorParamHybridMixed, DescrptHybrid), -# (DescriptorParamHybridMixedTTebd, DescrptHybrid), -# ), # descrpt_class_param & class -# ((FittingParamDos, DOSFittingNet),), # fitting_class_param & class -# ), -# fit_parameterized=( -# ( -# (DescriptorParamSeA, DescrptSeA), -# (DescriptorParamSeR, DescrptSeR), -# (DescriptorParamSeT, DescrptSeT), -# (DescriptorParamSeTTebd, DescrptSeTTebd), -# (DescriptorParamDPA1, DescrptDPA1), -# (DescriptorParamDPA2, DescrptDPA2), -# ), # descrpt_class_param & class -# ( -# *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], -# ), # fitting_class_param & class -# ), -# ) -# class TestDosModelPT(unittest.TestCase, DosModelTest, PTTestCase): -# @property -# def modules_to_test(self): -# skip_test_jit = getattr(self, "skip_test_jit", False) -# modules = PTTestCase.modules_to_test.fget(self) -# if not skip_test_jit: -# # for Model, we can test script module API -# modules += [ -# self._script_module -# if hasattr(self, "_script_module") -# else self.script_module -# ] -# return modules - -# @classmethod -# def setUpClass(cls): -# DosModelTest.setUpClass() -# (DescriptorParam, Descrpt) = cls.param[0] -# (FittingParam, Fitting) = cls.param[1] -# # set special precision -# cls.aprec_dict["test_smooth"] = 1e-4 -# if Descrpt in [DescrptDPA2]: -# cls.epsilon_dict["test_smooth"] = 1e-8 -# if Descrpt in [DescrptSeT, DescrptSeTTebd]: -# # computational expensive -# cls.expected_sel = [i // 4 for i in cls.expected_sel] -# cls.expected_rcut = cls.expected_rcut / 2 -# cls.input_dict_ds = DescriptorParam( -# len(cls.expected_type_map), -# cls.expected_rcut, -# cls.expected_rcut / 2, -# cls.expected_sel, -# cls.expected_type_map, -# ) - -# # set skip tests -# skiptest, skip_reason = skip_model_tests(cls) -# if skiptest: -# raise cls.skipTest(cls, skip_reason) - -# ds = Descrpt(**cls.input_dict_ds) -# cls.input_dict_ft = FittingParam( -# ntypes=len(cls.expected_type_map), -# dim_descrpt=ds.get_dim_out(), -# mixed_types=ds.mixed_types(), -# type_map=cls.expected_type_map, -# ) -# ft = Fitting( -# **cls.input_dict_ft, -# ) -# cls.module = DOSModel( -# ds, -# ft, -# type_map=cls.expected_type_map, -# ) -# # only test jit API once for different models -# if ( -# DescriptorParam not in defalut_des_param -# or FittingParam not in defalut_fit_param -# ): -# cls.skip_test_jit = True -# else: -# with torch.jit.optimized_execution(False): -# cls._script_module = torch.jit.script(cls.module) -# cls.output_def = cls.module.translated_output_def() -# cls.expected_has_message_passing = ds.has_message_passing() -# cls.expected_sel_type = ft.get_sel_type() -# cls.expected_dim_fparam = ft.get_dim_fparam() -# cls.expected_dim_aparam = ft.get_dim_aparam() - - -# @parameterized( -# des_parameterized=( -# ( -# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], -# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], -# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], -# (DescriptorParamHybrid, DescrptHybrid), -# (DescriptorParamHybridMixed, DescrptHybrid), -# ), # descrpt_class_param & class -# ((FittingParamDipole, DipoleFittingNet),), # fitting_class_param & class -# ), -# fit_parameterized=( -# ( -# (DescriptorParamSeA, DescrptSeA), -# (DescriptorParamDPA1, DescrptDPA1), -# (DescriptorParamDPA2, DescrptDPA2), -# ), # descrpt_class_param & class -# ( -# *[(param_func, DipoleFittingNet) for param_func in FittingParamDipoleList], -# ), # fitting_class_param & class -# ), -# ) -# class TestDipoleModelPT(unittest.TestCase, DipoleModelTest, PTTestCase): -# @property -# def modules_to_test(self): -# skip_test_jit = getattr(self, "skip_test_jit", False) -# modules = PTTestCase.modules_to_test.fget(self) -# if not skip_test_jit: -# # for Model, we can test script module API -# modules += [ -# self._script_module -# if hasattr(self, "_script_module") -# else self.script_module -# ] -# return modules - -# @classmethod -# def setUpClass(cls): -# DipoleModelTest.setUpClass() -# (DescriptorParam, Descrpt) = cls.param[0] -# (FittingParam, Fitting) = cls.param[1] -# # set special precision -# if Descrpt in [DescrptDPA2]: -# cls.epsilon_dict["test_smooth"] = 1e-8 -# cls.aprec_dict["test_forward"] = 1e-10 # for dipole force when near zero -# cls.aprec_dict["test_rot"] = 1e-10 # for dipole force when near zero -# cls.aprec_dict["test_trans"] = 1e-10 # for dipole force when near zero -# cls.aprec_dict["test_permutation"] = 1e-10 # for dipole force when near zero -# cls.input_dict_ds = DescriptorParam( -# len(cls.expected_type_map), -# cls.expected_rcut, -# cls.expected_rcut / 2, -# cls.expected_sel, -# cls.expected_type_map, -# ) - -# # set skip tests -# skiptest, skip_reason = skip_model_tests(cls) -# if skiptest: -# raise cls.skipTest(cls, skip_reason) - -# ds = Descrpt(**cls.input_dict_ds) -# cls.input_dict_ft = FittingParam( -# ntypes=len(cls.expected_type_map), -# dim_descrpt=ds.get_dim_out(), -# mixed_types=ds.mixed_types(), -# type_map=cls.expected_type_map, -# embedding_width=ds.get_dim_emb(), -# ) -# ft = Fitting( -# **cls.input_dict_ft, -# ) -# cls.module = DipoleModel( -# ds, -# ft, -# type_map=cls.expected_type_map, -# ) -# # only test jit API once for different models -# if ( -# DescriptorParam not in defalut_des_param -# or FittingParam not in defalut_fit_param -# ): -# cls.skip_test_jit = True -# else: -# with torch.jit.optimized_execution(False): -# cls._script_module = torch.jit.script(cls.module) -# cls.output_def = cls.module.translated_output_def() -# cls.expected_has_message_passing = ds.has_message_passing() -# cls.expected_sel_type = ft.get_sel_type() -# cls.expected_dim_fparam = ft.get_dim_fparam() -# cls.expected_dim_aparam = ft.get_dim_aparam() - - -# @parameterized( -# des_parameterized=( -# ( -# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], -# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], -# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], -# (DescriptorParamHybrid, DescrptHybrid), -# (DescriptorParamHybridMixed, DescrptHybrid), -# ), # descrpt_class_param & class -# ((FittingParamPolar, PolarFittingNet),), # fitting_class_param & class -# ), -# fit_parameterized=( -# ( -# (DescriptorParamSeA, DescrptSeA), -# (DescriptorParamDPA1, DescrptDPA1), -# (DescriptorParamDPA2, DescrptDPA2), -# ), # descrpt_class_param & class -# ( -# *[(param_func, PolarFittingNet) for param_func in FittingParamPolarList], -# ), # fitting_class_param & class -# ), -# ) -# class TestPolarModelPT(unittest.TestCase, PolarModelTest, PTTestCase): -# @property -# def modules_to_test(self): -# skip_test_jit = getattr(self, "skip_test_jit", False) -# modules = PTTestCase.modules_to_test.fget(self) -# if not skip_test_jit: -# # for Model, we can test script module API -# modules += [ -# self._script_module -# if hasattr(self, "_script_module") -# else self.script_module -# ] -# return modules - -# @classmethod -# def setUpClass(cls): -# PolarModelTest.setUpClass() -# (DescriptorParam, Descrpt) = cls.param[0] -# (FittingParam, Fitting) = cls.param[1] -# # set special precision -# if Descrpt in [DescrptDPA2]: -# cls.epsilon_dict["test_smooth"] = 1e-8 -# cls.input_dict_ds = DescriptorParam( -# len(cls.expected_type_map), -# cls.expected_rcut, -# cls.expected_rcut / 2, -# cls.expected_sel, -# cls.expected_type_map, -# ) - -# # set skip tests -# skiptest, skip_reason = skip_model_tests(cls) -# if skiptest: -# raise cls.skipTest(cls, skip_reason) - -# ds = Descrpt(**cls.input_dict_ds) -# cls.input_dict_ft = FittingParam( -# ntypes=len(cls.expected_type_map), -# dim_descrpt=ds.get_dim_out(), -# mixed_types=ds.mixed_types(), -# type_map=cls.expected_type_map, -# embedding_width=ds.get_dim_emb(), -# ) -# ft = Fitting( -# **cls.input_dict_ft, -# ) -# cls.module = PolarModel( -# ds, -# ft, -# type_map=cls.expected_type_map, -# ) -# # only test jit API once for different models -# if ( -# DescriptorParam not in defalut_des_param -# or FittingParam not in defalut_fit_param -# ): -# cls.skip_test_jit = True -# else: -# with torch.jit.optimized_execution(False): -# cls._script_module = torch.jit.script(cls.module) -# cls.output_def = cls.module.translated_output_def() -# cls.expected_has_message_passing = ds.has_message_passing() -# cls.expected_sel_type = ft.get_sel_type() -# cls.expected_dim_fparam = ft.get_dim_fparam() -# cls.expected_dim_aparam = ft.get_dim_aparam() - - -# @parameterized( -# des_parameterized=( -# ( -# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], -# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], -# (DescriptorParamHybridMixed, DescrptHybrid), -# (DescriptorParamHybridMixedTTebd, DescrptHybrid), -# ), # descrpt_class_param & class -# ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class -# ), -# fit_parameterized=( -# ( -# (DescriptorParamDPA1, DescrptDPA1), -# (DescriptorParamDPA2, DescrptDPA2), -# ), # descrpt_class_param & class -# ( -# *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], -# ), # fitting_class_param & class -# ), -# ) -# class TestZBLModelPT(unittest.TestCase, ZBLModelTest, PTTestCase): -# @property -# def modules_to_test(self): -# skip_test_jit = getattr(self, "skip_test_jit", False) -# modules = PTTestCase.modules_to_test.fget(self) -# if not skip_test_jit: -# # for Model, we can test script module API -# modules += [ -# self._script_module -# if hasattr(self, "_script_module") -# else self.script_module -# ] -# return modules - -# @classmethod -# def setUpClass(cls): -# ZBLModelTest.setUpClass() -# (DescriptorParam, Descrpt) = cls.param[0] -# (FittingParam, Fitting) = cls.param[1] -# # set special precision -# # zbl weights not so smooth -# cls.aprec_dict["test_smooth"] = 5e-2 -# cls.input_dict_ds = DescriptorParam( -# len(cls.expected_type_map), -# cls.expected_rcut, -# cls.expected_rcut / 2, -# cls.expected_sel, -# cls.expected_type_map, -# ) - -# # set skip tests -# skiptest, skip_reason = skip_model_tests(cls) -# if skiptest: -# raise cls.skipTest(cls, skip_reason) - -# ds = Descrpt(**cls.input_dict_ds) -# cls.input_dict_ft = FittingParam( -# ntypes=len(cls.expected_type_map), -# dim_descrpt=ds.get_dim_out(), -# mixed_types=ds.mixed_types(), -# type_map=cls.expected_type_map, -# ) -# ft = Fitting( -# **cls.input_dict_ft, -# ) -# dp_model = DPAtomicModel( -# ds, -# ft, -# type_map=cls.expected_type_map, -# ) -# pt_model = PairTabAtomicModel( -# cls.tab_file["use_srtab"], -# cls.expected_rcut, -# cls.expected_sel, -# type_map=cls.expected_type_map, -# ) -# cls.module = DPZBLModel( -# dp_model, -# pt_model, -# sw_rmin=cls.tab_file["sw_rmin"], -# sw_rmax=cls.tab_file["sw_rmax"], -# smin_alpha=cls.tab_file["smin_alpha"], -# type_map=cls.expected_type_map, -# ) -# # only test jit API once for different models -# if ( -# DescriptorParam not in defalut_des_param -# or FittingParam not in defalut_fit_param -# ): -# cls.skip_test_jit = True -# else: -# with torch.jit.optimized_execution(False): -# cls._script_module = torch.jit.script(cls.module) -# cls.output_def = cls.module.translated_output_def() -# cls.expected_has_message_passing = ds.has_message_passing() -# cls.expected_dim_fparam = ft.get_dim_fparam() -# cls.expected_dim_aparam = ft.get_dim_aparam() - - -# @parameterized( -# des_parameterized=( -# ( -# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], -# *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], -# *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], -# *[ -# (param_func, DescrptSeTTebd) -# for param_func in DescriptorParamSeTTebdList -# ], -# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], -# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], -# # (DescriptorParamHybrid, DescrptHybrid), -# # unsupported for SpinModel to hybrid both mixed_types and no-mixed_types descriptor -# (DescriptorParamHybridMixed, DescrptHybrid), -# (DescriptorParamHybridMixedTTebd, DescrptHybrid), -# ), # descrpt_class_param & class -# ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class -# ), -# fit_parameterized=( -# ( -# (DescriptorParamSeA, DescrptSeA), -# (DescriptorParamSeR, DescrptSeR), -# (DescriptorParamSeT, DescrptSeT), -# (DescriptorParamSeTTebd, DescrptSeTTebd), -# (DescriptorParamDPA1, DescrptDPA1), -# (DescriptorParamDPA2, DescrptDPA2), -# ), # descrpt_class_param & class -# ( -# *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], -# ), # fitting_class_param & class -# ), -# ) -# class TestSpinEnergyModelDP(unittest.TestCase, SpinEnerModelTest, PTTestCase): -# @property -# def modules_to_test(self): -# skip_test_jit = getattr(self, "skip_test_jit", False) -# modules = PTTestCase.modules_to_test.fget(self) -# if not skip_test_jit: -# # for Model, we can test script module API -# modules += [ -# self._script_module -# if hasattr(self, "_script_module") -# else self.script_module -# ] -# return modules - -# @classmethod -# def setUpClass(cls): -# SpinEnerModelTest.setUpClass() -# (DescriptorParam, Descrpt) = cls.param[0] -# (FittingParam, Fitting) = cls.param[1] -# cls.epsilon_dict["test_smooth"] = 1e-6 -# cls.aprec_dict["test_smooth"] = 5e-5 -# # set special precision -# if Descrpt in [DescrptDPA2, DescrptHybrid]: -# cls.epsilon_dict["test_smooth"] = 1e-8 -# if Descrpt in [DescrptSeT, DescrptSeTTebd]: -# # computational expensive -# cls.expected_sel = [i // 4 for i in cls.expected_sel] -# cls.expected_rcut = cls.expected_rcut / 2 - -# spin = Spin( -# use_spin=cls.spin_dict["use_spin"], -# virtual_scale=cls.spin_dict["virtual_scale"], -# ) -# spin_type_map = cls.expected_type_map + [ -# item + "_spin" for item in cls.expected_type_map -# ] -# if Descrpt in [DescrptSeA, DescrptSeR, DescrptSeT]: -# spin_sel = cls.expected_sel + cls.expected_sel -# else: -# spin_sel = cls.expected_sel -# pair_exclude_types = spin.get_pair_exclude_types() -# atom_exclude_types = spin.get_atom_exclude_types() -# cls.input_dict_ds = DescriptorParam( -# len(spin_type_map), -# cls.expected_rcut, -# cls.expected_rcut / 2, -# spin_sel, -# spin_type_map, -# env_protection=1e-6, -# exclude_types=pair_exclude_types, -# ) - -# # set skip tests -# skiptest, skip_reason = skip_model_tests(cls) -# if skiptest: -# raise cls.skipTest(cls, skip_reason) - -# ds = Descrpt(**cls.input_dict_ds) -# cls.input_dict_ft = FittingParam( -# ntypes=len(spin_type_map), -# dim_descrpt=ds.get_dim_out(), -# mixed_types=ds.mixed_types(), -# type_map=spin_type_map, -# ) -# ft = Fitting( -# **cls.input_dict_ft, -# ) -# backbone_model = EnergyModel( -# ds, -# ft, -# type_map=spin_type_map, -# atom_exclude_types=atom_exclude_types, -# pair_exclude_types=pair_exclude_types, -# ) -# cls.module = SpinEnergyModel(backbone_model=backbone_model, spin=spin) -# # only test jit API once for different models -# if ( -# DescriptorParam not in defalut_des_param -# or FittingParam not in defalut_fit_param -# ): -# cls.skip_test_jit = True -# else: -# with torch.jit.optimized_execution(False): -# cls._script_module = torch.jit.script(cls.module) -# cls.output_def = cls.module.translated_output_def() -# cls.expected_has_message_passing = ds.has_message_passing() -# cls.expected_sel_type = ft.get_sel_type() -# cls.expected_dim_fparam = ft.get_dim_fparam() -# cls.expected_dim_aparam = ft.get_dim_aparam() - - -# @parameterized( -# des_parameterized=( -# ( -# *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], -# *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], -# *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], -# (DescriptorParamHybrid, DescrptHybrid), -# (DescriptorParamHybridMixed, DescrptHybrid), -# ), # descrpt_class_param & class -# ((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class -# ), -# fit_parameterized=( -# ( -# (DescriptorParamSeA, DescrptSeA), -# (DescriptorParamDPA1, DescrptDPA1), -# (DescriptorParamDPA2, DescrptDPA2), -# ), # descrpt_class_param & class -# ( -# *[ -# (param_func, PropertyFittingNet) -# for param_func in FittingParamPropertyList -# ], -# ), # fitting_class_param & class -# ), -# ) -# class TestPropertyModelPT(unittest.TestCase, PropertyModelTest, PTTestCase): -# @property -# def modules_to_test(self): -# skip_test_jit = getattr(self, "skip_test_jit", False) -# modules = PTTestCase.modules_to_test.fget(self) -# if not skip_test_jit: -# # for Model, we can test script module API -# modules += [ -# self._script_module -# if hasattr(self, "_script_module") -# else self.script_module -# ] -# return modules - -# @classmethod -# def setUpClass(cls): -# PropertyModelTest.setUpClass() -# (DescriptorParam, Descrpt) = cls.param[0] -# (FittingParam, Fitting) = cls.param[1] -# # set special precision -# if Descrpt in [DescrptDPA2]: -# cls.epsilon_dict["test_smooth"] = 1e-8 -# cls.input_dict_ds = DescriptorParam( -# len(cls.expected_type_map), -# cls.expected_rcut, -# cls.expected_rcut / 2, -# cls.expected_sel, -# cls.expected_type_map, -# ) - -# # set skip tests -# skiptest, skip_reason = skip_model_tests(cls) -# if skiptest: -# raise cls.skipTest(cls, skip_reason) - -# ds = Descrpt(**cls.input_dict_ds) -# cls.input_dict_ft = FittingParam( -# ntypes=len(cls.expected_type_map), -# dim_descrpt=ds.get_dim_out(), -# mixed_types=ds.mixed_types(), -# type_map=cls.expected_type_map, -# embedding_width=ds.get_dim_emb(), -# ) -# ft = Fitting( -# **cls.input_dict_ft, -# ) -# cls.module = PropertyModel( -# ds, -# ft, -# type_map=cls.expected_type_map, -# ) -# # only test jit API once for different models -# if ( -# DescriptorParam not in defalut_des_param -# or FittingParam not in defalut_fit_param -# ): -# cls.skip_test_jit = True -# else: -# with torch.jit.optimized_execution(False): -# cls._script_module = torch.jit.script(cls.module) -# cls.output_def = cls.module.translated_output_def() -# cls.expected_has_message_passing = ds.has_message_passing() -# cls.expected_sel_type = ft.get_sel_type() -# cls.expected_dim_fparam = ft.get_dim_fparam() -# cls.expected_dim_aparam = ft.get_dim_aparam() +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], + *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], + *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], + *[ + (param_func, DescrptSeTTebd) + for param_func in DescriptorParamSeTTebdList + ], + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybrid, DescrptHybrid), + (DescriptorParamHybridMixed, DescrptHybrid), + (DescriptorParamHybridMixedTTebd, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamSeA, DescrptSeA), + (DescriptorParamSeR, DescrptSeR), + (DescriptorParamSeT, DescrptSeT), + (DescriptorParamSeTTebd, DescrptSeTTebd), + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], + ), # fitting_class_param & class + ), +) +class TestEnergyModelPT(unittest.TestCase, EnerModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + EnerModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + # set special precision + if Descrpt in [DescrptDPA2]: + cls.epsilon_dict["test_smooth"] = 1e-8 + if Descrpt in [DescrptSeT, DescrptSeTTebd]: + # computational expensive + cls.expected_sel = [i // 4 for i in cls.expected_sel] + cls.expected_rcut = cls.expected_rcut / 2 + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = EnergyModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], + *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], + *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], + *[ + (param_func, DescrptSeTTebd) + for param_func in DescriptorParamSeTTebdList + ], + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybrid, DescrptHybrid), + (DescriptorParamHybridMixed, DescrptHybrid), + (DescriptorParamHybridMixedTTebd, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamDos, DOSFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamSeA, DescrptSeA), + (DescriptorParamSeR, DescrptSeR), + (DescriptorParamSeT, DescrptSeT), + (DescriptorParamSeTTebd, DescrptSeTTebd), + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], + ), # fitting_class_param & class + ), +) +class TestDosModelPT(unittest.TestCase, DosModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + DosModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + # set special precision + cls.aprec_dict["test_smooth"] = 1e-4 + if Descrpt in [DescrptDPA2]: + cls.epsilon_dict["test_smooth"] = 1e-8 + if Descrpt in [DescrptSeT, DescrptSeTTebd]: + # computational expensive + cls.expected_sel = [i // 4 for i in cls.expected_sel] + cls.expected_rcut = cls.expected_rcut / 2 + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = DOSModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybrid, DescrptHybrid), + (DescriptorParamHybridMixed, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamDipole, DipoleFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamSeA, DescrptSeA), + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, DipoleFittingNet) for param_func in FittingParamDipoleList], + ), # fitting_class_param & class + ), +) +class TestDipoleModelPT(unittest.TestCase, DipoleModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + DipoleModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + # set special precision + if Descrpt in [DescrptDPA2]: + cls.epsilon_dict["test_smooth"] = 1e-8 + cls.aprec_dict["test_forward"] = 1e-10 # for dipole force when near zero + cls.aprec_dict["test_rot"] = 1e-10 # for dipole force when near zero + cls.aprec_dict["test_trans"] = 1e-10 # for dipole force when near zero + cls.aprec_dict["test_permutation"] = 1e-10 # for dipole force when near zero + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + embedding_width=ds.get_dim_emb(), + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = DipoleModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybrid, DescrptHybrid), + (DescriptorParamHybridMixed, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamPolar, PolarFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamSeA, DescrptSeA), + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, PolarFittingNet) for param_func in FittingParamPolarList], + ), # fitting_class_param & class + ), +) +class TestPolarModelPT(unittest.TestCase, PolarModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + PolarModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + # set special precision + if Descrpt in [DescrptDPA2]: + cls.epsilon_dict["test_smooth"] = 1e-8 + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + embedding_width=ds.get_dim_emb(), + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = PolarModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybridMixed, DescrptHybrid), + (DescriptorParamHybridMixedTTebd, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], + ), # fitting_class_param & class + ), +) +class TestZBLModelPT(unittest.TestCase, ZBLModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + ZBLModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + # set special precision + # zbl weights not so smooth + cls.aprec_dict["test_smooth"] = 5e-2 + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + ) + ft = Fitting( + **cls.input_dict_ft, + ) + dp_model = DPAtomicModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + pt_model = PairTabAtomicModel( + cls.tab_file["use_srtab"], + cls.expected_rcut, + cls.expected_sel, + type_map=cls.expected_type_map, + ) + cls.module = DPZBLModel( + dp_model, + pt_model, + sw_rmin=cls.tab_file["sw_rmin"], + sw_rmax=cls.tab_file["sw_rmax"], + smin_alpha=cls.tab_file["smin_alpha"], + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], + *[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList], + *[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList], + *[ + (param_func, DescrptSeTTebd) + for param_func in DescriptorParamSeTTebdList + ], + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + # (DescriptorParamHybrid, DescrptHybrid), + # unsupported for SpinModel to hybrid both mixed_types and no-mixed_types descriptor + (DescriptorParamHybridMixed, DescrptHybrid), + (DescriptorParamHybridMixedTTebd, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamSeA, DescrptSeA), + (DescriptorParamSeR, DescrptSeR), + (DescriptorParamSeT, DescrptSeT), + (DescriptorParamSeTTebd, DescrptSeTTebd), + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], + ), # fitting_class_param & class + ), +) +class TestSpinEnergyModelDP(unittest.TestCase, SpinEnerModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + SpinEnerModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + cls.epsilon_dict["test_smooth"] = 1e-6 + cls.aprec_dict["test_smooth"] = 5e-5 + # set special precision + if Descrpt in [DescrptDPA2, DescrptHybrid]: + cls.epsilon_dict["test_smooth"] = 1e-8 + if Descrpt in [DescrptSeT, DescrptSeTTebd]: + # computational expensive + cls.expected_sel = [i // 4 for i in cls.expected_sel] + cls.expected_rcut = cls.expected_rcut / 2 + + spin = Spin( + use_spin=cls.spin_dict["use_spin"], + virtual_scale=cls.spin_dict["virtual_scale"], + ) + spin_type_map = cls.expected_type_map + [ + item + "_spin" for item in cls.expected_type_map + ] + if Descrpt in [DescrptSeA, DescrptSeR, DescrptSeT]: + spin_sel = cls.expected_sel + cls.expected_sel + else: + spin_sel = cls.expected_sel + pair_exclude_types = spin.get_pair_exclude_types() + atom_exclude_types = spin.get_atom_exclude_types() + cls.input_dict_ds = DescriptorParam( + len(spin_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + spin_sel, + spin_type_map, + env_protection=1e-6, + exclude_types=pair_exclude_types, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(spin_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=spin_type_map, + ) + ft = Fitting( + **cls.input_dict_ft, + ) + backbone_model = EnergyModel( + ds, + ft, + type_map=spin_type_map, + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + cls.module = SpinEnergyModel(backbone_model=backbone_model, spin=spin) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybrid, DescrptHybrid), + (DescriptorParamHybridMixed, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamSeA, DescrptSeA), + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[ + (param_func, PropertyFittingNet) + for param_func in FittingParamPropertyList + ], + ), # fitting_class_param & class + ), +) +class TestPropertyModelPT(unittest.TestCase, PropertyModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls): + PropertyModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + # set special precision + if Descrpt in [DescrptDPA2]: + cls.epsilon_dict["test_smooth"] = 1e-8 + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + embedding_width=ds.get_dim_emb(), + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = PropertyModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() @parameterized( @@ -851,7 +885,7 @@ def setUpClass(cls): type_map=cls.expected_type_map, ) cls.module = DPLinearModel( - [dp_model1, dp_model2], + [dp_model1,dp_model2], type_map=cls.expected_type_map, ) # only test jit API once for different models From a9357841fbc37b8e175ec5732791aff4a3e435b0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:32:35 +0000 Subject: [PATCH 19/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/universal/pt/model/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 04ec9b6f94..dc38efe184 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -885,7 +885,7 @@ def setUpClass(cls): type_map=cls.expected_type_map, ) cls.module = DPLinearModel( - [dp_model1,dp_model2], + [dp_model1, dp_model2], type_map=cls.expected_type_map, ) # only test jit API once for different models From b664e55e439a885ffed9f347fd7348be22f803fe Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:41:36 +0800 Subject: [PATCH 20/29] fix: rename, fix UT device --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 2 +- deepmd/pt/model/model/__init__.py | 5 +++-- deepmd/pt/model/model/dp_linear_model.py | 2 +- doc/model/linear.md | 4 ++-- source/tests/common/test_examples.py | 1 + source/tests/universal/pt/model/test_model.py | 4 ++-- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 3837ed64d4..e09d6b13d0 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -404,7 +404,7 @@ def get_sel_type(self) -> list[int]: return torch.unique( torch.cat( [ - torch.as_tensor(model.get_sel_type(), dtype=torch.int32) + torch.as_tensor(model.get_sel_type(), dtype=torch.int32, device=env.DEVICE) for model in self.models ] ) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index aa13386289..26aefa6201 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -37,7 +37,7 @@ DOSModel, ) from .dp_linear_model import ( - DPLinearModel, + LinearEnergyModel, ) from .dp_model import ( DPModelCommon, @@ -155,7 +155,7 @@ def get_linear_model(model_params): atom_exclude_types = model_params.get("atom_exclude_types", []) pair_exclude_types = model_params.get("pair_exclude_types", []) - return DPLinearModel( + return LinearEnergyModel( models=list_of_models, type_map=model_params["type_map"], weights=weights, @@ -326,4 +326,5 @@ def get_model(model_params): "DPZBLModel", "make_model", "make_hessian_model", + "LinearEnergyModel", ] diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index 79cddcb35d..ef2e84bd19 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -29,7 +29,7 @@ @BaseModel.register("linear_ener") -class DPLinearModel(DPLinearModel_): +class LinearEnergyModel(DPLinearModel_): model_type = "ener" def __init__( diff --git a/doc/model/linear.md b/doc/model/linear.md index 3891559d90..47fdd1750b 100644 --- a/doc/model/linear.md +++ b/doc/model/linear.md @@ -1,7 +1,7 @@ -## Linear model {{ tensorflow_icon }} +## Linear model {{ tensorflow_icon }} {{ pytorch_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }} ::: One can linearly combine existing models with arbitrary coefficients: diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 6abb482824..cc2a7ad487 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -34,6 +34,7 @@ p_examples / "water" / "hybrid" / "input.json", p_examples / "water" / "dplr" / "train" / "dw.json", p_examples / "water" / "dplr" / "train" / "ener.json", + p_examples / "water" / "d3" / "input_pt.json", p_examples / "water" / "linear" / "input.json", p_examples / "nopbc" / "train" / "input.json", p_examples / "water_tensor" / "dipole" / "dipole_input.json", diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index dc38efe184..f6337c7459 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -19,7 +19,7 @@ from deepmd.pt.model.model import ( DipoleModel, DOSModel, - DPLinearModel, + LinearEnergyModel, DPZBLModel, EnergyModel, PolarModel, @@ -884,7 +884,7 @@ def setUpClass(cls): ft2, type_map=cls.expected_type_map, ) - cls.module = DPLinearModel( + cls.module = LinearEnergyModel( [dp_model1, dp_model2], type_map=cls.expected_type_map, ) From 8f06bb50fb304c0754383d50843abfc3327cae38 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 03:42:10 +0000 Subject: [PATCH 21/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 4 +++- source/tests/universal/pt/model/test_model.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index e09d6b13d0..013fa40ef0 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -404,7 +404,9 @@ def get_sel_type(self) -> list[int]: return torch.unique( torch.cat( [ - torch.as_tensor(model.get_sel_type(), dtype=torch.int32, device=env.DEVICE) + torch.as_tensor( + model.get_sel_type(), dtype=torch.int32, device=env.DEVICE + ) for model in self.models ] ) diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index f6337c7459..81c32eb94c 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -19,9 +19,9 @@ from deepmd.pt.model.model import ( DipoleModel, DOSModel, - LinearEnergyModel, DPZBLModel, EnergyModel, + LinearEnergyModel, PolarModel, PropertyModel, SpinEnergyModel, From af16e659bf2291332bf74c99ae5ef4fc5027fefd Mon Sep 17 00:00:00 2001 From: anyangml Date: Wed, 9 Oct 2024 12:37:59 +0800 Subject: [PATCH 22/29] change get_sel_type dtype to int64 --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 013fa40ef0..b0813b7b02 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -404,9 +404,7 @@ def get_sel_type(self) -> list[int]: return torch.unique( torch.cat( [ - torch.as_tensor( - model.get_sel_type(), dtype=torch.int32, device=env.DEVICE - ) + torch.as_tensor(model.get_sel_type(), dtype=torch.int64, device=env.DEVICE) for model in self.models ] ) From 34e3c97342fa6d35eb81cab1fb79d153c075c290 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 04:39:56 +0000 Subject: [PATCH 23/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index b0813b7b02..8d27fbcac4 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -404,7 +404,9 @@ def get_sel_type(self) -> list[int]: return torch.unique( torch.cat( [ - torch.as_tensor(model.get_sel_type(), dtype=torch.int64, device=env.DEVICE) + torch.as_tensor( + model.get_sel_type(), dtype=torch.int64, device=env.DEVICE + ) for model in self.models ] ) From 34843a6de03f05773a19b6c9e2edbee209e2e026 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 9 Oct 2024 13:57:28 +0800 Subject: [PATCH 24/29] feat: add test training --- source/tests/pt/model/test_permutation.py | 43 ++++++++ source/tests/pt/model/water/data/d3/dftd3.txt | 100 ++++++++++++++++++ source/tests/pt/model/water/linear_ener.json | 96 +++++++++++++++++ source/tests/pt/test_training.py | 15 +++ 4 files changed, 254 insertions(+) create mode 100644 source/tests/pt/model/water/data/d3/dftd3.txt create mode 100644 source/tests/pt/model/water/linear_ener.json diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 6aec895041..b0fc8ec759 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -98,6 +98,49 @@ "data_stat_nbatch": 20, } +model_linear = { + "type_map": ["O", "H"], + "type": "linear_ener", + "weights": "sum", + "models":[ + { + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + } + }, + { + "type": "pairtab", + "tab_file": f"{CUR_DIR}/water/data/d3/dftd3.txt", + "rcut": 10.0, + "sel": 534 + } + + ], + + "data_stat_nbatch": 20, +} + model_spin = { "type_map": ["O", "H", "B"], "descriptor": { diff --git a/source/tests/pt/model/water/data/d3/dftd3.txt b/source/tests/pt/model/water/data/d3/dftd3.txt new file mode 100644 index 0000000000..09e5fb697a --- /dev/null +++ b/source/tests/pt/model/water/data/d3/dftd3.txt @@ -0,0 +1,100 @@ +1.000000000000000056e-01 -5.836993924755046366e-03 -3.207255698139210940e-03 -1.843064837882633228e-03 +2.000000000000000111e-01 -5.836993806911452108e-03 -3.207255613696154226e-03 -1.843064776130543892e-03 +3.000000000000000444e-01 -5.836992560106194113e-03 -3.207254720510349828e-03 -1.843064123123401392e-03 +4.000000000000000222e-01 -5.836986225627246658e-03 -3.207250184384043221e-03 -1.843060811677158526e-03 +5.000000000000000000e-01 -5.836964436915091821e-03 -3.207234589497737730e-03 -1.843052788205641135e-03 +5.999999999999999778e-01 -5.836905460107320170e-03 -3.207192410957825698e-03 -1.843338972660025360e-03 +7.000000000000000666e-01 -5.836769626930583300e-03 -3.207096085246822614e-03 -1.851839876215982238e-03 +8.000000000000000444e-01 -5.836491030513121618e-03 -3.206924889333430135e-03 -2.035200426069873857e-03 +9.000000000000000222e-01 -5.835967602710929840e-03 -3.206999537190755728e-03 -3.724418810291191088e-03 +1.000000000000000000e+00 -5.835053775792304297e-03 -3.210477055685919626e-03 -4.311009958284344433e-03 +1.100000000000000089e+00 -5.833591489567684953e-03 -3.237527828601436623e-03 -4.381510573223419171e-03 +1.200000000000000178e+00 -5.831652981781070173e-03 -3.454845258034439960e-03 -4.394419437232751843e-03 +1.300000000000000266e+00 -5.830520601296543433e-03 -4.478070067533340692e-03 -4.394683688871586433e-03 +1.400000000000000133e+00 -5.835353622834494637e-03 -5.097530655625692915e-03 -4.389691198859401421e-03 +1.500000000000000222e+00 -5.863290690264541874e-03 -5.215500241204417201e-03 -4.380686516072217034e-03 +1.600000000000000089e+00 -6.007605076700822840e-03 -5.234994618743306349e-03 -4.367337507268855175e-03 +1.700000000000000178e+00 -6.481613230242359684e-03 -5.228094160806716871e-03 -4.348706108547779198e-03 +1.800000000000000266e+00 -6.814114687600298335e-03 -5.208252365588400719e-03 -4.323505520547227775e-03 +1.900000000000000133e+00 -6.876286379079538276e-03 -5.177988357772074675e-03 -4.290186895355558444e-03 +2.000000000000000000e+00 -6.858440816799354217e-03 -5.136887568332395605e-03 -4.246989919717190920e-03 +2.100000000000000089e+00 -6.810730159155128395e-03 -5.083475665301987606e-03 -4.192000168715152505e-03 +2.200000000000000178e+00 -6.742330737387775344e-03 -5.015815334399144516e-03 -4.123231519970332187e-03 +2.300000000000000266e+00 -6.653841351238824232e-03 -4.931782661310191510e-03 -4.038743210125123918e-03 +2.400000000000000355e+00 -6.543651317938833402e-03 -4.829269294496830317e-03 -3.936795390727530070e-03 +2.500000000000000444e+00 -6.409559281498313811e-03 -4.706385522261587705e-03 -3.816040239463167755e-03 +2.600000000000000089e+00 -6.249406635892575460e-03 -4.561685215972477100e-03 -3.675736338668155346e-03 +2.700000000000000178e+00 -6.061478463281754457e-03 -4.394408172892586353e-03 -3.515962176363645990e-03 +2.800000000000000266e+00 -5.844844934626365965e-03 -4.204716954930251029e-03 -3.337792190764940319e-03 +2.900000000000000355e+00 -5.599669004675433479e-03 -3.993889719587391009e-03 -3.143390268473208755e-03 +3.000000000000000444e+00 -5.327453506642119106e-03 -3.764420755089863558e-03 -2.935977648106832729e-03 +3.100000000000000089e+00 -5.031178000843260223e-03 -3.519982860915751074e-03 -2.719650568099894056e-03 +3.200000000000000178e+00 -4.715273672783852794e-03 -3.265225882759082918e-03 -2.499057451653833965e-03 +3.300000000000000266e+00 -4.385404785641488362e-03 -3.005422601424333727e-03 -2.278985743812388717e-03 +3.400000000000000355e+00 -4.048065433713449700e-03 -2.746015696661484231e-03 -2.063937321866260270e-03 +3.500000000000000444e+00 -3.710048572169818114e-03 -2.492149763588673555e-03 -1.857774171128685628e-03 +3.600000000000000089e+00 -3.377881092113224713e-03 -2.248275746149775312e-03 -1.663491260531681313e-03 +3.700000000000000178e+00 -3.057327225182689644e-03 -2.017890114824574810e-03 -1.483133951195727196e-03 +3.800000000000000266e+00 -2.753038981057491941e-03 -1.803430168074075671e-03 -1.317840750738439540e-03 +3.900000000000000355e+00 -2.468388171389931940e-03 -1.606308000309067743e-03 -1.167971059502070875e-03 +4.000000000000000000e+00 -2.205469013267805957e-03 -1.427041871266797194e-03 -1.033273795673775699e-03 +4.099999999999999645e+00 -1.965228953751702902e-03 -1.265437879541002862e-03 -9.130610310879381641e-04 +4.200000000000000178e+00 -1.747673832278765806e-03 -1.120782158543769547e-03 -8.063636493380576522e-04 +4.299999999999999822e+00 -1.552098284175109895e-03 -9.920168984562682292e-04 -7.120580835032176920e-04 +4.399999999999999467e+00 -1.377305748647780163e-03 -8.778864597897169646e-04 -6.289618864203703032e-04 +4.500000000000000000e+00 -1.221797526507303194e-03 -7.770496638083513111e-04 -5.559009474092405914e-04 +4.599999999999999645e+00 -1.083922782809847944e-03 -6.881603844395511003e-04 -4.917533939693695443e-04 +4.700000000000000178e+00 -9.619897379282633162e-04 -6.099214740721333600e-04 -4.354756390957214944e-04 +4.799999999999999822e+00 -8.543428352989788704e-04 -5.411178648690499965e-04 -3.861155118068372257e-04 +4.900000000000000355e+00 -7.594124385866309881e-04 -4.806343247547230249e-04 -3.428165131289927659e-04 +5.000000000000000000e+00 -6.757436744162991990e-04 -4.274624687438948085e-04 -3.048162971647301774e-04 +5.099999999999999645e+00 -6.020102408497160842e-04 -3.807006248475114439e-04 -2.714416410742632600e-04 +5.200000000000000178e+00 -5.370178955485286568e-04 -3.395492294862310413e-04 -2.421014916366724180e-04 +5.299999999999999822e+00 -4.797012289428498875e-04 -3.033036596191310643e-04 -2.162791601488694472e-04 +5.400000000000000355e+00 -4.291163603974148220e-04 -2.713458112672340397e-04 -1.935243599692976007e-04 +5.500000000000000000e+00 -3.844314156775488251e-04 -2.431352896687036106e-04 -1.734455139070628909e-04 +5.599999999999999645e+00 -3.449160478270333653e-04 -2.182007570692257958e-04 -1.557025751017268144e-04 +5.700000000000000178e+00 -3.099308250478081581e-04 -1.961317615550248216e-04 -1.400004825033046053e-04 +5.799999999999999822e+00 -2.789169965744946232e-04 -1.765712194195623135e-04 -1.260832928664179986e-04 +5.900000000000000355e+00 -2.513869308376957498e-04 -1.592086242469989389e-04 -1.137289820430557137e-04 +6.000000000000000000e+00 -2.269153740910770769e-04 -1.437739928526920498e-04 -1.027448796326863424e-04 +6.099999999999999645e+00 -2.051315821421489645e-04 -1.300325201124615134e-04 -9.296368585686617101e-05 +6.200000000000000178e+00 -1.857123177371916057e-04 -1.177798933810950252e-04 -8.424001307773229020e-05 +6.299999999999999822e+00 -1.683756703844696025e-04 -1.068382068924710331e-04 -7.644739339328402133e-05 +6.400000000000000355e+00 -1.528756359693242027e-04 -9.705241326038571551e-05 -6.947569600606938272e-05 +6.500000000000000000e+00 -1.389973847900246836e-04 -8.828725024164000819e-05 -6.322890211399061677e-05 +6.599999999999999645e+00 -1.265531447216864910e-04 -8.042458445906290783e-05 -5.762318996708282935e-05 +6.700000000000000178e+00 -1.153786284350462083e-04 -7.336111861455703732e-05 -5.258528787829301387e-05 +6.799999999999999822e+00 -1.053299381724164837e-04 -6.700641408290249014e-05 -4.805105801105378322e-05 +6.900000000000000355e+00 -9.628088734156424651e-05 -6.128118618925484863e-05 -4.396427848897285724e-05 +7.000000000000000000e+00 -8.812068437769617318e-05 -5.611583465513190913e-05 -4.027559568117881135e-05 +7.099999999999999645e+00 -8.075193047879847589e-05 -5.144917649553730292e-05 -3.694162237238345173e-05 +7.200000000000000178e+00 -7.408888866698059216e-05 -4.722735299269381154e-05 -3.392416093003276399e-05 +7.299999999999999822e+00 -6.805598702152939358e-05 -4.340288624040664528e-05 -3.118953355495664799e-05 +7.400000000000000355e+00 -6.258652380321327402e-05 -3.993386415929820209e-05 -2.870800428147100793e-05 +7.500000000000000000e+00 -5.762154653038724025e-05 -3.678323585657680581e-05 -2.645327961777798337e-05 +7.599999999999999645e+00 -5.310888089285451013e-05 -3.391820178292012157e-05 -2.440207662848582849e-05 +7.700000000000000178e+00 -4.900228873380196631e-05 -3.130968536501008690e-05 -2.253374889733842244e-05 +7.799999999999999822e+00 -4.526073723647751752e-05 -2.893187470658392955e-05 -2.082996220611386151e-05 +7.900000000000000355e+00 -4.184776396661089387e-05 -2.676182459276817884e-05 -1.927441295794754013e-05 +8.000000000000000000e+00 -3.873092458939377268e-05 -2.477911043795883125e-05 -1.785258338919504348e-05 +8.099999999999999645e+00 -3.588131194417033489e-05 -2.296552701898519263e-05 -1.655152847892165657e-05 +8.199999999999999289e+00 -3.327313676038550535e-05 -2.130482586144845337e-05 -1.535969020138178571e-05 +8.300000000000000711e+00 -3.088336167038252842e-05 -1.978248602307972753e-05 -1.426673539356727330e-05 +8.400000000000000355e+00 -2.869138134992016182e-05 -1.838551376555334571e-05 -1.326341404347894562e-05 +8.500000000000000000e+00 -2.667874262351516647e-05 -1.710226724425689568e-05 -1.234143525923654349e-05 +8.599999999999999645e+00 -2.482889923305170253e-05 -1.592230289025118773e-05 -1.149335856644029766e-05 +8.699999999999999289e+00 -2.312699670543422217e-05 -1.483624062390156494e-05 -1.071249851148625846e-05 +8.800000000000000711e+00 -2.155968338640409072e-05 -1.383564543725925445e-05 -9.992840830439041262e-06 +8.900000000000000355e+00 -2.011494424844594071e-05 -1.291292322228744002e-05 -9.328968683880578804e-06 +9.000000000000000000e+00 -1.878195454422273937e-05 -1.206122901303710759e-05 -8.715997664073560640e-06 +9.099999999999999645e+00 -1.755095077450199208e-05 -1.127438605916122122e-05 -8.149518457037649769e-06 +9.199999999999999289e+00 -1.641311678073074286e-05 -1.054681436190059537e-05 -7.625546193173260002e-06 +9.300000000000000711e+00 -1.536048306550537418e-05 -9.873467487129955082e-06 -7.140475649634374041e-06 +9.400000000000000355e+00 -1.438583769617946587e-05 -9.249776627676899564e-06 -6.691041578929334717e-06 +9.500000000000000000e+00 -1.348264736372320039e-05 -8.671601022702781346e-06 -6.274283533910064576e-06 +9.599999999999999645e+00 -1.264498735578012246e-05 -8.135183958678718279e-06 -5.887514641681122086e-06 +9.700000000000001066e+00 -1.186747936398473687e-05 -7.637113677130612127e-06 -5.528293849956352819e-06 +9.800000000000000711e+00 -1.114523618469756001e-05 -7.174288601187318493e-06 -5.194401230658985063e-06 +9.900000000000000355e+00 -1.047381249252528874e-05 -6.743886368019750717e-06 -4.883815978498405921e-06 +1.000000000000000000e+01 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 diff --git a/source/tests/pt/model/water/linear_ener.json b/source/tests/pt/model/water/linear_ener.json new file mode 100644 index 0000000000..c2d9304a7e --- /dev/null +++ b/source/tests/pt/model/water/linear_ener.json @@ -0,0 +1,96 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "linear_ener", + "weights": "sum", + "type_map": [ + "O", + "H" + ], + "models": [ + { + "descriptor": { + "type": "se_atten", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "type_one_side": true, + "precision": "float64", + "seed": 1, + "_comment2": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float64", + "seed": 1, + "_comment3": " that's all" + }, + "_comment4": " that's all" + }, + { + "type": "pairtab", + "tab_file": "dftd3.txt", + "rcut": 10.0, + "sel": 534 + } + ] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment5": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment6": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0/", + "../data/data_1/", + "../data/data_2/" + ], + "batch_size": "auto", + "_comment7": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment8": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment9": "that's all" + }, + "_comment10": "that's all" +} diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index fa9e5c138a..5b1b30083d 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -26,6 +26,7 @@ model_hybrid, model_se_e2_a, model_zbl, + model_linear, ) @@ -187,6 +188,20 @@ def setUp(self): def tearDown(self) -> None: DPTrainTest.tearDown(self) +class TestLinearEnergyModel(unittest.TestCase, DPTrainTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/linear_energy.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_linear) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) class TestFparam(unittest.TestCase, DPTrainTest): """Test if `fparam` can be loaded correctly.""" From 1579a7eabb5441e90eb65f35f60483d4722aefef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 05:58:23 +0000 Subject: [PATCH 25/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_permutation.py | 64 +++++++++++------------ source/tests/pt/test_training.py | 4 +- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index b0fc8ec759..99851dba3d 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -102,42 +102,40 @@ "type_map": ["O", "H"], "type": "linear_ener", "weights": "sum", - "models":[ + "models": [ { - "descriptor": { - "type": "se_atten", - "sel": 40, - "rcut_smth": 0.5, - "rcut": 4.0, - "neuron": [25, 50, 100], - "axis_neuron": 16, - "attn": 64, - "attn_layer": 2, - "attn_dotr": True, - "attn_mask": False, - "activation_function": "tanh", - "scaling_factor": 1.0, - "normalize": False, - "temperature": 1.0, - "set_davg_zero": True, - "type_one_side": True, - "seed": 1, + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, + }, + { + "type": "pairtab", + "tab_file": f"{CUR_DIR}/water/data/d3/dftd3.txt", + "rcut": 10.0, + "sel": 534, }, - "fitting_net": { - "neuron": [24, 24, 24], - "resnet_dt": True, - "seed": 1, - } - }, - { - "type": "pairtab", - "tab_file": f"{CUR_DIR}/water/data/d3/dftd3.txt", - "rcut": 10.0, - "sel": 534 - } - ], - "data_stat_nbatch": 20, } diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 5b1b30083d..4eb87633cf 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -24,9 +24,9 @@ model_dpa1, model_dpa2, model_hybrid, + model_linear, model_se_e2_a, model_zbl, - model_linear, ) @@ -188,6 +188,7 @@ def setUp(self): def tearDown(self) -> None: DPTrainTest.tearDown(self) + class TestLinearEnergyModel(unittest.TestCase, DPTrainTest): def setUp(self): input_json = str(Path(__file__).parent / "water/linear_energy.json") @@ -203,6 +204,7 @@ def setUp(self): def tearDown(self) -> None: DPTrainTest.tearDown(self) + class TestFparam(unittest.TestCase, DPTrainTest): """Test if `fparam` can be loaded correctly.""" From 11350e2f9dc040e6cc7b395bcb7b6d421a144727 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:09:48 +0800 Subject: [PATCH 26/29] fix: revert changes --- source/tests/pt/model/test_permutation.py | 40 ------- source/tests/pt/model/water/data/d3/dftd3.txt | 100 ------------------ source/tests/pt/model/water/linear_ener.json | 96 ----------------- source/tests/pt/test_training.py | 15 --- 4 files changed, 251 deletions(-) delete mode 100644 source/tests/pt/model/water/data/d3/dftd3.txt delete mode 100644 source/tests/pt/model/water/linear_ener.json diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 99851dba3d..2d391c7115 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -98,46 +98,6 @@ "data_stat_nbatch": 20, } -model_linear = { - "type_map": ["O", "H"], - "type": "linear_ener", - "weights": "sum", - "models": [ - { - "descriptor": { - "type": "se_atten", - "sel": 40, - "rcut_smth": 0.5, - "rcut": 4.0, - "neuron": [25, 50, 100], - "axis_neuron": 16, - "attn": 64, - "attn_layer": 2, - "attn_dotr": True, - "attn_mask": False, - "activation_function": "tanh", - "scaling_factor": 1.0, - "normalize": False, - "temperature": 1.0, - "set_davg_zero": True, - "type_one_side": True, - "seed": 1, - }, - "fitting_net": { - "neuron": [24, 24, 24], - "resnet_dt": True, - "seed": 1, - }, - }, - { - "type": "pairtab", - "tab_file": f"{CUR_DIR}/water/data/d3/dftd3.txt", - "rcut": 10.0, - "sel": 534, - }, - ], - "data_stat_nbatch": 20, -} model_spin = { "type_map": ["O", "H", "B"], diff --git a/source/tests/pt/model/water/data/d3/dftd3.txt b/source/tests/pt/model/water/data/d3/dftd3.txt deleted file mode 100644 index 09e5fb697a..0000000000 --- a/source/tests/pt/model/water/data/d3/dftd3.txt +++ /dev/null @@ -1,100 +0,0 @@ -1.000000000000000056e-01 -5.836993924755046366e-03 -3.207255698139210940e-03 -1.843064837882633228e-03 -2.000000000000000111e-01 -5.836993806911452108e-03 -3.207255613696154226e-03 -1.843064776130543892e-03 -3.000000000000000444e-01 -5.836992560106194113e-03 -3.207254720510349828e-03 -1.843064123123401392e-03 -4.000000000000000222e-01 -5.836986225627246658e-03 -3.207250184384043221e-03 -1.843060811677158526e-03 -5.000000000000000000e-01 -5.836964436915091821e-03 -3.207234589497737730e-03 -1.843052788205641135e-03 -5.999999999999999778e-01 -5.836905460107320170e-03 -3.207192410957825698e-03 -1.843338972660025360e-03 -7.000000000000000666e-01 -5.836769626930583300e-03 -3.207096085246822614e-03 -1.851839876215982238e-03 -8.000000000000000444e-01 -5.836491030513121618e-03 -3.206924889333430135e-03 -2.035200426069873857e-03 -9.000000000000000222e-01 -5.835967602710929840e-03 -3.206999537190755728e-03 -3.724418810291191088e-03 -1.000000000000000000e+00 -5.835053775792304297e-03 -3.210477055685919626e-03 -4.311009958284344433e-03 -1.100000000000000089e+00 -5.833591489567684953e-03 -3.237527828601436623e-03 -4.381510573223419171e-03 -1.200000000000000178e+00 -5.831652981781070173e-03 -3.454845258034439960e-03 -4.394419437232751843e-03 -1.300000000000000266e+00 -5.830520601296543433e-03 -4.478070067533340692e-03 -4.394683688871586433e-03 -1.400000000000000133e+00 -5.835353622834494637e-03 -5.097530655625692915e-03 -4.389691198859401421e-03 -1.500000000000000222e+00 -5.863290690264541874e-03 -5.215500241204417201e-03 -4.380686516072217034e-03 -1.600000000000000089e+00 -6.007605076700822840e-03 -5.234994618743306349e-03 -4.367337507268855175e-03 -1.700000000000000178e+00 -6.481613230242359684e-03 -5.228094160806716871e-03 -4.348706108547779198e-03 -1.800000000000000266e+00 -6.814114687600298335e-03 -5.208252365588400719e-03 -4.323505520547227775e-03 -1.900000000000000133e+00 -6.876286379079538276e-03 -5.177988357772074675e-03 -4.290186895355558444e-03 -2.000000000000000000e+00 -6.858440816799354217e-03 -5.136887568332395605e-03 -4.246989919717190920e-03 -2.100000000000000089e+00 -6.810730159155128395e-03 -5.083475665301987606e-03 -4.192000168715152505e-03 -2.200000000000000178e+00 -6.742330737387775344e-03 -5.015815334399144516e-03 -4.123231519970332187e-03 -2.300000000000000266e+00 -6.653841351238824232e-03 -4.931782661310191510e-03 -4.038743210125123918e-03 -2.400000000000000355e+00 -6.543651317938833402e-03 -4.829269294496830317e-03 -3.936795390727530070e-03 -2.500000000000000444e+00 -6.409559281498313811e-03 -4.706385522261587705e-03 -3.816040239463167755e-03 -2.600000000000000089e+00 -6.249406635892575460e-03 -4.561685215972477100e-03 -3.675736338668155346e-03 -2.700000000000000178e+00 -6.061478463281754457e-03 -4.394408172892586353e-03 -3.515962176363645990e-03 -2.800000000000000266e+00 -5.844844934626365965e-03 -4.204716954930251029e-03 -3.337792190764940319e-03 -2.900000000000000355e+00 -5.599669004675433479e-03 -3.993889719587391009e-03 -3.143390268473208755e-03 -3.000000000000000444e+00 -5.327453506642119106e-03 -3.764420755089863558e-03 -2.935977648106832729e-03 -3.100000000000000089e+00 -5.031178000843260223e-03 -3.519982860915751074e-03 -2.719650568099894056e-03 -3.200000000000000178e+00 -4.715273672783852794e-03 -3.265225882759082918e-03 -2.499057451653833965e-03 -3.300000000000000266e+00 -4.385404785641488362e-03 -3.005422601424333727e-03 -2.278985743812388717e-03 -3.400000000000000355e+00 -4.048065433713449700e-03 -2.746015696661484231e-03 -2.063937321866260270e-03 -3.500000000000000444e+00 -3.710048572169818114e-03 -2.492149763588673555e-03 -1.857774171128685628e-03 -3.600000000000000089e+00 -3.377881092113224713e-03 -2.248275746149775312e-03 -1.663491260531681313e-03 -3.700000000000000178e+00 -3.057327225182689644e-03 -2.017890114824574810e-03 -1.483133951195727196e-03 -3.800000000000000266e+00 -2.753038981057491941e-03 -1.803430168074075671e-03 -1.317840750738439540e-03 -3.900000000000000355e+00 -2.468388171389931940e-03 -1.606308000309067743e-03 -1.167971059502070875e-03 -4.000000000000000000e+00 -2.205469013267805957e-03 -1.427041871266797194e-03 -1.033273795673775699e-03 -4.099999999999999645e+00 -1.965228953751702902e-03 -1.265437879541002862e-03 -9.130610310879381641e-04 -4.200000000000000178e+00 -1.747673832278765806e-03 -1.120782158543769547e-03 -8.063636493380576522e-04 -4.299999999999999822e+00 -1.552098284175109895e-03 -9.920168984562682292e-04 -7.120580835032176920e-04 -4.399999999999999467e+00 -1.377305748647780163e-03 -8.778864597897169646e-04 -6.289618864203703032e-04 -4.500000000000000000e+00 -1.221797526507303194e-03 -7.770496638083513111e-04 -5.559009474092405914e-04 -4.599999999999999645e+00 -1.083922782809847944e-03 -6.881603844395511003e-04 -4.917533939693695443e-04 -4.700000000000000178e+00 -9.619897379282633162e-04 -6.099214740721333600e-04 -4.354756390957214944e-04 -4.799999999999999822e+00 -8.543428352989788704e-04 -5.411178648690499965e-04 -3.861155118068372257e-04 -4.900000000000000355e+00 -7.594124385866309881e-04 -4.806343247547230249e-04 -3.428165131289927659e-04 -5.000000000000000000e+00 -6.757436744162991990e-04 -4.274624687438948085e-04 -3.048162971647301774e-04 -5.099999999999999645e+00 -6.020102408497160842e-04 -3.807006248475114439e-04 -2.714416410742632600e-04 -5.200000000000000178e+00 -5.370178955485286568e-04 -3.395492294862310413e-04 -2.421014916366724180e-04 -5.299999999999999822e+00 -4.797012289428498875e-04 -3.033036596191310643e-04 -2.162791601488694472e-04 -5.400000000000000355e+00 -4.291163603974148220e-04 -2.713458112672340397e-04 -1.935243599692976007e-04 -5.500000000000000000e+00 -3.844314156775488251e-04 -2.431352896687036106e-04 -1.734455139070628909e-04 -5.599999999999999645e+00 -3.449160478270333653e-04 -2.182007570692257958e-04 -1.557025751017268144e-04 -5.700000000000000178e+00 -3.099308250478081581e-04 -1.961317615550248216e-04 -1.400004825033046053e-04 -5.799999999999999822e+00 -2.789169965744946232e-04 -1.765712194195623135e-04 -1.260832928664179986e-04 -5.900000000000000355e+00 -2.513869308376957498e-04 -1.592086242469989389e-04 -1.137289820430557137e-04 -6.000000000000000000e+00 -2.269153740910770769e-04 -1.437739928526920498e-04 -1.027448796326863424e-04 -6.099999999999999645e+00 -2.051315821421489645e-04 -1.300325201124615134e-04 -9.296368585686617101e-05 -6.200000000000000178e+00 -1.857123177371916057e-04 -1.177798933810950252e-04 -8.424001307773229020e-05 -6.299999999999999822e+00 -1.683756703844696025e-04 -1.068382068924710331e-04 -7.644739339328402133e-05 -6.400000000000000355e+00 -1.528756359693242027e-04 -9.705241326038571551e-05 -6.947569600606938272e-05 -6.500000000000000000e+00 -1.389973847900246836e-04 -8.828725024164000819e-05 -6.322890211399061677e-05 -6.599999999999999645e+00 -1.265531447216864910e-04 -8.042458445906290783e-05 -5.762318996708282935e-05 -6.700000000000000178e+00 -1.153786284350462083e-04 -7.336111861455703732e-05 -5.258528787829301387e-05 -6.799999999999999822e+00 -1.053299381724164837e-04 -6.700641408290249014e-05 -4.805105801105378322e-05 -6.900000000000000355e+00 -9.628088734156424651e-05 -6.128118618925484863e-05 -4.396427848897285724e-05 -7.000000000000000000e+00 -8.812068437769617318e-05 -5.611583465513190913e-05 -4.027559568117881135e-05 -7.099999999999999645e+00 -8.075193047879847589e-05 -5.144917649553730292e-05 -3.694162237238345173e-05 -7.200000000000000178e+00 -7.408888866698059216e-05 -4.722735299269381154e-05 -3.392416093003276399e-05 -7.299999999999999822e+00 -6.805598702152939358e-05 -4.340288624040664528e-05 -3.118953355495664799e-05 -7.400000000000000355e+00 -6.258652380321327402e-05 -3.993386415929820209e-05 -2.870800428147100793e-05 -7.500000000000000000e+00 -5.762154653038724025e-05 -3.678323585657680581e-05 -2.645327961777798337e-05 -7.599999999999999645e+00 -5.310888089285451013e-05 -3.391820178292012157e-05 -2.440207662848582849e-05 -7.700000000000000178e+00 -4.900228873380196631e-05 -3.130968536501008690e-05 -2.253374889733842244e-05 -7.799999999999999822e+00 -4.526073723647751752e-05 -2.893187470658392955e-05 -2.082996220611386151e-05 -7.900000000000000355e+00 -4.184776396661089387e-05 -2.676182459276817884e-05 -1.927441295794754013e-05 -8.000000000000000000e+00 -3.873092458939377268e-05 -2.477911043795883125e-05 -1.785258338919504348e-05 -8.099999999999999645e+00 -3.588131194417033489e-05 -2.296552701898519263e-05 -1.655152847892165657e-05 -8.199999999999999289e+00 -3.327313676038550535e-05 -2.130482586144845337e-05 -1.535969020138178571e-05 -8.300000000000000711e+00 -3.088336167038252842e-05 -1.978248602307972753e-05 -1.426673539356727330e-05 -8.400000000000000355e+00 -2.869138134992016182e-05 -1.838551376555334571e-05 -1.326341404347894562e-05 -8.500000000000000000e+00 -2.667874262351516647e-05 -1.710226724425689568e-05 -1.234143525923654349e-05 -8.599999999999999645e+00 -2.482889923305170253e-05 -1.592230289025118773e-05 -1.149335856644029766e-05 -8.699999999999999289e+00 -2.312699670543422217e-05 -1.483624062390156494e-05 -1.071249851148625846e-05 -8.800000000000000711e+00 -2.155968338640409072e-05 -1.383564543725925445e-05 -9.992840830439041262e-06 -8.900000000000000355e+00 -2.011494424844594071e-05 -1.291292322228744002e-05 -9.328968683880578804e-06 -9.000000000000000000e+00 -1.878195454422273937e-05 -1.206122901303710759e-05 -8.715997664073560640e-06 -9.099999999999999645e+00 -1.755095077450199208e-05 -1.127438605916122122e-05 -8.149518457037649769e-06 -9.199999999999999289e+00 -1.641311678073074286e-05 -1.054681436190059537e-05 -7.625546193173260002e-06 -9.300000000000000711e+00 -1.536048306550537418e-05 -9.873467487129955082e-06 -7.140475649634374041e-06 -9.400000000000000355e+00 -1.438583769617946587e-05 -9.249776627676899564e-06 -6.691041578929334717e-06 -9.500000000000000000e+00 -1.348264736372320039e-05 -8.671601022702781346e-06 -6.274283533910064576e-06 -9.599999999999999645e+00 -1.264498735578012246e-05 -8.135183958678718279e-06 -5.887514641681122086e-06 -9.700000000000001066e+00 -1.186747936398473687e-05 -7.637113677130612127e-06 -5.528293849956352819e-06 -9.800000000000000711e+00 -1.114523618469756001e-05 -7.174288601187318493e-06 -5.194401230658985063e-06 -9.900000000000000355e+00 -1.047381249252528874e-05 -6.743886368019750717e-06 -4.883815978498405921e-06 -1.000000000000000000e+01 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 diff --git a/source/tests/pt/model/water/linear_ener.json b/source/tests/pt/model/water/linear_ener.json deleted file mode 100644 index c2d9304a7e..0000000000 --- a/source/tests/pt/model/water/linear_ener.json +++ /dev/null @@ -1,96 +0,0 @@ -{ - "_comment1": " model parameters", - "model": { - "type": "linear_ener", - "weights": "sum", - "type_map": [ - "O", - "H" - ], - "models": [ - { - "descriptor": { - "type": "se_atten", - "sel": [ - 46, - 92 - ], - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [ - 25, - 50, - 100 - ], - "resnet_dt": false, - "axis_neuron": 16, - "type_one_side": true, - "precision": "float64", - "seed": 1, - "_comment2": " that's all" - }, - "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], - "resnet_dt": true, - "precision": "float64", - "seed": 1, - "_comment3": " that's all" - }, - "_comment4": " that's all" - }, - { - "type": "pairtab", - "tab_file": "dftd3.txt", - "rcut": 10.0, - "sel": 534 - } - ] - }, - "learning_rate": { - "type": "exp", - "decay_steps": 5000, - "start_lr": 0.001, - "stop_lr": 3.51e-8, - "_comment5": "that's all" - }, - "loss": { - "type": "ener", - "start_pref_e": 0.02, - "limit_pref_e": 1, - "start_pref_f": 1000, - "limit_pref_f": 1, - "start_pref_v": 0, - "limit_pref_v": 0, - "_comment6": " that's all" - }, - "training": { - "training_data": { - "systems": [ - "../data/data_0/", - "../data/data_1/", - "../data/data_2/" - ], - "batch_size": "auto", - "_comment7": "that's all" - }, - "validation_data": { - "systems": [ - "../data/data_3" - ], - "batch_size": 1, - "numb_btch": 3, - "_comment8": "that's all" - }, - "numb_steps": 1000000, - "seed": 10, - "disp_file": "lcurve.out", - "disp_freq": 100, - "save_freq": 1000, - "_comment9": "that's all" - }, - "_comment10": "that's all" -} diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 4eb87633cf..d41b21f442 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -189,21 +189,6 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) -class TestLinearEnergyModel(unittest.TestCase, DPTrainTest): - def setUp(self): - input_json = str(Path(__file__).parent / "water/linear_energy.json") - with open(input_json) as f: - self.config = json.load(f) - data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.config["training"]["training_data"]["systems"] = data_file - self.config["training"]["validation_data"]["systems"] = data_file - self.config["model"] = deepcopy(model_linear) - self.config["training"]["numb_steps"] = 1 - self.config["training"]["save_freq"] = 1 - - def tearDown(self) -> None: - DPTrainTest.tearDown(self) - class TestFparam(unittest.TestCase, DPTrainTest): """Test if `fparam` can be loaded correctly.""" From 5b5e9489325d6d7f6e1ee124f466d81ab6924704 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 06:12:12 +0000 Subject: [PATCH 27/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_training.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index d41b21f442..fa9e5c138a 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -24,7 +24,6 @@ model_dpa1, model_dpa2, model_hybrid, - model_linear, model_se_e2_a, model_zbl, ) @@ -189,7 +188,6 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) - class TestFparam(unittest.TestCase, DPTrainTest): """Test if `fparam` can be loaded correctly.""" From 576c2898c5a89b3493bd940e1a998a3bcf755f80 Mon Sep 17 00:00:00 2001 From: anyangml Date: Thu, 10 Oct 2024 09:49:19 +0800 Subject: [PATCH 28/29] fix: update zbl example descriptor --- examples/water/zbl/input.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/water/zbl/input.json b/examples/water/zbl/input.json index 1c951f3de5..54586ca0cf 100644 --- a/examples/water/zbl/input.json +++ b/examples/water/zbl/input.json @@ -10,7 +10,7 @@ "H" ], "descriptor": { - "type": "se_atten", + "type": "se_atten_v2", "sel": [ 46, 92 From d3b3342e195a0f26db09a2a41a59ee5e71d97406 Mon Sep 17 00:00:00 2001 From: anyangml Date: Thu, 10 Oct 2024 16:52:54 +0800 Subject: [PATCH 29/29] feat: add linear example --- examples/water/linear/input_pt.json | 124 +++++++++++++++++++++++++++ source/tests/common/test_examples.py | 1 + 2 files changed, 125 insertions(+) create mode 100644 examples/water/linear/input_pt.json diff --git a/examples/water/linear/input_pt.json b/examples/water/linear/input_pt.json new file mode 100644 index 0000000000..e8d8e07136 --- /dev/null +++ b/examples/water/linear/input_pt.json @@ -0,0 +1,124 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "linear_ener", + "weights": "sum", + "type_map": [ + "O", + "H" + ], + "models": [ + { + "descriptor": { + "type": "se_atten", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "type_one_side": true, + "precision": "float64", + "seed": 1, + "_comment2": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float64", + "seed": 1, + "_comment3": " that's all" + }, + "_comment4": " that's all" + }, + { + "descriptor": { + "type": "se_atten", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "type_one_side": true, + "precision": "float64", + "seed": 1, + "_comment2": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float64", + "seed": 1, + "_comment3": " that's all" + }, + "_comment4": " that's all" + } + ] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment5": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment6": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0/", + "../data/data_1/", + "../data/data_2/" + ], + "batch_size": "auto", + "_comment7": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment8": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment9": "that's all" + }, + "_comment10": "that's all" +} diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index cc2a7ad487..246e767f01 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -36,6 +36,7 @@ p_examples / "water" / "dplr" / "train" / "ener.json", p_examples / "water" / "d3" / "input_pt.json", p_examples / "water" / "linear" / "input.json", + p_examples / "water" / "linear" / "input_pt.json", p_examples / "nopbc" / "train" / "input.json", p_examples / "water_tensor" / "dipole" / "dipole_input.json", p_examples / "water_tensor" / "polar" / "polar_input.json",