Skip to content

Commit 9d9dc8f

Browse files
committed
add n_multi_edge_message
1 parent fec6462 commit 9d9dc8f

File tree

7 files changed

+57
-11
lines changed

7 files changed

+57
-11
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
a_compress_rate: int = 0,
1818
a_compress_e_rate: int = 1,
1919
a_compress_use_split: bool = False,
20+
n_multi_edge_message: int = 1,
2021
axis_neuron: int = 4,
2122
update_angle: bool = True,
2223
update_style: str = "res_residual",
@@ -59,6 +60,9 @@ def __init__(
5960
a_compress_use_split : bool, optional
6061
Whether to split first sub-vectors instead of linear mapping during angular message compression.
6162
The default value is False.
63+
n_multi_edge_message : int, optional
64+
The head number of multiple edge messages to update node feature.
65+
Default is 1, indicating one head edge message.
6266
axis_neuron : int, optional
6367
The number of dimension of submatrix in the symmetrization ops.
6468
update_angle : bool, optional
@@ -87,6 +91,7 @@ def __init__(
8791
self.a_rcut_smth = a_rcut_smth
8892
self.a_sel = a_sel
8993
self.a_compress_rate = a_compress_rate
94+
self.n_multi_edge_message = n_multi_edge_message
9095
self.axis_neuron = axis_neuron
9196
self.update_angle = update_angle
9297
self.update_style = update_style
@@ -117,6 +122,7 @@ def serialize(self) -> dict:
117122
"a_compress_rate": self.a_compress_rate,
118123
"a_compress_e_rate": self.a_compress_e_rate,
119124
"a_compress_use_split": self.a_compress_use_split,
125+
"n_multi_edge_message": self.n_multi_edge_message,
120126
"axis_neuron": self.axis_neuron,
121127
"update_angle": self.update_angle,
122128
"update_style": self.update_style,

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def init_subclass_params(sub_data, sub_class):
154154
a_compress_rate=self.repflow_args.a_compress_rate,
155155
a_compress_e_rate=self.repflow_args.a_compress_e_rate,
156156
a_compress_use_split=self.repflow_args.a_compress_use_split,
157+
n_multi_edge_message=self.repflow_args.n_multi_edge_message,
157158
axis_neuron=self.repflow_args.axis_neuron,
158159
update_angle=self.repflow_args.update_angle,
159160
activation_function=self.activation_function,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
a_compress_rate: int = 0,
4949
a_compress_use_split: bool = False,
5050
a_compress_e_rate: int = 1,
51+
n_multi_edge_message: int = 1,
5152
axis_neuron: int = 4,
5253
update_angle: bool = True, # angle
5354
activation_function: str = "silu",
@@ -79,6 +80,8 @@ def __init__(
7980
f"For a_compress_rate of {a_compress_rate}, a_dim must be divisible by {2 * a_compress_rate}. "
8081
f"Currently, a_dim={a_dim} is not valid."
8182
)
83+
self.n_multi_edge_message = n_multi_edge_message
84+
assert self.n_multi_edge_message >= 1, "n_multi_edge_message must >= 1!"
8285
self.axis_neuron = axis_neuron
8386
self.update_angle = update_angle
8487
self.activation_function = activation_function
@@ -144,20 +147,21 @@ def __init__(
144147
# node edge message
145148
self.node_edge_linear = MLPLayer(
146149
self.edge_info_dim,
147-
n_dim,
150+
self.n_multi_edge_message * n_dim,
148151
precision=precision,
149152
seed=child_seed(seed, 4),
150153
)
151154
if self.update_style == "res_residual":
152-
self.n_residual.append(
153-
get_residual(
154-
n_dim,
155-
self.update_residual,
156-
self.update_residual_init,
157-
precision=precision,
158-
seed=child_seed(seed, 5),
155+
for head_index in range(self.n_multi_edge_message):
156+
self.n_residual.append(
157+
get_residual(
158+
n_dim,
159+
self.update_residual,
160+
self.update_residual_init,
161+
precision=precision,
162+
seed=child_seed(child_seed(seed, 5), head_index),
163+
)
159164
)
160-
)
161165

162166
# edge self message
163167
self.edge_self_linear = MLPLayer(
@@ -479,10 +483,18 @@ def forward(
479483
)
480484

481485
# node edge message
482-
# nb x nloc x nnei x n_dim
486+
# nb x nloc x nnei x (h * n_dim)
483487
node_edge_update = self.act(self.node_edge_linear(edge_info)) * sw.unsqueeze(-1)
484488
node_edge_update = torch.sum(node_edge_update, dim=-2) / self.nnei
485-
n_update_list.append(node_edge_update)
489+
if self.n_multi_edge_message > 1:
490+
# nb x nloc x nnei x h x n_dim
491+
node_edge_update_mul_head = node_edge_update.view(
492+
nb, nloc, self.n_multi_edge_message, self.n_dim
493+
)
494+
for head_index in range(self.n_multi_edge_message):
495+
n_update_list.append(node_edge_update_mul_head[:, :, head_index, :])
496+
else:
497+
n_update_list.append(node_edge_update)
486498
# update node_ebd
487499
n_updated = self.list_update(n_update_list, "node")
488500

@@ -670,6 +682,7 @@ def serialize(self) -> dict:
670682
"a_compress_rate": self.a_compress_rate,
671683
"a_compress_e_rate": self.a_compress_e_rate,
672684
"a_compress_use_split": self.a_compress_use_split,
685+
"n_multi_edge_message": self.n_multi_edge_message,
673686
"axis_neuron": self.axis_neuron,
674687
"activation_function": self.activation_function,
675688
"update_angle": self.update_angle,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
a_compress_rate: int = 0,
8989
a_compress_e_rate: int = 1,
9090
a_compress_use_split: bool = False,
91+
n_multi_edge_message: int = 1,
9192
axis_neuron: int = 4,
9293
update_angle: bool = True,
9394
activation_function: str = "silu",
@@ -137,6 +138,9 @@ def __init__(
137138
a_compress_use_split : bool, optional
138139
Whether to split first sub-vectors instead of linear mapping during angular message compression.
139140
The default value is False.
141+
n_multi_edge_message : int, optional
142+
The head number of multiple edge messages to update node feature.
143+
Default is 1, indicating one head edge message.
140144
axis_neuron : int, optional
141145
The number of dimension of submatrix in the symmetrization ops.
142146
update_angle : bool, optional
@@ -191,6 +195,7 @@ def __init__(
191195
self.split_sel = self.sel
192196
self.a_compress_rate = a_compress_rate
193197
self.a_compress_e_rate = a_compress_e_rate
198+
self.n_multi_edge_message = n_multi_edge_message
194199
self.axis_neuron = axis_neuron
195200
self.set_davg_zero = set_davg_zero
196201
self.skip_stat = skip_stat
@@ -238,6 +243,7 @@ def __init__(
238243
a_compress_rate=self.a_compress_rate,
239244
a_compress_use_split=self.a_compress_use_split,
240245
a_compress_e_rate=self.a_compress_e_rate,
246+
n_multi_edge_message=self.n_multi_edge_message,
241247
axis_neuron=self.axis_neuron,
242248
update_angle=self.update_angle,
243249
activation_function=self.activation_function,

deepmd/utils/argcheck.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,10 @@ def dpa3_repflow_args():
14541454
"Whether to split first sub-vectors instead of linear mapping during angular message compression. "
14551455
"The default value is False."
14561456
)
1457+
doc_n_multi_edge_message = (
1458+
"The head number of multiple edge messages to update node feature. "
1459+
"Default is 1, indicating one head edge message."
1460+
)
14571461
doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops."
14581462
doc_update_angle = (
14591463
"Where to update the angle rep. If not, only node and edge rep will be used."
@@ -1506,6 +1510,13 @@ def dpa3_repflow_args():
15061510
default=False,
15071511
doc=doc_a_compress_use_split,
15081512
),
1513+
Argument(
1514+
"n_multi_edge_message",
1515+
int,
1516+
optional=True,
1517+
default=1,
1518+
doc=doc_n_multi_edge_message,
1519+
),
15091520
Argument(
15101521
"axis_neuron",
15111522
int,

source/tests/pt/model/test_dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ def test_consistency(
4949
rus,
5050
ruri,
5151
acr,
52+
nme,
5253
prec,
5354
ect,
5455
) in itertools.product(
5556
[True, False], # update_angle
5657
["res_residual"], # update_style
5758
["norm", "const"], # update_residual_init
5859
[0, 1], # a_compress_rate
60+
[1, 2], # n_multi_edge_message
5961
["float64"], # precision
6062
[False], # use_econf_tebd
6163
):
@@ -76,6 +78,7 @@ def test_consistency(
7678
a_rcut_smth=self.rcut_smth,
7779
a_sel=nnei - 1,
7880
a_compress_rate=acr,
81+
n_multi_edge_message=nme,
7982
axis_neuron=4,
8083
update_angle=ua,
8184
update_style=rus,
@@ -131,13 +134,15 @@ def test_jit(
131134
rus,
132135
ruri,
133136
acr,
137+
nme,
134138
prec,
135139
ect,
136140
) in itertools.product(
137141
[True, False], # update_angle
138142
["res_residual"], # update_style
139143
["norm", "const"], # update_residual_init
140144
[0, 1], # a_compress_rate
145+
[1, 2], # n_multi_edge_message
141146
["float64"], # precision
142147
[False], # use_econf_tebd
143148
):
@@ -156,6 +161,7 @@ def test_jit(
156161
a_rcut_smth=self.rcut_smth,
157162
a_sel=nnei - 1,
158163
a_compress_rate=acr,
164+
n_multi_edge_message=nme,
159165
axis_neuron=4,
160166
update_angle=ua,
161167
update_style=rus,

source/tests/universal/dpmodel/descriptor/test_descriptor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def DescriptorParamDPA3(
475475
update_residual=0.1,
476476
update_residual_init="const",
477477
update_angle=True,
478+
n_multi_edge_message=1,
478479
a_compress_rate=0,
479480
precision="float64",
480481
):
@@ -493,6 +494,7 @@ def DescriptorParamDPA3(
493494
"a_rcut_smth": rcut_smth / 2,
494495
"a_sel": sum(sel) // 4,
495496
"a_compress_rate": a_compress_rate,
497+
"n_multi_edge_message": n_multi_edge_message,
496498
"axis_neuron": 4,
497499
"update_angle": update_angle,
498500
"update_style": update_style,
@@ -523,6 +525,7 @@ def DescriptorParamDPA3(
523525
"exclude_types": ([], [[0, 1]]),
524526
"update_angle": (True, False),
525527
"a_compress_rate": (0, 1),
528+
"n_multi_edge_message": (1, 2),
526529
"env_protection": (0.0, 1e-8),
527530
"precision": ("float64",),
528531
}

0 commit comments

Comments
 (0)