diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 56af5f4f43..a2cbef3eee 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -363,21 +363,25 @@ def compute_or_load_stat( self, merged: Union[Callable[[], list[dict]], list[dict]], stat_file_path: Optional[DPPath] = None, + compute_or_load_out_stat: bool = True, ) -> NoReturn: """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. Parameters ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. + merged + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ raise NotImplementedError diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 5fa787b17b..62c7d78d75 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -285,6 +285,7 @@ def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, + compute_or_load_out_stat: bool = True, ) -> None: """ Compute or load the statistics parameters of the model, @@ -300,6 +301,9 @@ def compute_or_load_stat( The lazy sampled function to get data frames from different data systems. stat_file_path The dictionary of paths to the statistics files. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ if stat_file_path is not None and self.type_map is not None: # descriptors and fitting net with different type_map @@ -323,7 +327,8 @@ def wrapped_sampler(): self.fitting_net.compute_input_stats( wrapped_sampler, protection=self.data_stat_protect ) - self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 3d894dc3a0..9b20d80516 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools from typing import ( - Callable, Optional, Union, ) @@ -319,6 +319,10 @@ def apply_out_stat( The atom types. nf x nloc """ + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + # nf x nloc x odims, out_bias: ntypes x odims + ret[kk] = ret[kk] + out_bias[kk][atype] return ret @staticmethod @@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool: """ return False - def compute_or_load_out_stat( - self, - merged: Union[Callable[[], list[dict]], list[dict]], - stat_file_path: Optional[DPPath] = None, - ) -> None: - """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. - - Parameters - ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. - - """ - for md in self.models: - md.compute_or_load_out_stat(merged, stat_file_path) - def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, + compute_or_load_out_stat: bool = True, ) -> None: """ Compute or load the statistics parameters of the model, @@ -507,9 +488,34 @@ def compute_or_load_stat( The lazy sampled function to get data frames from different data systems. stat_file_path The dictionary of paths to the statistics files. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ for md in self.models: - md.compute_or_load_stat(sampled_func, stat_file_path) + md.compute_or_load_stat( + sampled_func, stat_file_path, compute_or_load_out_stat=False + ) + + if stat_file_path is not None and self.type_map is not None: + # descriptors and fitting net with different type_map + # should not share the same parameters + stat_file_path /= " ".join(self.type_map) + + @functools.lru_cache + def wrapped_sampler(): + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + return sampled + + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 62b47afb32..8f73d81d76 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -224,26 +224,31 @@ def deserialize(cls, data) -> "PairTabAtomicModel": def compute_or_load_stat( self, - merged: Union[Callable[[], list[dict]], list[dict]], + sampled_func: Union[Callable[[], list[dict]], list[dict]], stat_file_path: Optional[DPPath] = None, + compute_or_load_out_stat: bool = True, ) -> None: """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. Parameters ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ - self.compute_or_load_out_stat(merged, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(sampled_func, stat_file_path) def forward_atomic( self, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index c5dbdfd9dd..9b2b896d02 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -510,15 +510,31 @@ def collect_single_finetune_params( if i != "_extra_state" and f".{_model_key}." in i ] for item_key in target_keys: - if _new_fitting and (".descriptor." not in item_key): + new_key = item_key.replace( + f".{_model_key}.", f".{_model_key_from}." + ) + use_random_initialization = _new_fitting and ( + ".descriptor." not in item_key + ) + if ( + not use_random_initialization + and new_key not in _origin_state_dict + ): + # for ZBL models finetuning from standard models + if ".models.0." in new_key: + new_key = new_key.replace(".models.0.", ".") + elif ".models.1." in new_key: + use_random_initialization = True + else: + raise KeyError( + f"Key {new_key} not found in pretrained model." + ) + if use_random_initialization: # print(f'Keep {item_key} in old model!') _new_state_dict[item_key] = ( _random_state_dict[item_key].clone().detach() ) else: - new_key = item_key.replace( - f".{_model_key}.", f".{_model_key_from}." - ) # print(f'Replace {item_key} with {new_key} in pretrained_model!') _new_state_dict[item_key] = ( _origin_state_dict[new_key].clone().detach() diff --git a/source/tests/pt/model/test_linear_atomic_model_stat.py b/source/tests/pt/model/test_linear_atomic_model_stat.py index 90758526b9..0f92f1253d 100644 --- a/source/tests/pt/model/test_linear_atomic_model_stat.py +++ b/source/tests/pt/model/test_linear_atomic_model_stat.py @@ -233,16 +233,11 @@ def test_linear_atomic_model_stat_with_bias(self) -> None: linear_model.compute_or_load_out_stat( self.merged_output_stat, stat_file_path=self.stat_file_path ) - # bias applied to sub atomic models. ener_bias = np.array([1.0, 3.0]).reshape(2, 1) - linear_ret = [] - for idx, md in enumerate(linear_model.models): - ret = md.forward_common_atomic(*args) - ret = to_numpy_array(ret["energy"]) - linear_ret.append(ret_no_bias[idx] + ener_bias[at]) - np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret) + ret = to_numpy_array(linear_model.forward_common_atomic(*args)["energy"]) + np.testing.assert_almost_equal((ret0 + ener_bias[at]), ret) # linear model not adding bias again ret1 = linear_model.forward_common_atomic(*args) ret1 = to_numpy_array(ret1["energy"]) - np.testing.assert_almost_equal(np.mean(np.stack(linear_ret), axis=0), ret1) + np.testing.assert_almost_equal(ret, ret1) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 3df95e4b14..c57c896197 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -30,6 +30,8 @@ class DPTrainTest: + test_zbl_from_standard: bool = False + def test_dp_train(self) -> None: # test training from scratch trainer = get_trainer(deepcopy(self.config)) @@ -95,6 +97,34 @@ def test_dp_train(self) -> None: state_dict_finetuned_random[state_key], ) + if self.test_zbl_from_standard: + # test fine-tuning using zbl from standard model + finetune_model = ( + self.config["training"].get("save_ckpt", "model.ckpt") + ".pt" + ) + self.config_zbl["model"], finetune_links = get_finetune_rules( + finetune_model, + self.config_zbl["model"], + ) + trainer_finetune_zbl = get_trainer( + deepcopy(self.config_zbl), + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + state_dict_finetuned_zbl = trainer_finetune_zbl.wrapper.model.state_dict() + for state_key in state_dict_finetuned_zbl: + if "out_bias" not in state_key and "out_std" not in state_key: + original_key = state_key + if ".models.0." in state_key: + original_key = state_key.replace(".models.0.", ".") + if ".models.1." not in state_key: + torch.testing.assert_close( + state_dict_trained[original_key], + state_dict_finetuned_zbl[state_key], + ) + # check running + trainer_finetune_zbl.run() + # check running trainer_finetune.run() trainer_finetune_empty.run() @@ -222,6 +252,18 @@ def setUp(self) -> None: self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 + self.test_zbl_from_standard = True + + input_json_zbl = str(Path(__file__).parent / "water/zbl.json") + with open(input_json_zbl) as f: + self.config_zbl = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config_zbl["training"]["training_data"]["systems"] = data_file + self.config_zbl["training"]["validation_data"]["systems"] = data_file + self.config_zbl["model"] = deepcopy(model_zbl) + self.config_zbl["training"]["numb_steps"] = 1 + self.config_zbl["training"]["save_freq"] = 1 + def tearDown(self) -> None: DPTrainTest.tearDown(self)