Skip to content

Commit 1c29fe4

Browse files
iProzdanyangmlpre-commit-ci[bot]
authored
feat(pt): support zbl finetune (#4849)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added an option to control whether output statistics are computed or loaded across atomic models. * **Bug Fixes** * More robust parameter transfer during fine‑tuning to handle renamed branches and missing pretrained keys. * **Refactor** * Revised output-statistics workflow and refined per‑type output bias application in composite models. * **Tests** * Simplified linear-model bias checks and added a ZBL finetuning test path. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: anyangml <anyangpeng.ca@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cefce47 commit 1c29fe4

File tree

7 files changed

+134
-61
lines changed

7 files changed

+134
-61
lines changed

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,21 +363,25 @@ def compute_or_load_stat(
363363
self,
364364
merged: Union[Callable[[], list[dict]], list[dict]],
365365
stat_file_path: Optional[DPPath] = None,
366+
compute_or_load_out_stat: bool = True,
366367
) -> NoReturn:
367368
"""
368-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
369+
Compute or load the statistics parameters of the model,
370+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
371+
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
372+
and saved in the `stat_file_path`(s).
373+
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
374+
and load the calculated statistics parameters.
369375
370376
Parameters
371377
----------
372-
merged : Union[Callable[[], list[dict]], list[dict]]
373-
- list[dict]: A list of data samples from various data systems.
374-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
375-
originating from the `i`-th data system.
376-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
377-
only when needed. Since the sampling process can be slow and memory-intensive,
378-
the lazy function helps by only sampling once.
379-
stat_file_path : Optional[DPPath]
380-
The path to the stat file.
378+
merged
379+
The lazy sampled function to get data frames from different data systems.
380+
stat_file_path
381+
The dictionary of paths to the statistics files.
382+
compute_or_load_out_stat : bool
383+
Whether to compute the output statistics.
384+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
381385
382386
"""
383387
raise NotImplementedError

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def compute_or_load_stat(
285285
self,
286286
sampled_func,
287287
stat_file_path: Optional[DPPath] = None,
288+
compute_or_load_out_stat: bool = True,
288289
) -> None:
289290
"""
290291
Compute or load the statistics parameters of the model,
@@ -300,6 +301,9 @@ def compute_or_load_stat(
300301
The lazy sampled function to get data frames from different data systems.
301302
stat_file_path
302303
The dictionary of paths to the statistics files.
304+
compute_or_load_out_stat : bool
305+
Whether to compute the output statistics.
306+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
303307
"""
304308
if stat_file_path is not None and self.type_map is not None:
305309
# descriptors and fitting net with different type_map
@@ -323,7 +327,8 @@ def wrapped_sampler():
323327
self.fitting_net.compute_input_stats(
324328
wrapped_sampler, protection=self.data_stat_protect
325329
)
326-
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
330+
if compute_or_load_out_stat:
331+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
327332

328333
def get_dim_fparam(self) -> int:
329334
"""Get the number (dimension) of frame parameters of this atomic model."""

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
23
from typing import (
3-
Callable,
44
Optional,
55
Union,
66
)
@@ -319,6 +319,10 @@ def apply_out_stat(
319319
The atom types. nf x nloc
320320
321321
"""
322+
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
323+
for kk in self.bias_keys:
324+
# nf x nloc x odims, out_bias: ntypes x odims
325+
ret[kk] = ret[kk] + out_bias[kk][atype]
322326
return ret
323327

324328
@staticmethod
@@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool:
464468
"""
465469
return False
466470

467-
def compute_or_load_out_stat(
468-
self,
469-
merged: Union[Callable[[], list[dict]], list[dict]],
470-
stat_file_path: Optional[DPPath] = None,
471-
) -> None:
472-
"""
473-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
474-
475-
Parameters
476-
----------
477-
merged : Union[Callable[[], list[dict]], list[dict]]
478-
- list[dict]: A list of data samples from various data systems.
479-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
480-
originating from the `i`-th data system.
481-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
482-
only when needed. Since the sampling process can be slow and memory-intensive,
483-
the lazy function helps by only sampling once.
484-
stat_file_path : Optional[DPPath]
485-
The path to the stat file.
486-
487-
"""
488-
for md in self.models:
489-
md.compute_or_load_out_stat(merged, stat_file_path)
490-
491471
def compute_or_load_stat(
492472
self,
493473
sampled_func,
494474
stat_file_path: Optional[DPPath] = None,
475+
compute_or_load_out_stat: bool = True,
495476
) -> None:
496477
"""
497478
Compute or load the statistics parameters of the model,
@@ -507,9 +488,34 @@ def compute_or_load_stat(
507488
The lazy sampled function to get data frames from different data systems.
508489
stat_file_path
509490
The dictionary of paths to the statistics files.
491+
compute_or_load_out_stat : bool
492+
Whether to compute the output statistics.
493+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
510494
"""
511495
for md in self.models:
512-
md.compute_or_load_stat(sampled_func, stat_file_path)
496+
md.compute_or_load_stat(
497+
sampled_func, stat_file_path, compute_or_load_out_stat=False
498+
)
499+
500+
if stat_file_path is not None and self.type_map is not None:
501+
# descriptors and fitting net with different type_map
502+
# should not share the same parameters
503+
stat_file_path /= " ".join(self.type_map)
504+
505+
@functools.lru_cache
506+
def wrapped_sampler():
507+
sampled = sampled_func()
508+
if self.pair_excl is not None:
509+
pair_exclude_types = self.pair_excl.get_exclude_types()
510+
for sample in sampled:
511+
sample["pair_exclude_types"] = list(pair_exclude_types)
512+
if self.atom_excl is not None:
513+
atom_exclude_types = self.atom_excl.get_exclude_types()
514+
for sample in sampled:
515+
sample["atom_exclude_types"] = list(atom_exclude_types)
516+
return sampled
517+
518+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
513519

514520

515521
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):

deepmd/pt/model/atomic_model/pairtab_atomic_model.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,26 +224,31 @@ def deserialize(cls, data) -> "PairTabAtomicModel":
224224

225225
def compute_or_load_stat(
226226
self,
227-
merged: Union[Callable[[], list[dict]], list[dict]],
227+
sampled_func: Union[Callable[[], list[dict]], list[dict]],
228228
stat_file_path: Optional[DPPath] = None,
229+
compute_or_load_out_stat: bool = True,
229230
) -> None:
230231
"""
231-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
232+
Compute or load the statistics parameters of the model,
233+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
234+
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
235+
and saved in the `stat_file_path`(s).
236+
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
237+
and load the calculated statistics parameters.
232238
233239
Parameters
234240
----------
235-
merged : Union[Callable[[], list[dict]], list[dict]]
236-
- list[dict]: A list of data samples from various data systems.
237-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
238-
originating from the `i`-th data system.
239-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
240-
only when needed. Since the sampling process can be slow and memory-intensive,
241-
the lazy function helps by only sampling once.
242-
stat_file_path : Optional[DPPath]
243-
The path to the stat file.
241+
sampled_func
242+
The lazy sampled function to get data frames from different data systems.
243+
stat_file_path
244+
The dictionary of paths to the statistics files.
245+
compute_or_load_out_stat : bool
246+
Whether to compute the output statistics.
247+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
244248
245249
"""
246-
self.compute_or_load_out_stat(merged, stat_file_path)
250+
if compute_or_load_out_stat:
251+
self.compute_or_load_out_stat(sampled_func, stat_file_path)
247252

248253
def forward_atomic(
249254
self,

deepmd/pt/train/training.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,15 +510,31 @@ def collect_single_finetune_params(
510510
if i != "_extra_state" and f".{_model_key}." in i
511511
]
512512
for item_key in target_keys:
513-
if _new_fitting and (".descriptor." not in item_key):
513+
new_key = item_key.replace(
514+
f".{_model_key}.", f".{_model_key_from}."
515+
)
516+
use_random_initialization = _new_fitting and (
517+
".descriptor." not in item_key
518+
)
519+
if (
520+
not use_random_initialization
521+
and new_key not in _origin_state_dict
522+
):
523+
# for ZBL models finetuning from standard models
524+
if ".models.0." in new_key:
525+
new_key = new_key.replace(".models.0.", ".")
526+
elif ".models.1." in new_key:
527+
use_random_initialization = True
528+
else:
529+
raise KeyError(
530+
f"Key {new_key} not found in pretrained model."
531+
)
532+
if use_random_initialization:
514533
# print(f'Keep {item_key} in old model!')
515534
_new_state_dict[item_key] = (
516535
_random_state_dict[item_key].clone().detach()
517536
)
518537
else:
519-
new_key = item_key.replace(
520-
f".{_model_key}.", f".{_model_key_from}."
521-
)
522538
# print(f'Replace {item_key} with {new_key} in pretrained_model!')
523539
_new_state_dict[item_key] = (
524540
_origin_state_dict[new_key].clone().detach()

source/tests/pt/model/test_linear_atomic_model_stat.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,11 @@ def test_linear_atomic_model_stat_with_bias(self) -> None:
233233
linear_model.compute_or_load_out_stat(
234234
self.merged_output_stat, stat_file_path=self.stat_file_path
235235
)
236-
# bias applied to sub atomic models.
237236
ener_bias = np.array([1.0, 3.0]).reshape(2, 1)
238-
linear_ret = []
239-
for idx, md in enumerate(linear_model.models):
240-
ret = md.forward_common_atomic(*args)
241-
ret = to_numpy_array(ret["energy"])
242-
linear_ret.append(ret_no_bias[idx] + ener_bias[at])
243-
np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret)
237+
ret = to_numpy_array(linear_model.forward_common_atomic(*args)["energy"])
238+
np.testing.assert_almost_equal((ret0 + ener_bias[at]), ret)
244239

245240
# linear model not adding bias again
246241
ret1 = linear_model.forward_common_atomic(*args)
247242
ret1 = to_numpy_array(ret1["energy"])
248-
np.testing.assert_almost_equal(np.mean(np.stack(linear_ret), axis=0), ret1)
243+
np.testing.assert_almost_equal(ret, ret1)

source/tests/pt/test_training.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131

3232
class DPTrainTest:
33+
test_zbl_from_standard: bool = False
34+
3335
def test_dp_train(self) -> None:
3436
# test training from scratch
3537
trainer = get_trainer(deepcopy(self.config))
@@ -95,6 +97,34 @@ def test_dp_train(self) -> None:
9597
state_dict_finetuned_random[state_key],
9698
)
9799

100+
if self.test_zbl_from_standard:
101+
# test fine-tuning using zbl from standard model
102+
finetune_model = (
103+
self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
104+
)
105+
self.config_zbl["model"], finetune_links = get_finetune_rules(
106+
finetune_model,
107+
self.config_zbl["model"],
108+
)
109+
trainer_finetune_zbl = get_trainer(
110+
deepcopy(self.config_zbl),
111+
finetune_model=finetune_model,
112+
finetune_links=finetune_links,
113+
)
114+
state_dict_finetuned_zbl = trainer_finetune_zbl.wrapper.model.state_dict()
115+
for state_key in state_dict_finetuned_zbl:
116+
if "out_bias" not in state_key and "out_std" not in state_key:
117+
original_key = state_key
118+
if ".models.0." in state_key:
119+
original_key = state_key.replace(".models.0.", ".")
120+
if ".models.1." not in state_key:
121+
torch.testing.assert_close(
122+
state_dict_trained[original_key],
123+
state_dict_finetuned_zbl[state_key],
124+
)
125+
# check running
126+
trainer_finetune_zbl.run()
127+
98128
# check running
99129
trainer_finetune.run()
100130
trainer_finetune_empty.run()
@@ -222,6 +252,18 @@ def setUp(self) -> None:
222252
self.config["training"]["numb_steps"] = 1
223253
self.config["training"]["save_freq"] = 1
224254

255+
self.test_zbl_from_standard = True
256+
257+
input_json_zbl = str(Path(__file__).parent / "water/zbl.json")
258+
with open(input_json_zbl) as f:
259+
self.config_zbl = json.load(f)
260+
data_file = [str(Path(__file__).parent / "water/data/data_0")]
261+
self.config_zbl["training"]["training_data"]["systems"] = data_file
262+
self.config_zbl["training"]["validation_data"]["systems"] = data_file
263+
self.config_zbl["model"] = deepcopy(model_zbl)
264+
self.config_zbl["training"]["numb_steps"] = 1
265+
self.config_zbl["training"]["save_freq"] = 1
266+
225267
def tearDown(self) -> None:
226268
DPTrainTest.tearDown(self)
227269

0 commit comments

Comments
 (0)