Skip to content

Commit aa83a41

Browse files
authored
feat(pd): add se_atten_v2 (#4558)
Has been used for benchmark <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new descriptor type, `DescrptSeAttenV2`, with advanced embedding and attention mechanisms. - Enhanced support for serialization and deserialization of descriptor states. - **Tests** - Added comprehensive test coverage for the new descriptor implementation. - Expanded testing framework to support multiple backend configurations, including conditional testing for the `pd` backend. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 0c6c568 commit aa83a41

File tree

3 files changed

+325
-0
lines changed

3 files changed

+325
-0
lines changed

deepmd/pd/model/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
DescrptBlockSeA,
2323
DescrptSeA,
2424
)
25+
from .se_atten_v2 import (
26+
DescrptSeAttenV2,
27+
)
2528
from .se_t_tebd import (
2629
DescrptBlockSeTTebd,
2730
DescrptSeTTebd,
@@ -37,6 +40,7 @@
3740
"DescrptDPA1",
3841
"DescrptDPA2",
3942
"DescrptSeA",
43+
"DescrptSeAttenV2",
4044
"DescrptSeTTebd",
4145
"prod_env_mat",
4246
]
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Optional,
4+
Union,
5+
)
6+
7+
import paddle
8+
9+
from deepmd.dpmodel.utils import EnvMat as DPEnvMat
10+
from deepmd.pd.model.descriptor.dpa1 import (
11+
DescrptDPA1,
12+
)
13+
from deepmd.pd.model.network.mlp import (
14+
NetworkCollection,
15+
)
16+
from deepmd.pd.model.network.network import (
17+
TypeEmbedNetConsistent,
18+
)
19+
from deepmd.pd.utils import (
20+
env,
21+
)
22+
from deepmd.pd.utils.env import (
23+
RESERVED_PRECISION_DICT,
24+
)
25+
from deepmd.utils.version import (
26+
check_version_compatibility,
27+
)
28+
29+
from .base_descriptor import (
30+
BaseDescriptor,
31+
)
32+
from .se_atten import (
33+
NeighborGatedAttention,
34+
)
35+
36+
37+
@BaseDescriptor.register("se_atten_v2")
38+
class DescrptSeAttenV2(DescrptDPA1):
39+
def __init__(
40+
self,
41+
rcut: float,
42+
rcut_smth: float,
43+
sel: Union[list[int], int],
44+
ntypes: int,
45+
neuron: list = [25, 50, 100],
46+
axis_neuron: int = 16,
47+
tebd_dim: int = 8,
48+
set_davg_zero: bool = True,
49+
attn: int = 128,
50+
attn_layer: int = 2,
51+
attn_dotr: bool = True,
52+
attn_mask: bool = False,
53+
activation_function: str = "tanh",
54+
precision: str = "float64",
55+
resnet_dt: bool = False,
56+
exclude_types: list[tuple[int, int]] = [],
57+
env_protection: float = 0.0,
58+
scaling_factor: int = 1.0,
59+
normalize=True,
60+
temperature=None,
61+
concat_output_tebd: bool = True,
62+
trainable: bool = True,
63+
trainable_ln: bool = True,
64+
ln_eps: Optional[float] = 1e-5,
65+
type_one_side: bool = False,
66+
stripped_type_embedding: Optional[bool] = None,
67+
seed: Optional[Union[int, list[int]]] = None,
68+
use_econf_tebd: bool = False,
69+
use_tebd_bias: bool = False,
70+
type_map: Optional[list[str]] = None,
71+
# not implemented
72+
spin=None,
73+
type: Optional[str] = None,
74+
) -> None:
75+
r"""Construct smooth version of embedding net of type `se_atten_v2`.
76+
77+
Parameters
78+
----------
79+
rcut : float
80+
The cut-off radius :math:`r_c`
81+
rcut_smth : float
82+
From where the environment matrix should be smoothed :math:`r_s`
83+
sel : list[int], int
84+
list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius
85+
int: the total maxmum number of atoms in the cut-off radius
86+
ntypes : int
87+
Number of element types
88+
neuron : list[int]
89+
Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
90+
axis_neuron : int
91+
Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix)
92+
tebd_dim : int
93+
Dimension of the type embedding
94+
set_davg_zero : bool
95+
Set the shift of embedding net input to zero.
96+
attn : int
97+
Hidden dimension of the attention vectors
98+
attn_layer : int
99+
Number of attention layers
100+
attn_dotr : bool
101+
If dot the angular gate to the attention weights
102+
attn_mask : bool
103+
(Only support False to keep consistent with other backend references.)
104+
(Not used in this version.)
105+
If mask the diagonal of attention weights
106+
activation_function : str
107+
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
108+
precision : str
109+
The precision of the embedding net parameters. Supported options are |PRECISION|
110+
resnet_dt : bool
111+
Time-step `dt` in the resnet construction:
112+
y = x + dt * \phi (Wx + b)
113+
exclude_types : list[list[int]]
114+
The excluded pairs of types which have no interaction with each other.
115+
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
116+
env_protection : float
117+
Protection parameter to prevent division by zero errors during environment matrix calculations.
118+
scaling_factor : float
119+
The scaling factor of normalization in calculations of attention weights.
120+
If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5
121+
normalize : bool
122+
Whether to normalize the hidden vectors in attention weights calculation.
123+
temperature : float
124+
If not None, the scaling of attention weights is `temperature` itself.
125+
concat_output_tebd : bool
126+
Whether to concat type embedding at the output of the descriptor.
127+
trainable : bool
128+
If the weights of this descriptors are trainable.
129+
trainable_ln : bool
130+
Whether to use trainable shift and scale weights in layer normalization.
131+
ln_eps : float, Optional
132+
The epsilon value for layer normalization.
133+
type_one_side : bool
134+
If 'False', type embeddings of both neighbor and central atoms are considered.
135+
If 'True', only type embeddings of neighbor atoms are considered.
136+
Default is 'False'.
137+
stripped_type_embedding : bool, Optional
138+
(Deprecated, kept only for compatibility.)
139+
Whether to strip the type embedding into a separate embedding network.
140+
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
141+
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
142+
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
143+
seed : int, Optional
144+
Random seed for parameter initialization.
145+
use_econf_tebd : bool, Optional
146+
Whether to use electronic configuration type embedding.
147+
use_tebd_bias : bool, Optional
148+
Whether to use bias in the type embedding layer.
149+
type_map : list[str], Optional
150+
A list of strings. Give the name to each type of atoms.
151+
spin
152+
(Only support None to keep consistent with other backend references.)
153+
(Not used in this version. Not-none option is not implemented.)
154+
The old implementation of deepspin.
155+
"""
156+
DescrptDPA1.__init__(
157+
self,
158+
rcut,
159+
rcut_smth,
160+
sel,
161+
ntypes,
162+
neuron=neuron,
163+
axis_neuron=axis_neuron,
164+
tebd_dim=tebd_dim,
165+
tebd_input_mode="strip",
166+
set_davg_zero=set_davg_zero,
167+
attn=attn,
168+
attn_layer=attn_layer,
169+
attn_dotr=attn_dotr,
170+
attn_mask=attn_mask,
171+
activation_function=activation_function,
172+
precision=precision,
173+
resnet_dt=resnet_dt,
174+
exclude_types=exclude_types,
175+
env_protection=env_protection,
176+
scaling_factor=scaling_factor,
177+
normalize=normalize,
178+
temperature=temperature,
179+
concat_output_tebd=concat_output_tebd,
180+
trainable=trainable,
181+
trainable_ln=trainable_ln,
182+
ln_eps=ln_eps,
183+
smooth_type_embedding=True,
184+
type_one_side=type_one_side,
185+
stripped_type_embedding=stripped_type_embedding,
186+
seed=seed,
187+
use_econf_tebd=use_econf_tebd,
188+
use_tebd_bias=use_tebd_bias,
189+
type_map=type_map,
190+
# not implemented
191+
spin=spin,
192+
type=type,
193+
)
194+
195+
def serialize(self) -> dict:
196+
obj = self.se_atten
197+
data = {
198+
"@class": "Descriptor",
199+
"type": "se_atten_v2",
200+
"@version": 2,
201+
"rcut": obj.rcut,
202+
"rcut_smth": obj.rcut_smth,
203+
"sel": obj.sel,
204+
"ntypes": obj.ntypes,
205+
"neuron": obj.neuron,
206+
"axis_neuron": obj.axis_neuron,
207+
"tebd_dim": obj.tebd_dim,
208+
"set_davg_zero": obj.set_davg_zero,
209+
"attn": obj.attn_dim,
210+
"attn_layer": obj.attn_layer,
211+
"attn_dotr": obj.attn_dotr,
212+
"attn_mask": False,
213+
"activation_function": obj.activation_function,
214+
"resnet_dt": obj.resnet_dt,
215+
"scaling_factor": obj.scaling_factor,
216+
"normalize": obj.normalize,
217+
"temperature": obj.temperature,
218+
"trainable_ln": obj.trainable_ln,
219+
"ln_eps": obj.ln_eps,
220+
"type_one_side": obj.type_one_side,
221+
"concat_output_tebd": self.concat_output_tebd,
222+
"use_econf_tebd": self.use_econf_tebd,
223+
"use_tebd_bias": self.use_tebd_bias,
224+
"type_map": self.type_map,
225+
# make deterministic
226+
"precision": RESERVED_PRECISION_DICT[obj.prec],
227+
"embeddings": obj.filter_layers.serialize(),
228+
"embeddings_strip": obj.filter_layers_strip.serialize(),
229+
"attention_layers": obj.dpa1_attention.serialize(),
230+
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
231+
"type_embedding": self.type_embedding.embedding.serialize(),
232+
"exclude_types": obj.exclude_types,
233+
"env_protection": obj.env_protection,
234+
"@variables": {
235+
"davg": obj["davg"].detach().cpu().numpy(),
236+
"dstd": obj["dstd"].detach().cpu().numpy(),
237+
},
238+
"trainable": self.trainable,
239+
"spin": None,
240+
}
241+
return data
242+
243+
@classmethod
244+
def deserialize(cls, data: dict) -> "DescrptSeAttenV2":
245+
data = data.copy()
246+
check_version_compatibility(data.pop("@version"), 2, 1)
247+
data.pop("@class")
248+
data.pop("type")
249+
variables = data.pop("@variables")
250+
embeddings = data.pop("embeddings")
251+
type_embedding = data.pop("type_embedding")
252+
attention_layers = data.pop("attention_layers")
253+
data.pop("env_mat")
254+
embeddings_strip = data.pop("embeddings_strip")
255+
# compat with version 1
256+
if "use_tebd_bias" not in data:
257+
data["use_tebd_bias"] = True
258+
obj = cls(**data)
259+
260+
def t_cvt(xx):
261+
return paddle.to_tensor(xx, dtype=obj.se_atten.prec, place=env.DEVICE)
262+
263+
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
264+
type_embedding
265+
)
266+
obj.se_atten["davg"] = t_cvt(variables["davg"])
267+
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
268+
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
269+
obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
270+
embeddings_strip
271+
)
272+
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
273+
attention_layers
274+
)
275+
return obj

source/tests/consistent/descriptor/test_se_atten_v2.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..common import (
1919
INSTALLED_ARRAY_API_STRICT,
2020
INSTALLED_JAX,
21+
INSTALLED_PD,
2122
INSTALLED_PT,
2223
CommonTest,
2324
parameterized,
@@ -44,6 +45,12 @@
4445
)
4546
else:
4647
DescrptSeAttenV2Strict = None
48+
if INSTALLED_PD:
49+
from deepmd.pd.model.descriptor.se_atten_v2 import (
50+
DescrptSeAttenV2 as DescrptSeAttenV2PD,
51+
)
52+
else:
53+
DescrptSeAttenV2PD = None
4754
DescrptSeAttenV2TF = None
4855
from deepmd.utils.argcheck import (
4956
descrpt_se_atten_args,
@@ -248,11 +255,40 @@ def skip_array_api_strict(self) -> bool:
248255
)
249256
)
250257

258+
@property
259+
def skip_pd(self) -> bool:
260+
(
261+
tebd_dim,
262+
resnet_dt,
263+
type_one_side,
264+
attn,
265+
attn_layer,
266+
attn_dotr,
267+
excluded_types,
268+
env_protection,
269+
set_davg_zero,
270+
scaling_factor,
271+
normalize,
272+
temperature,
273+
ln_eps,
274+
concat_output_tebd,
275+
precision,
276+
use_econf_tebd,
277+
use_tebd_bias,
278+
) = self.param
279+
return not INSTALLED_PD or self.is_meaningless_zero_attention_layer_tests(
280+
attn_layer,
281+
attn_dotr,
282+
normalize,
283+
temperature,
284+
)
285+
251286
tf_class = DescrptSeAttenV2TF
252287
dp_class = DescrptSeAttenV2DP
253288
pt_class = DescrptSeAttenV2PT
254289
jax_class = DescrptSeAttenV2JAX
255290
array_api_strict_class = DescrptSeAttenV2Strict
291+
pd_class = DescrptSeAttenV2PD
256292
args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False))
257293

258294
def setUp(self) -> None:
@@ -339,6 +375,16 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
339375
mixed_types=True,
340376
)
341377

378+
def eval_pd(self, pd_obj: Any) -> Any:
379+
return self.eval_pd_descriptor(
380+
pd_obj,
381+
self.natoms,
382+
self.coords,
383+
self.atype,
384+
self.box,
385+
mixed_types=True,
386+
)
387+
342388
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
343389
return (ret[0], ret[1])
344390

0 commit comments

Comments
 (0)