Skip to content

Commit dc1b1a3

Browse files
iProzdChengqian-Zhangpre-commit-ci[bot]ChiahsinChuHydrogenSulfate
authored
Refactor property (#37)
* change property.npy to any name * Init branch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change | to Union * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change sub_var_name default to [] * Solve pre-commit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * solve scanning github * fix UT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * delete useless file * Solve some UT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Solve precommit * slove pre * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Solve dptest UT, dpatomicmodel UT, code scannisang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * delete param and * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Solve UT fail caused by task_dim and property_name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix UT * Fix UT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix UT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix permutation error * Add property bias UT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * recover rcond doc * recover blank * Change code according according to coderabbitai * solve pre-commit * Fix UT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change apply_bias doc * update the version compatibility * feat (tf/pt): add atomic weights to tensor loss (deepmodeling#4466) Interfaces are of particular interest in many studies. However, the configurations in the training set to represent the interface normally also include large parts of the bulk material. As a result, the final model would prefer the bulk information while the interfacial information is less learnt. It is difficult to simply improve the proportion of interfaces in the configurations since the electronic structures of the interface might only be reasonable with a certain thickness of bulk materials. Therefore, I wonder whether it is possible to define weights for atomic quantities in loss functions. This allows us to add higher weights for the atomic information for the regions of interest and probably makes the model "more focused" on the region of interest. In this PR, I add the keyword `enable_atomic_weight` to the loss function of the tensor model. In principle, it could be generalised to any atomic quantity, e.g., atomic forces. I would like to know the developers' comments/suggestions about this feature. I can add support for other loss functions and finish unit tests once we agree on this feature. Best. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced an optional parameter for atomic weights in loss calculations, enhancing flexibility in the `TensorLoss` class. - Added a suite of unit tests for the `TensorLoss` functionality, ensuring consistency between TensorFlow and PyTorch implementations. - **Bug Fixes** - Updated logic for local loss calculations to ensure correct application of atomic weights based on user input. - **Documentation** - Improved clarity of documentation for several function arguments, including the addition of a new argument related to atomic weights. <!-- end of auto-generated comment: release notes by coderabbit.ai --> * delete sub_var_name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * recover to property key * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix conflict * Fix UT * Add document of property fitting * Delete checkpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add get_property_name to DeepEvalBackend * pd: fix learning rate setting when resume (deepmodeling#4480) "When resuming training, there is no need to add `self.start_step` to the step count because Paddle uses `lr_sche.last_epoch` as the input for `step`, which already records the `start_step` steps." learning rate are correct after fixing ![22AD6874B74E437E9B133D75ABCC02FE](https://github.com/user-attachments/assets/1ad0ce71-6e1c-4de5-87dc-0daca1f6f038) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced training process with improved optimizer configuration and learning rate adjustments. - Refined logging of training and validation results for clarity. - Improved model saving logic to preserve the latest state during interruptions. - Enhanced tensorboard logging for detailed tracking of training metrics. - **Bug Fixes** - Corrected lambda function for learning rate scheduler to reference warmup steps accurately. - **Chores** - Streamlined data loading and handling for efficient training across different tasks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> * docs: update deepmd-gnn URL (deepmodeling#4482) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Documentation** - Updated guidelines for creating and integrating new models in the DeePMD-kit framework. - Added new sections on descriptors, fitting networks, and model requirements. - Enhanced unit testing section with instructions for regression tests. - Updated URL for the DeePMD-GNN plugin to reflect new repository location. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> * docs: update DPA-2 citation (deepmodeling#4483) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Updated references in the bibliography for the DPA-2 model to include a new article entry for 2024. - Added a new reference for an attention-based descriptor. - **Bug Fixes** - Corrected reference links in documentation to point to updated DOI links instead of arXiv. - **Documentation** - Revised entries in the credits and model documentation to reflect the latest citations and details. - Enhanced clarity and detail in fine-tuning documentation for TensorFlow and PyTorch implementations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> * docs: fix a minor typo on the title of `install-from-c-library.md` (deepmodeling#4484) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Documentation** - Updated formatting of the installation guide for the pre-compiled C library. - Icons for TensorFlow and JAX are now displayed together in the header. - Retained all installation instructions and compatibility notes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> * fix: print dlerror if dlopen fails (deepmodeling#4485) xref: deepmodeling/deepmd-gnn#44 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced error messages for library loading failures on non-Windows platforms. - Updated thread management environment variable checks for improved compatibility. - Added support for mixed types in tensor input handling, allowing for more flexible configurations. - **Bug Fixes** - Improved error reporting for dynamic library loading issues. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * change doc to py * Add out_bias out_std doc * change bias method to compute_stats_do_not_distinguish_types * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change var_name to property_name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change logic of extensive bias * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add doc for neww added parameter * change doc for compute_stats_do_not_distinguish_types * try to fix dptest * change all property to property_name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix UT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete key 'property' completely * Fix UT * Fix dptest UT * pd: fix oom error (deepmodeling#4493) Paddle use `MemoryError` rather than `RuntimeError` used in pytorch, now I can test DPA-1 and DPA-2 in 16G V100... ![image](https://github.com/user-attachments/assets/42ead773-bf26-4195-8f67-404b151371de) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved detection of out-of-memory (OOM) errors to enhance application stability. - Ensured cached memory is cleared upon OOM errors, preventing potential memory leaks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> * pd: add missing `dp.eval()` in pd backend (deepmodeling#4488) Switch to eval mode when evaluating model, otherwise `self.training` will be `True`, backward graph will be created and cause OOM <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced model evaluation state management to ensure correct behavior during evaluation. - **Bug Fixes** - Improved type consistency in the `normalize_coord` function for better computational accuracy. <!-- end of auto-generated comment: release notes by coderabbit.ai --> * [pre-commit.ci] pre-commit autoupdate (deepmodeling#4497) <!--pre-commit.ci start--> updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.3 → v0.8.4](astral-sh/ruff-pre-commit@v0.8.3...v0.8.4) <!--pre-commit.ci end--> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Delete attribute * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Solve comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Solve error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * delete property_name in serialize --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: root <2000011006@stu.pku.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chenqqian Zhang <100290172+Chengqian-Zhang@users.noreply.github.com> Co-authored-by: Jia-Xin Zhu <53895049+ChiahsinChu@users.noreply.github.com> Co-authored-by: HydrogenSulfate <490868991@qq.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 76f28e9 commit dc1b1a3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1225
-186
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ repos:
2929
exclude: ^source/3rdparty
3030
- repo: https://github.com/astral-sh/ruff-pre-commit
3131
# Ruff version.
32-
rev: v0.8.3
32+
rev: v0.8.4
3333
hooks:
3434
- id: ruff
3535
args: ["--fix"]

CITATIONS.bib

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,26 +128,26 @@ @article{Zhang_NpjComputMater_2024_v10_p94
128128
doi = {10.1038/s41524-024-01278-7},
129129
}
130130

131-
@misc{Zhang_2023_DPA2,
131+
@article{Zhang_npjComputMater_2024_v10_p293,
132132
annote = {DPA-2},
133133
author = {
134134
Duo Zhang and Xinzijian Liu and Xiangyu Zhang and Chengqian Zhang and Chun
135-
Cai and Hangrui Bi and Yiming Du and Xuejian Qin and Jiameng Huang and
136-
Bowen Li and Yifan Shan and Jinzhe Zeng and Yuzhi Zhang and Siyuan Liu and
137-
Yifan Li and Junhan Chang and Xinyan Wang and Shuo Zhou and Jianchuan Liu
138-
and Xiaoshan Luo and Zhenyu Wang and Wanrun Jiang and Jing Wu and Yudi Yang
139-
and Jiyuan Yang and Manyi Yang and Fu-Qiang Gong and Linshuang Zhang and
140-
Mengchao Shi and Fu-Zhi Dai and Darrin M. York and Shi Liu and Tong Zhu and
141-
Zhicheng Zhong and Jian Lv and Jun Cheng and Weile Jia and Mohan Chen and
142-
Guolin Ke and Weinan E and Linfeng Zhang and Han Wang
135+
Cai and Hangrui Bi and Yiming Du and Xuejian Qin and Anyang Peng and
136+
Jiameng Huang and Bowen Li and Yifan Shan and Jinzhe Zeng and Yuzhi Zhang
137+
and Siyuan Liu and Yifan Li and Junhan Chang and Xinyan Wang and Shuo Zhou
138+
and Jianchuan Liu and Xiaoshan Luo and Zhenyu Wang and Wanrun Jiang and
139+
Jing Wu and Yudi Yang and Jiyuan Yang and Manyi Yang and Fu-Qiang Gong and
140+
Linshuang Zhang and Mengchao Shi and Fu-Zhi Dai and Darrin M. York and Shi
141+
Liu and Tong Zhu and Zhicheng Zhong and Jian Lv and Jun Cheng and Weile Jia
142+
and Mohan Chen and Guolin Ke and Weinan E and Linfeng Zhang and Han Wang
143143
},
144-
title = {
145-
{DPA-2: Towards a universal large atomic model for molecular and material
146-
simulation}
147-
},
148-
publisher = {arXiv},
149-
year = 2023,
150-
doi = {10.48550/arXiv.2312.15492},
144+
title = {{DPA-2: a large atomic model as a multi-task learner}},
145+
journal = {npj Comput. Mater},
146+
year = 2024,
147+
volume = 10,
148+
number = 1,
149+
pages = 293,
150+
doi = {10.1038/s41524-024-01493-2},
151151
}
152152

153153
@article{Zhang_PhysPlasmas_2020_v27_p122704,

deepmd/dpmodel/atomic_model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
from .polar_atomic_model import (
4343
DPPolarAtomicModel,
4444
)
45+
from .property_atomic_model import (
46+
DPPropertyAtomicModel,
47+
)
4548

4649
__all__ = [
4750
"BaseAtomicModel",
@@ -50,6 +53,7 @@
5053
"DPDipoleAtomicModel",
5154
"DPEnergyAtomicModel",
5255
"DPPolarAtomicModel",
56+
"DPPropertyAtomicModel",
5357
"DPZBLLinearEnergyAtomicModel",
5458
"LinearEnergyAtomicModel",
5559
"PairTabAtomicModel",

deepmd/dpmodel/atomic_model/property_atomic_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import numpy as np
3+
24
from deepmd.dpmodel.fitting.property_fitting import (
35
PropertyFittingNet,
46
)
@@ -15,3 +17,25 @@ def __init__(self, descriptor, fitting, type_map, **kwargs):
1517
"fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel"
1618
)
1719
super().__init__(descriptor, fitting, type_map, **kwargs)
20+
21+
def apply_out_stat(
22+
self,
23+
ret: dict[str, np.ndarray],
24+
atype: np.ndarray,
25+
):
26+
"""Apply the stat to each atomic output.
27+
28+
In property fitting, each output will be multiplied by label std and then plus the label average value.
29+
30+
Parameters
31+
----------
32+
ret
33+
The returned dict by the forward_atomic method
34+
atype
35+
The atom types. nf x nloc. It is useless in property fitting.
36+
37+
"""
38+
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
39+
for kk in self.bias_keys:
40+
ret[kk] = ret[kk] * out_std[kk][0] + out_bias[kk][0]
41+
return ret

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def __init__(
387387
use_tebd_bias: bool = False,
388388
type_map: Optional[list[str]] = None,
389389
) -> None:
390-
r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492.
390+
r"""The DPA-2 descriptor[1]_.
391391
392392
Parameters
393393
----------
@@ -434,6 +434,11 @@ def __init__(
434434
sw: torch.Tensor
435435
The switch function for decaying inverse distance.
436436
437+
References
438+
----------
439+
.. [1] Zhang, D., Liu, X., Zhang, X. et al. DPA-2: a
440+
large atomic model as a multi-task learner. npj
441+
Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2
437442
"""
438443

439444
def init_subclass_params(sub_data, sub_class):

deepmd/dpmodel/fitting/property_fitting.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@ class PropertyFittingNet(InvarFitting):
4141
this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable.
4242
intensive
4343
Whether the fitting property is intensive.
44-
bias_method
45-
The method of applying the bias to each atomic output, user can select 'normal' or 'no_bias'.
46-
If 'normal' is used, the computed bias will be added to the atomic output.
47-
If 'no_bias' is used, no bias will be added to the atomic output.
44+
property_name:
45+
The name of fitting property, which should be consistent with the property name in the dataset.
46+
If the data file is named `humo.npy`, this parameter should be "humo".
4847
resnet_dt
4948
Time-step `dt` in the resnet construction:
5049
:math:`y = x + dt * \phi (Wx + b)`
@@ -74,7 +73,7 @@ def __init__(
7473
rcond: Optional[float] = None,
7574
trainable: Union[bool, list[bool]] = True,
7675
intensive: bool = False,
77-
bias_method: str = "normal",
76+
property_name: str = "property",
7877
resnet_dt: bool = True,
7978
numb_fparam: int = 0,
8079
numb_aparam: int = 0,
@@ -89,9 +88,8 @@ def __init__(
8988
) -> None:
9089
self.task_dim = task_dim
9190
self.intensive = intensive
92-
self.bias_method = bias_method
9391
super().__init__(
94-
var_name="property",
92+
var_name=property_name,
9593
ntypes=ntypes,
9694
dim_descrpt=dim_descrpt,
9795
dim_out=task_dim,
@@ -113,9 +111,9 @@ def __init__(
113111
@classmethod
114112
def deserialize(cls, data: dict) -> "PropertyFittingNet":
115113
data = data.copy()
116-
check_version_compatibility(data.pop("@version"), 3, 1)
114+
check_version_compatibility(data.pop("@version"), 4, 1)
117115
data.pop("dim_out")
118-
data.pop("var_name")
116+
data["property_name"] = data.pop("var_name")
119117
data.pop("tot_ener_zero")
120118
data.pop("layer_name")
121119
data.pop("use_aparam_as_mask", None)
@@ -131,6 +129,8 @@ def serialize(self) -> dict:
131129
**InvarFitting.serialize(self),
132130
"type": "property",
133131
"task_dim": self.task_dim,
132+
"intensive": self.intensive,
134133
}
134+
dd["@version"] = 4
135135

136136
return dd

deepmd/dpmodel/model/property_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
3-
DPAtomicModel,
2+
from deepmd.dpmodel.atomic_model import (
3+
DPPropertyAtomicModel,
44
)
55
from deepmd.dpmodel.model.base_model import (
66
BaseModel,
@@ -13,7 +13,7 @@
1313
make_model,
1414
)
1515

16-
DPPropertyModel_ = make_model(DPAtomicModel)
16+
DPPropertyModel_ = make_model(DPPropertyAtomicModel)
1717

1818

1919
@BaseModel.register("property")

deepmd/entrypoints/test.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -779,9 +779,17 @@ def test_property(
779779
tuple[list[np.ndarray], list[int]]
780780
arrays with results and their shapes
781781
"""
782-
data.add("property", dp.task_dim, atomic=False, must=True, high_prec=True)
782+
var_name = dp.get_var_name()
783+
assert isinstance(var_name, str)
784+
data.add(var_name, dp.task_dim, atomic=False, must=True, high_prec=True)
783785
if has_atom_property:
784-
data.add("atom_property", dp.task_dim, atomic=True, must=False, high_prec=True)
786+
data.add(
787+
f"atom_{var_name}",
788+
dp.task_dim,
789+
atomic=True,
790+
must=False,
791+
high_prec=True,
792+
)
785793

786794
if dp.get_dim_fparam() > 0:
787795
data.add(
@@ -832,12 +840,12 @@ def test_property(
832840
aproperty = ret[1]
833841
aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim])
834842

835-
diff_property = property - test_data["property"][:numb_test]
843+
diff_property = property - test_data[var_name][:numb_test]
836844
mae_property = mae(diff_property)
837845
rmse_property = rmse(diff_property)
838846

839847
if has_atom_property:
840-
diff_aproperty = aproperty - test_data["atom_property"][:numb_test]
848+
diff_aproperty = aproperty - test_data[f"atom_{var_name}"][:numb_test]
841849
mae_aproperty = mae(diff_aproperty)
842850
rmse_aproperty = rmse(diff_aproperty)
843851

@@ -854,7 +862,7 @@ def test_property(
854862
detail_path = Path(detail_file)
855863

856864
for ii in range(numb_test):
857-
test_out = test_data["property"][ii].reshape(-1, 1)
865+
test_out = test_data[var_name][ii].reshape(-1, 1)
858866
pred_out = property[ii].reshape(-1, 1)
859867

860868
frame_output = np.hstack((test_out, pred_out))
@@ -868,7 +876,7 @@ def test_property(
868876

869877
if has_atom_property:
870878
for ii in range(numb_test):
871-
test_out = test_data["atom_property"][ii].reshape(-1, 1)
879+
test_out = test_data[f"atom_{var_name}"][ii].reshape(-1, 1)
872880
pred_out = aproperty[ii].reshape(-1, 1)
873881

874882
frame_output = np.hstack((test_out, pred_out))

deepmd/infer/deep_eval.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ class DeepEvalBackend(ABC):
7070
"dipole_derv_c_redu": "virial",
7171
"dos": "atom_dos",
7272
"dos_redu": "dos",
73-
"property": "atom_property",
74-
"property_redu": "property",
7573
"mask_mag": "mask_mag",
7674
"mask": "mask",
7775
# old models in v1
@@ -276,6 +274,10 @@ def get_has_spin(self) -> bool:
276274
"""Check if the model has spin atom types."""
277275
return False
278276

277+
def get_var_name(self) -> str:
278+
"""Get the name of the fitting property."""
279+
raise NotImplementedError
280+
279281
@abstractmethod
280282
def get_ntypes_spin(self) -> int:
281283
"""Get the number of spin atom types of this model. Only used in old implement."""

deepmd/infer/deep_property.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,41 @@ class DeepProperty(DeepEval):
3737
Keyword arguments.
3838
"""
3939

40-
@property
4140
def output_def(self) -> ModelOutputDef:
42-
"""Get the output definition of this model."""
43-
return ModelOutputDef(
41+
"""
42+
Get the output definition of this model.
43+
But in property_fitting, the output definition is not known until the model is loaded.
44+
So we need to rewrite the output definition after the model is loaded.
45+
See detail in change_output_def.
46+
"""
47+
pass
48+
49+
def change_output_def(self) -> None:
50+
"""
51+
Change the output definition of this model.
52+
In property_fitting, the output definition is known after the model is loaded.
53+
We need to rewrite the output definition and related information.
54+
"""
55+
self.output_def = ModelOutputDef(
4456
FittingOutputDef(
4557
[
4658
OutputVariableDef(
47-
"property",
48-
shape=[-1],
59+
self.get_var_name(),
60+
shape=[self.get_task_dim()],
4961
reducible=True,
5062
atomic=True,
63+
intensive=self.get_intensive(),
5164
),
5265
]
5366
)
5467
)
55-
56-
def change_output_def(self) -> None:
57-
self.output_def["property"].shape = self.task_dim
58-
self.output_def["property"].intensive = self.get_intensive()
68+
self.deep_eval.output_def = self.output_def
69+
self.deep_eval._OUTDEF_DP2BACKEND[self.get_var_name()] = (
70+
f"atom_{self.get_var_name()}"
71+
)
72+
self.deep_eval._OUTDEF_DP2BACKEND[f"{self.get_var_name()}_redu"] = (
73+
self.get_var_name()
74+
)
5975

6076
@property
6177
def task_dim(self) -> int:
@@ -120,10 +136,12 @@ def eval(
120136
aparam=aparam,
121137
**kwargs,
122138
)
123-
atomic_property = results["property"].reshape(
139+
atomic_property = results[self.get_var_name()].reshape(
124140
nframes, natoms, self.get_task_dim()
125141
)
126-
property = results["property_redu"].reshape(nframes, self.get_task_dim())
142+
property = results[f"{self.get_var_name()}_redu"].reshape(
143+
nframes, self.get_task_dim()
144+
)
127145

128146
if atomic:
129147
return (
@@ -141,5 +159,9 @@ def get_intensive(self) -> bool:
141159
"""Get whether the property is intensive."""
142160
return self.deep_eval.get_intensive()
143161

162+
def get_var_name(self) -> str:
163+
"""Get the name of the fitting property."""
164+
return self.deep_eval.get_var_name()
165+
144166

145167
__all__ = ["DeepProperty"]

0 commit comments

Comments
 (0)