Skip to content

Commit e23dc5f

Browse files
committed
add dpa3 alpha
1 parent e9ed267 commit e23dc5f

File tree

7 files changed

+2321
-0
lines changed

7 files changed

+2321
-0
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
4+
class RepFlowArgs:
5+
def __init__(
6+
self,
7+
n_dim: int = 128,
8+
e_dim: int = 64,
9+
a_dim: int = 64,
10+
nlayers: int = 6,
11+
e_rcut: float = 6.0,
12+
e_rcut_smth: float = 5.0,
13+
e_sel: int = 120,
14+
a_rcut: float = 4.0,
15+
a_rcut_smth: float = 3.5,
16+
a_sel: int = 20,
17+
axis_neuron: int = 4,
18+
node_has_conv: bool = False,
19+
update_angle: bool = True,
20+
update_style: str = "res_residual",
21+
update_residual: float = 0.1,
22+
update_residual_init: str = "const",
23+
) -> None:
24+
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
25+
26+
Parameters
27+
----------
28+
n_dim : int, optional
29+
The dimension of node representation.
30+
e_dim : int, optional
31+
The dimension of edge representation.
32+
a_dim : int, optional
33+
The dimension of angle representation.
34+
nlayers : int, optional
35+
Number of repflow layers.
36+
e_rcut : float, optional
37+
The edge cut-off radius.
38+
e_rcut_smth : float, optional
39+
Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth.
40+
e_sel : int, optional
41+
Maximally possible number of selected edge neighbors.
42+
a_rcut : float, optional
43+
The angle cut-off radius.
44+
a_rcut_smth : float, optional
45+
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
46+
a_sel : int, optional
47+
Maximally possible number of selected angle neighbors.
48+
axis_neuron : int, optional
49+
The number of dimension of submatrix in the symmetrization ops.
50+
update_angle : bool, optional
51+
Where to update the angle rep. If not, only node and edge rep will be used.
52+
update_style : str, optional
53+
Style to update a representation.
54+
Supported options are:
55+
-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n)
56+
-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)
57+
-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n)
58+
where `r1`, `r2` ... `r3` are residual weights defined by `update_residual`
59+
and `update_residual_init`.
60+
update_residual : float, optional
61+
When update using residual mode, the initial std of residual vector weights.
62+
update_residual_init : str, optional
63+
When update using residual mode, the initialization mode of residual vector weights.
64+
"""
65+
self.n_dim = n_dim
66+
self.e_dim = e_dim
67+
self.a_dim = a_dim
68+
self.nlayers = nlayers
69+
self.e_rcut = e_rcut
70+
self.e_rcut_smth = e_rcut_smth
71+
self.e_sel = e_sel
72+
self.a_rcut = a_rcut
73+
self.a_rcut_smth = a_rcut_smth
74+
self.a_sel = a_sel
75+
self.axis_neuron = axis_neuron
76+
self.node_has_conv = node_has_conv # tmp
77+
self.update_angle = update_angle
78+
self.update_style = update_style
79+
self.update_residual = update_residual
80+
self.update_residual_init = update_residual_init
81+
82+
def __getitem__(self, key):
83+
if hasattr(self, key):
84+
return getattr(self, key)
85+
else:
86+
raise KeyError(key)
87+
88+
def serialize(self) -> dict:
89+
return {
90+
"n_dim": self.n_dim,
91+
"e_dim": self.e_dim,
92+
"a_dim": self.a_dim,
93+
"nlayers": self.nlayers,
94+
"e_rcut": self.e_rcut,
95+
"e_rcut_smth": self.e_rcut_smth,
96+
"e_sel": self.e_sel,
97+
"a_rcut": self.a_rcut,
98+
"a_rcut_smth": self.a_rcut_smth,
99+
"a_sel": self.a_sel,
100+
"axis_neuron": self.axis_neuron,
101+
"node_has_conv": self.node_has_conv, # tmp
102+
"update_angle": self.update_angle,
103+
"update_style": self.update_style,
104+
"update_residual": self.update_residual,
105+
"update_residual_init": self.update_residual_init,
106+
}
107+
108+
@classmethod
109+
def deserialize(cls, data: dict) -> "RepFlowArgs":
110+
return cls(**data)

deepmd/pt/model/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from .dpa2 import (
1414
DescrptDPA2,
1515
)
16+
from .dpa3 import (
17+
DescrptDPA3,
18+
)
1619
from .env_mat import (
1720
prod_env_mat,
1821
)
@@ -49,6 +52,7 @@
4952
"DescrptBlockSeTTebd",
5053
"DescrptDPA1",
5154
"DescrptDPA2",
55+
"DescrptDPA3",
5256
"DescrptHybrid",
5357
"DescrptSeA",
5458
"DescrptSeAttenV2",

0 commit comments

Comments
 (0)