Skip to content
24 changes: 14 additions & 10 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,21 +363,25 @@ def compute_or_load_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
compute_or_load_out_stat: bool = True,
) -> NoReturn:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.
merged
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
compute_or_load_out_stat : bool
Whether to compute the output statistics.
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).

"""
raise NotImplementedError
Expand Down
7 changes: 6 additions & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
compute_or_load_out_stat: bool = True,
) -> None:
"""
Compute or load the statistics parameters of the model,
Expand All @@ -300,6 +301,9 @@ def compute_or_load_stat(
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
compute_or_load_out_stat : bool
Whether to compute the output statistics.
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
"""
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
Expand All @@ -323,7 +327,8 @@ def wrapped_sampler():
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down
58 changes: 32 additions & 26 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from typing import (
Callable,
Optional,
Union,
)
Expand Down Expand Up @@ -319,6 +319,10 @@ def apply_out_stat(
The atom types. nf x nloc

"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + out_bias[kk][atype]
return ret

@staticmethod
Expand Down Expand Up @@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool:
"""
return False

def compute_or_load_out_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
for md in self.models:
md.compute_or_load_out_stat(merged, stat_file_path)

def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
compute_or_load_out_stat: bool = True,
) -> None:
"""
Compute or load the statistics parameters of the model,
Expand All @@ -507,9 +488,34 @@ def compute_or_load_stat(
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
compute_or_load_out_stat : bool
Whether to compute the output statistics.
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
"""
for md in self.models:
md.compute_or_load_stat(sampled_func, stat_file_path)
md.compute_or_load_stat(
sampled_func, stat_file_path, compute_or_load_out_stat=False
)

if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)

@functools.lru_cache
def wrapped_sampler():
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
return sampled

self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)


class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
Expand Down
29 changes: 17 additions & 12 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,26 +224,31 @@ def deserialize(cls, data) -> "PairTabAtomicModel":

def compute_or_load_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
sampled_func: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
compute_or_load_out_stat: bool = True,
) -> None:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
compute_or_load_out_stat : bool
Whether to compute the output statistics.
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).

"""
self.compute_or_load_out_stat(merged, stat_file_path)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(sampled_func, stat_file_path)

def forward_atomic(
self,
Expand Down
24 changes: 20 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,31 @@ def collect_single_finetune_params(
if i != "_extra_state" and f".{_model_key}." in i
]
for item_key in target_keys:
if _new_fitting and (".descriptor." not in item_key):
new_key = item_key.replace(
f".{_model_key}.", f".{_model_key_from}."
)
use_random_initialization = _new_fitting and (
".descriptor." not in item_key
)
if (
not use_random_initialization
and new_key not in _origin_state_dict
):
# for ZBL models finetuning from standard models
if ".models.0." in new_key:
new_key = new_key.replace(".models.0.", ".")
elif ".models.1." in new_key:
use_random_initialization = True
else:
raise KeyError(
f"Key {new_key} not found in pretrained model."
)
if use_random_initialization:
# print(f'Keep {item_key} in old model!')
_new_state_dict[item_key] = (
_random_state_dict[item_key].clone().detach()
)
else:
new_key = item_key.replace(
f".{_model_key}.", f".{_model_key_from}."
)
# print(f'Replace {item_key} with {new_key} in pretrained_model!')
_new_state_dict[item_key] = (
_origin_state_dict[new_key].clone().detach()
Expand Down
11 changes: 3 additions & 8 deletions source/tests/pt/model/test_linear_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,11 @@ def test_linear_atomic_model_stat_with_bias(self) -> None:
linear_model.compute_or_load_out_stat(
self.merged_output_stat, stat_file_path=self.stat_file_path
)
# bias applied to sub atomic models.
ener_bias = np.array([1.0, 3.0]).reshape(2, 1)
linear_ret = []
for idx, md in enumerate(linear_model.models):
ret = md.forward_common_atomic(*args)
ret = to_numpy_array(ret["energy"])
linear_ret.append(ret_no_bias[idx] + ener_bias[at])
np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret)
ret = to_numpy_array(linear_model.forward_common_atomic(*args)["energy"])
np.testing.assert_almost_equal((ret0 + ener_bias[at]), ret)

# linear model not adding bias again
ret1 = linear_model.forward_common_atomic(*args)
ret1 = to_numpy_array(ret1["energy"])
np.testing.assert_almost_equal(np.mean(np.stack(linear_ret), axis=0), ret1)
np.testing.assert_almost_equal(ret, ret1)
42 changes: 42 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@


class DPTrainTest:
test_zbl_from_standard: bool = False

def test_dp_train(self) -> None:
# test training from scratch
trainer = get_trainer(deepcopy(self.config))
Expand Down Expand Up @@ -95,6 +97,34 @@ def test_dp_train(self) -> None:
state_dict_finetuned_random[state_key],
)

if self.test_zbl_from_standard:
# test fine-tuning using zbl from standard model
finetune_model = (
self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
)
self.config_zbl["model"], finetune_links = get_finetune_rules(
finetune_model,
self.config_zbl["model"],
)
trainer_finetune_zbl = get_trainer(
deepcopy(self.config_zbl),
finetune_model=finetune_model,
finetune_links=finetune_links,
)
state_dict_finetuned_zbl = trainer_finetune_zbl.wrapper.model.state_dict()
for state_key in state_dict_finetuned_zbl:
if "out_bias" not in state_key and "out_std" not in state_key:
original_key = state_key
if ".models.0." in state_key:
original_key = state_key.replace(".models.0.", ".")
if ".models.1." not in state_key:
torch.testing.assert_close(
state_dict_trained[original_key],
state_dict_finetuned_zbl[state_key],
)
# check running
trainer_finetune_zbl.run()

# check running
trainer_finetune.run()
trainer_finetune_empty.run()
Expand Down Expand Up @@ -222,6 +252,18 @@ def setUp(self) -> None:
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

self.test_zbl_from_standard = True

input_json_zbl = str(Path(__file__).parent / "water/zbl.json")
with open(input_json_zbl) as f:
self.config_zbl = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.config_zbl["training"]["training_data"]["systems"] = data_file
self.config_zbl["training"]["validation_data"]["systems"] = data_file
self.config_zbl["model"] = deepcopy(model_zbl)
self.config_zbl["training"]["numb_steps"] = 1
self.config_zbl["training"]["save_freq"] = 1

def tearDown(self) -> None:
DPTrainTest.tearDown(self)

Expand Down