Skip to content
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def compute_or_load_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
compute_out_stat: bool = True,
) -> NoReturn:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
compute_out_stat: bool = True,
) -> None:
"""
Compute or load the statistics parameters of the model,
Expand Down Expand Up @@ -323,7 +324,8 @@ def wrapped_sampler():
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
if compute_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down
55 changes: 29 additions & 26 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from typing import (
Callable,
Optional,
Union,
)
Expand Down Expand Up @@ -319,6 +319,10 @@ def apply_out_stat(
The atom types. nf x nloc

"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + out_bias[kk][atype]
return ret

@staticmethod
Expand Down Expand Up @@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool:
"""
return False

def compute_or_load_out_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
for md in self.models:
md.compute_or_load_out_stat(merged, stat_file_path)

def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
compute_out_stat: bool = True,
) -> None:
"""
Compute or load the statistics parameters of the model,
Expand All @@ -509,7 +490,29 @@ def compute_or_load_stat(
The dictionary of paths to the statistics files.
"""
for md in self.models:
md.compute_or_load_stat(sampled_func, stat_file_path)
md.compute_or_load_stat(
sampled_func, stat_file_path, compute_out_stat=False
)

if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)

@functools.lru_cache
def wrapped_sampler():
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
return sampled

self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)


class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def compute_or_load_stat(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
compute_out_stat: bool = True,
) -> None:
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Expand All @@ -243,7 +244,8 @@ def compute_or_load_stat(
The path to the stat file.

"""
self.compute_or_load_out_stat(merged, stat_file_path)
if compute_out_stat:
self.compute_or_load_out_stat(merged, stat_file_path)

def forward_atomic(
self,
Expand Down
24 changes: 20 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,31 @@ def collect_single_finetune_params(
if i != "_extra_state" and f".{_model_key}." in i
]
for item_key in target_keys:
if _new_fitting and (".descriptor." not in item_key):
new_key = item_key.replace(
f".{_model_key}.", f".{_model_key_from}."
)
use_random_initialization = _new_fitting and (
".descriptor." not in item_key
)
if (
not use_random_initialization
and new_key not in _origin_state_dict
):
# for ZBL models finetuning from standard models
if ".models.0." in new_key:
new_key = new_key.replace(".models.0.", ".")
elif ".models.1." in new_key:
use_random_initialization = True
else:
raise KeyError(
f"Key {new_key} not found in pretrained model."
)
if use_random_initialization:
# print(f'Keep {item_key} in old model!')
_new_state_dict[item_key] = (
_random_state_dict[item_key].clone().detach()
)
else:
new_key = item_key.replace(
f".{_model_key}.", f".{_model_key_from}."
)
# print(f'Replace {item_key} with {new_key} in pretrained_model!')
_new_state_dict[item_key] = (
_origin_state_dict[new_key].clone().detach()
Expand Down
11 changes: 3 additions & 8 deletions source/tests/pt/model/test_linear_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,11 @@ def test_linear_atomic_model_stat_with_bias(self) -> None:
linear_model.compute_or_load_out_stat(
self.merged_output_stat, stat_file_path=self.stat_file_path
)
# bias applied to sub atomic models.
ener_bias = np.array([1.0, 3.0]).reshape(2, 1)
linear_ret = []
for idx, md in enumerate(linear_model.models):
ret = md.forward_common_atomic(*args)
ret = to_numpy_array(ret["energy"])
linear_ret.append(ret_no_bias[idx] + ener_bias[at])
np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret)
ret = to_numpy_array(linear_model.forward_common_atomic(*args)["energy"])
np.testing.assert_almost_equal((ret0 + ener_bias[at]), ret)

# linear model not adding bias again
ret1 = linear_model.forward_common_atomic(*args)
ret1 = to_numpy_array(ret1["energy"])
np.testing.assert_almost_equal(np.mean(np.stack(linear_ret), axis=0), ret1)
np.testing.assert_almost_equal(ret, ret1)
42 changes: 42 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@


class DPTrainTest:
test_zbl_from_standard: bool = False

def test_dp_train(self) -> None:
# test training from scratch
trainer = get_trainer(deepcopy(self.config))
Expand Down Expand Up @@ -95,6 +97,34 @@ def test_dp_train(self) -> None:
state_dict_finetuned_random[state_key],
)

if self.test_zbl_from_standard:
# test fine-tuning using zbl from standard model
finetune_model = (
self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
)
self.config_zbl["model"], finetune_links = get_finetune_rules(
finetune_model,
self.config_zbl["model"],
)
trainer_finetune_zbl = get_trainer(
deepcopy(self.config_zbl),
finetune_model=finetune_model,
finetune_links=finetune_links,
)
state_dict_finetuned_zbl = trainer_finetune_zbl.wrapper.model.state_dict()
for state_key in state_dict_finetuned_zbl:
if "out_bias" not in state_key and "out_std" not in state_key:
original_key = state_key
if ".models.0." in state_key:
original_key = state_key.replace(".models.0.", ".")
if ".models.1." not in state_key:
torch.testing.assert_close(
state_dict_trained[original_key],
state_dict_finetuned_zbl[state_key],
)
# check running
trainer_finetune_zbl.run()

# check running
trainer_finetune.run()
trainer_finetune_empty.run()
Expand Down Expand Up @@ -222,6 +252,18 @@ def setUp(self) -> None:
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

self.test_zbl_from_standard = True

input_json_zbl = str(Path(__file__).parent / "water/zbl.json")
with open(input_json_zbl) as f:
self.config_zbl = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.config_zbl["training"]["training_data"]["systems"] = data_file
self.config_zbl["training"]["validation_data"]["systems"] = data_file
self.config_zbl["model"] = deepcopy(model_zbl)
self.config_zbl["training"]["numb_steps"] = 1
self.config_zbl["training"]["save_freq"] = 1

def tearDown(self) -> None:
DPTrainTest.tearDown(self)

Expand Down