Skip to content

Commit c0cadf5

Browse files
authored
feat(pt): add trainable to property fitting (#4599)
Add keyword `trainable` to property fitting. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a configurable "trainable" option that allows users to control whether network parameters are updated individually or collectively. - **Documentation** - Enhanced descriptions to clearly explain how the new trainability setting works. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 176c746 commit c0cadf5

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

deepmd/pt/model/task/property.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from typing import (
44
Optional,
5+
Union,
56
)
67

78
import torch
@@ -88,6 +89,7 @@ def __init__(
8889
activation_function: str = "tanh",
8990
precision: str = DEFAULT_PRECISION,
9091
mixed_types: bool = True,
92+
trainable: Union[bool, list[bool]] = True,
9193
seed: Optional[int] = None,
9294
**kwargs,
9395
) -> None:
@@ -107,6 +109,7 @@ def __init__(
107109
activation_function=activation_function,
108110
precision=precision,
109111
mixed_types=mixed_types,
112+
trainable=trainable,
110113
seed=seed,
111114
**kwargs,
112115
)

deepmd/utils/argcheck.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,9 @@ def fitting_property():
15801580
doc_task_dim = "The dimension of outputs of fitting net"
15811581
doc_intensive = "Whether the fitting property is intensive"
15821582
doc_property_name = "The names of fitting property, which should be consistent with the property name in the dataset."
1583+
doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\
1584+
- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\
1585+
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
15831586
return [
15841587
Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam),
15851588
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
@@ -1616,6 +1619,13 @@ def fitting_property():
16161619
optional=False,
16171620
doc=doc_property_name,
16181621
),
1622+
Argument(
1623+
"trainable",
1624+
[list[bool], bool],
1625+
optional=True,
1626+
default=True,
1627+
doc=doc_trainable,
1628+
),
16191629
]
16201630

16211631

0 commit comments

Comments
 (0)