From e446004f6384776cd9c0dbf8ef445124ddcf8f35 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:53:17 +0800 Subject: [PATCH 1/8] Fix state dict key handling for ZBL models Updated the logic for transferring state dict items to correctly handle keys related to ZBL models by replacing '.models.0.' with '.' and ensuring '.models.1.' items are retained. This improves compatibility when loading pretrained models with different model key structures. --- deepmd/pt/train/training.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 193dcd8cb9..0c6245a080 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -509,7 +509,9 @@ def collect_single_finetune_params( if i != "_extra_state" and f".{_model_key}." in i ] for item_key in target_keys: - if _new_fitting and (".descriptor." not in item_key): + if ( + _new_fitting and (".descriptor." not in item_key) + ) or ".models.1." in item_key: # print(f'Keep {item_key} in old model!') _new_state_dict[item_key] = ( _random_state_dict[item_key].clone().detach() @@ -517,7 +519,7 @@ def collect_single_finetune_params( else: new_key = item_key.replace( f".{_model_key}.", f".{_model_key_from}." - ) + ).replace(".models.0.", ".") # for ZBL models # print(f'Replace {item_key} with {new_key} in pretrained_model!') _new_state_dict[item_key] = ( _origin_state_dict[new_key].clone().detach() From 9c90d9575659c866c43d0193ecb28183fc74bc22 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 28 Jul 2025 16:00:09 +0800 Subject: [PATCH 2/8] Add compute_out_stat flag to stat computation methods Introduces a compute_out_stat parameter to compute_or_load_stat methods in BaseAtomicModel, DPAtomicModel, LinearEnergyAtomicModel, and PairTabAtomicModel. This allows conditional computation of output statistics, improving flexibility and control over the statistics computation process. --- .../model/atomic_model/base_atomic_model.py | 1 + .../pt/model/atomic_model/dp_atomic_model.py | 4 +- .../model/atomic_model/linear_atomic_model.py | 55 ++++++++++--------- .../atomic_model/pairtab_atomic_model.py | 4 +- 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 56af5f4f43..f1802eb0f6 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -363,6 +363,7 @@ def compute_or_load_stat( self, merged: Union[Callable[[], list[dict]], list[dict]], stat_file_path: Optional[DPPath] = None, + compute_out_stat: bool = True, ) -> NoReturn: """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 5fa787b17b..0fe9fdae74 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -285,6 +285,7 @@ def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, + compute_out_stat: bool = True, ) -> None: """ Compute or load the statistics parameters of the model, @@ -323,7 +324,8 @@ def wrapped_sampler(): self.fitting_net.compute_input_stats( wrapped_sampler, protection=self.data_stat_protect ) - self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + if compute_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 3d894dc3a0..7bbb4d410c 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools from typing import ( - Callable, Optional, Union, ) @@ -319,6 +319,10 @@ def apply_out_stat( The atom types. nf x nloc """ + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + # nf x nloc x odims, out_bias: ntypes x odims + ret[kk] = ret[kk] + out_bias[kk][atype] return ret @staticmethod @@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool: """ return False - def compute_or_load_out_stat( - self, - merged: Union[Callable[[], list[dict]], list[dict]], - stat_file_path: Optional[DPPath] = None, - ) -> None: - """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. - - Parameters - ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. - - """ - for md in self.models: - md.compute_or_load_out_stat(merged, stat_file_path) - def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, + compute_out_stat: bool = True, ) -> None: """ Compute or load the statistics parameters of the model, @@ -509,7 +490,29 @@ def compute_or_load_stat( The dictionary of paths to the statistics files. """ for md in self.models: - md.compute_or_load_stat(sampled_func, stat_file_path) + md.compute_or_load_stat( + sampled_func, stat_file_path, compute_out_stat=False + ) + + if stat_file_path is not None and self.type_map is not None: + # descriptors and fitting net with different type_map + # should not share the same parameters + stat_file_path /= " ".join(self.type_map) + + @functools.lru_cache + def wrapped_sampler(): + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + return sampled + + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 62b47afb32..55b7e43917 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -226,6 +226,7 @@ def compute_or_load_stat( self, merged: Union[Callable[[], list[dict]], list[dict]], stat_file_path: Optional[DPPath] = None, + compute_out_stat: bool = True, ) -> None: """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. @@ -243,7 +244,8 @@ def compute_or_load_stat( The path to the stat file. """ - self.compute_or_load_out_stat(merged, stat_file_path) + if compute_out_stat: + self.compute_or_load_out_stat(merged, stat_file_path) def forward_atomic( self, From f461f436e0eee358a3096fdb56330c663f4d8caf Mon Sep 17 00:00:00 2001 From: anyangml Date: Mon, 28 Jul 2025 10:02:42 +0000 Subject: [PATCH 3/8] fix: UT --- .../tests/pt/model/test_linear_atomic_model_stat.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/source/tests/pt/model/test_linear_atomic_model_stat.py b/source/tests/pt/model/test_linear_atomic_model_stat.py index 90758526b9..180e4d08ae 100644 --- a/source/tests/pt/model/test_linear_atomic_model_stat.py +++ b/source/tests/pt/model/test_linear_atomic_model_stat.py @@ -233,16 +233,12 @@ def test_linear_atomic_model_stat_with_bias(self) -> None: linear_model.compute_or_load_out_stat( self.merged_output_stat, stat_file_path=self.stat_file_path ) - # bias applied to sub atomic models. ener_bias = np.array([1.0, 3.0]).reshape(2, 1) - linear_ret = [] - for idx, md in enumerate(linear_model.models): - ret = md.forward_common_atomic(*args) - ret = to_numpy_array(ret["energy"]) - linear_ret.append(ret_no_bias[idx] + ener_bias[at]) - np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret) + ret = to_numpy_array(linear_model.forward_common_atomic(*args)["energy"]) + np.testing.assert_almost_equal((ret0 + ener_bias[at]), ret) + # linear model not adding bias again ret1 = linear_model.forward_common_atomic(*args) ret1 = to_numpy_array(ret1["energy"]) - np.testing.assert_almost_equal(np.mean(np.stack(linear_ret), axis=0), ret1) + np.testing.assert_almost_equal(ret, ret1) From 2fdf1c980714aeef414a5c95af63208f6f786b51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Jul 2025 10:04:22 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_linear_atomic_model_stat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/source/tests/pt/model/test_linear_atomic_model_stat.py b/source/tests/pt/model/test_linear_atomic_model_stat.py index 180e4d08ae..0f92f1253d 100644 --- a/source/tests/pt/model/test_linear_atomic_model_stat.py +++ b/source/tests/pt/model/test_linear_atomic_model_stat.py @@ -236,7 +236,6 @@ def test_linear_atomic_model_stat_with_bias(self) -> None: ener_bias = np.array([1.0, 3.0]).reshape(2, 1) ret = to_numpy_array(linear_model.forward_common_atomic(*args)["energy"]) np.testing.assert_almost_equal((ret0 + ener_bias[at]), ret) - # linear model not adding bias again ret1 = linear_model.forward_common_atomic(*args) From a735ca74e5fd6c662d23a67dc77a745cd1fdd3e4 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:25:57 +0800 Subject: [PATCH 5/8] Add ZBL model fine-tuning from standard models Enhanced Trainer to support fine-tuning ZBL models from standard models by handling key mapping and random state initialization. Added corresponding tests to verify ZBL fine-tuning behavior and ensure correct state dict transfer in test_training.py. --- deepmd/pt/train/training.py | 24 ++++++++++++++---- source/tests/pt/test_training.py | 42 ++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 484ba6ef16..59b5c7a7e5 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -510,17 +510,31 @@ def collect_single_finetune_params( if i != "_extra_state" and f".{_model_key}." in i ] for item_key in target_keys: + new_key = item_key.replace( + f".{_model_key}.", f".{_model_key_from}." + ) + use_random_state = _new_fitting and ( + ".descriptor." not in item_key + ) if ( - _new_fitting and (".descriptor." not in item_key) - ) or ".models.1." in item_key: + not use_random_state + and new_key not in _origin_state_dict + ): + # for ZBL models finetuning from standard models + if ".models.0." in new_key: + new_key = new_key.replace(".models.0.", ".") + elif ".models.1." in new_key: + use_random_state = True + else: + raise KeyError( + f"Key {new_key} not found in pretrained model." + ) + if use_random_state: # print(f'Keep {item_key} in old model!') _new_state_dict[item_key] = ( _random_state_dict[item_key].clone().detach() ) else: - new_key = item_key.replace( - f".{_model_key}.", f".{_model_key_from}." - ).replace(".models.0.", ".") # for ZBL models # print(f'Replace {item_key} with {new_key} in pretrained_model!') _new_state_dict[item_key] = ( _origin_state_dict[new_key].clone().detach() diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 3df95e4b14..c57c896197 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -30,6 +30,8 @@ class DPTrainTest: + test_zbl_from_standard: bool = False + def test_dp_train(self) -> None: # test training from scratch trainer = get_trainer(deepcopy(self.config)) @@ -95,6 +97,34 @@ def test_dp_train(self) -> None: state_dict_finetuned_random[state_key], ) + if self.test_zbl_from_standard: + # test fine-tuning using zbl from standard model + finetune_model = ( + self.config["training"].get("save_ckpt", "model.ckpt") + ".pt" + ) + self.config_zbl["model"], finetune_links = get_finetune_rules( + finetune_model, + self.config_zbl["model"], + ) + trainer_finetune_zbl = get_trainer( + deepcopy(self.config_zbl), + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + state_dict_finetuned_zbl = trainer_finetune_zbl.wrapper.model.state_dict() + for state_key in state_dict_finetuned_zbl: + if "out_bias" not in state_key and "out_std" not in state_key: + original_key = state_key + if ".models.0." in state_key: + original_key = state_key.replace(".models.0.", ".") + if ".models.1." not in state_key: + torch.testing.assert_close( + state_dict_trained[original_key], + state_dict_finetuned_zbl[state_key], + ) + # check running + trainer_finetune_zbl.run() + # check running trainer_finetune.run() trainer_finetune_empty.run() @@ -222,6 +252,18 @@ def setUp(self) -> None: self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 + self.test_zbl_from_standard = True + + input_json_zbl = str(Path(__file__).parent / "water/zbl.json") + with open(input_json_zbl) as f: + self.config_zbl = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config_zbl["training"]["training_data"]["systems"] = data_file + self.config_zbl["training"]["validation_data"]["systems"] = data_file + self.config_zbl["model"] = deepcopy(model_zbl) + self.config_zbl["training"]["numb_steps"] = 1 + self.config_zbl["training"]["save_freq"] = 1 + def tearDown(self) -> None: DPTrainTest.tearDown(self) From 777042802dac79d913dbb6d78cab5276c1924ec8 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:52:39 +0800 Subject: [PATCH 6/8] Update training.py --- deepmd/pt/train/training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 59b5c7a7e5..9b2b896d02 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -513,23 +513,23 @@ def collect_single_finetune_params( new_key = item_key.replace( f".{_model_key}.", f".{_model_key_from}." ) - use_random_state = _new_fitting and ( + use_random_initialization = _new_fitting and ( ".descriptor." not in item_key ) if ( - not use_random_state + not use_random_initialization and new_key not in _origin_state_dict ): # for ZBL models finetuning from standard models if ".models.0." in new_key: new_key = new_key.replace(".models.0.", ".") elif ".models.1." in new_key: - use_random_state = True + use_random_initialization = True else: raise KeyError( f"Key {new_key} not found in pretrained model." ) - if use_random_state: + if use_random_initialization: # print(f'Keep {item_key} in old model!') _new_state_dict[item_key] = ( _random_state_dict[item_key].clone().detach() From 2bb190f0dd934475148c0c4118140b6d2d934892 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Aug 2025 13:25:53 +0800 Subject: [PATCH 7/8] fix docstr --- deepmd/pt/model/atomic_model/base_atomic_model.py | 7 +++++-- deepmd/pt/model/atomic_model/dp_atomic_model.py | 7 +++++-- deepmd/pt/model/atomic_model/linear_atomic_model.py | 7 +++++-- deepmd/pt/model/atomic_model/pairtab_atomic_model.py | 9 ++++++--- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index f1802eb0f6..8099470ad2 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -363,10 +363,10 @@ def compute_or_load_stat( self, merged: Union[Callable[[], list[dict]], list[dict]], stat_file_path: Optional[DPPath] = None, - compute_out_stat: bool = True, + compute_or_load_out_stat: bool = True, ) -> NoReturn: """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + Compute the input and output statistics (e.g. energy bias) for the model from packed data. Parameters ---------- @@ -379,6 +379,9 @@ def compute_or_load_stat( the lazy function helps by only sampling once. stat_file_path : Optional[DPPath] The path to the stat file. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ raise NotImplementedError diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 0fe9fdae74..62c7d78d75 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -285,7 +285,7 @@ def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, - compute_out_stat: bool = True, + compute_or_load_out_stat: bool = True, ) -> None: """ Compute or load the statistics parameters of the model, @@ -301,6 +301,9 @@ def compute_or_load_stat( The lazy sampled function to get data frames from different data systems. stat_file_path The dictionary of paths to the statistics files. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ if stat_file_path is not None and self.type_map is not None: # descriptors and fitting net with different type_map @@ -324,7 +327,7 @@ def wrapped_sampler(): self.fitting_net.compute_input_stats( wrapped_sampler, protection=self.data_stat_protect ) - if compute_out_stat: + if compute_or_load_out_stat: self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 7bbb4d410c..9b20d80516 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -472,7 +472,7 @@ def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, - compute_out_stat: bool = True, + compute_or_load_out_stat: bool = True, ) -> None: """ Compute or load the statistics parameters of the model, @@ -488,10 +488,13 @@ def compute_or_load_stat( The lazy sampled function to get data frames from different data systems. stat_file_path The dictionary of paths to the statistics files. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ for md in self.models: md.compute_or_load_stat( - sampled_func, stat_file_path, compute_out_stat=False + sampled_func, stat_file_path, compute_or_load_out_stat=False ) if stat_file_path is not None and self.type_map is not None: diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 55b7e43917..b1826e10f3 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -226,10 +226,10 @@ def compute_or_load_stat( self, merged: Union[Callable[[], list[dict]], list[dict]], stat_file_path: Optional[DPPath] = None, - compute_out_stat: bool = True, + compute_or_load_out_stat: bool = True, ) -> None: """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + Compute the input and output statistics (e.g. energy bias) for the model from packed data. Parameters ---------- @@ -242,9 +242,12 @@ def compute_or_load_stat( the lazy function helps by only sampling once. stat_file_path : Optional[DPPath] The path to the stat file. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ - if compute_out_stat: + if compute_or_load_out_stat: self.compute_or_load_out_stat(merged, stat_file_path) def forward_atomic( From 404b915d8211e35e3043c3f976eeaaff66cf69e6 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Aug 2025 13:32:10 +0800 Subject: [PATCH 8/8] Update docstrings for compute_or_load_stat methods Revised and clarified the docstrings for compute_or_load_stat in both BaseAtomicModel and PairTabAtomicModel to better describe the function parameters and behavior. Updated parameter names and descriptions for improved consistency and readability. --- .../model/atomic_model/base_atomic_model.py | 20 ++++++++-------- .../atomic_model/pairtab_atomic_model.py | 24 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 8099470ad2..a2cbef3eee 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -366,19 +366,19 @@ def compute_or_load_stat( compute_or_load_out_stat: bool = True, ) -> NoReturn: """ - Compute the input and output statistics (e.g. energy bias) for the model from packed data. + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. Parameters ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. + merged + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. compute_or_load_out_stat : bool Whether to compute the output statistics. If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index b1826e10f3..8f73d81d76 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -224,31 +224,31 @@ def deserialize(cls, data) -> "PairTabAtomicModel": def compute_or_load_stat( self, - merged: Union[Callable[[], list[dict]], list[dict]], + sampled_func: Union[Callable[[], list[dict]], list[dict]], stat_file_path: Optional[DPPath] = None, compute_or_load_out_stat: bool = True, ) -> None: """ - Compute the input and output statistics (e.g. energy bias) for the model from packed data. + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. Parameters ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. compute_or_load_out_stat : bool Whether to compute the output statistics. If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ if compute_or_load_out_stat: - self.compute_or_load_out_stat(merged, stat_file_path) + self.compute_or_load_out_stat(sampled_func, stat_file_path) def forward_atomic( self,