Skip to content

Commit e1b7a9f

Browse files
iProzdnjzjz
andauthored
fix(pt/dp): fix non-smooth edge update in DPA3 (#4675)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced an optional "smooth edge update" setting that offers a refined update process by altering the default padding behavior. This enhancement provides users with increased control over descriptor computations. - **Tests** - Expanded and updated test cases to validate the new smooth edge update functionality under various scenarios, ensuring consistent and reliable performance. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 78b8f01 commit e1b7a9f

File tree

13 files changed

+123
-45
lines changed

13 files changed

+123
-45
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ class RepFlowArgs:
121121
optim_update : bool, optional
122122
Whether to enable the optimized update method.
123123
Uses a more efficient process when enabled. Defaults to True
124+
smooth_edge_update : bool, optional
125+
Whether to make edge update smooth.
126+
If True, the edge update from angle message will not use self as padding.
124127
"""
125128

126129
def __init__(
@@ -147,6 +150,7 @@ def __init__(
147150
fix_stat_std: float = 0.3,
148151
skip_stat: bool = False,
149152
optim_update: bool = True,
153+
smooth_edge_update: bool = False,
150154
) -> None:
151155
self.n_dim = n_dim
152156
self.e_dim = e_dim
@@ -172,6 +176,7 @@ def __init__(
172176
self.a_compress_e_rate = a_compress_e_rate
173177
self.a_compress_use_split = a_compress_use_split
174178
self.optim_update = optim_update
179+
self.smooth_edge_update = smooth_edge_update
175180

176181
def __getitem__(self, key):
177182
if hasattr(self, key):
@@ -202,6 +207,7 @@ def serialize(self) -> dict:
202207
"update_residual_init": self.update_residual_init,
203208
"fix_stat_std": self.fix_stat_std,
204209
"optim_update": self.optim_update,
210+
"smooth_edge_update": self.smooth_edge_update,
205211
}
206212

207213
@classmethod
@@ -297,6 +303,7 @@ def init_subclass_params(sub_data, sub_class):
297303
update_residual_init=self.repflow_args.update_residual_init,
298304
fix_stat_std=self.repflow_args.fix_stat_std,
299305
optim_update=self.repflow_args.optim_update,
306+
smooth_edge_update=self.repflow_args.smooth_edge_update,
300307
exclude_types=exclude_types,
301308
env_protection=env_protection,
302309
precision=precision,

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
114114
optim_update : bool, optional
115115
Whether to enable the optimized update method.
116116
Uses a more efficient process when enabled. Defaults to True
117+
smooth_edge_update : bool, optional
118+
Whether to make edge update smooth.
119+
If True, the edge update from angle message will not use self as padding.
117120
ntypes : int
118121
Number of element types
119122
activation_function : str, optional
@@ -161,6 +164,7 @@ def __init__(
161164
precision: str = "float64",
162165
fix_stat_std: float = 0.3,
163166
optim_update: bool = True,
167+
smooth_edge_update: bool = False,
164168
seed: Optional[Union[int, list[int]]] = None,
165169
) -> None:
166170
super().__init__()
@@ -191,6 +195,7 @@ def __init__(
191195
self.set_stddev_constant = fix_stat_std != 0.0
192196
self.a_compress_use_split = a_compress_use_split
193197
self.optim_update = optim_update
198+
self.smooth_edge_update = smooth_edge_update
194199

195200
self.n_dim = n_dim
196201
self.e_dim = e_dim
@@ -243,6 +248,7 @@ def __init__(
243248
update_residual_init=self.update_residual_init,
244249
precision=precision,
245250
optim_update=self.optim_update,
251+
smooth_edge_update=self.smooth_edge_update,
246252
seed=child_seed(child_seed(seed, 1), ii),
247253
)
248254
)
@@ -563,6 +569,7 @@ def __init__(
563569
axis_neuron: int = 4,
564570
update_angle: bool = True,
565571
optim_update: bool = True,
572+
smooth_edge_update: bool = False,
566573
activation_function: str = "silu",
567574
update_style: str = "res_residual",
568575
update_residual: float = 0.1,
@@ -607,6 +614,7 @@ def __init__(
607614
self.seed = seed
608615
self.prec = PRECISION_DICT[precision]
609616
self.optim_update = optim_update
617+
self.smooth_edge_update = smooth_edge_update
610618

611619
assert update_residual_init in [
612620
"norm",
@@ -1136,19 +1144,23 @@ def call(
11361144
],
11371145
axis=2,
11381146
)
1139-
full_mask = xp.concat(
1140-
[
1141-
a_nlist_mask,
1142-
xp.zeros(
1143-
(nb, nloc, self.nnei - self.a_sel),
1144-
dtype=a_nlist_mask.dtype,
1145-
),
1146-
],
1147-
axis=-1,
1148-
)
1149-
padding_edge_angle_update = xp.where(
1150-
xp.expand_dims(full_mask, axis=-1), padding_edge_angle_update, edge_ebd
1151-
)
1147+
if not self.smooth_edge_update:
1148+
# will be deprecated in the future
1149+
full_mask = xp.concat(
1150+
[
1151+
a_nlist_mask,
1152+
xp.zeros(
1153+
(nb, nloc, self.nnei - self.a_sel),
1154+
dtype=a_nlist_mask.dtype,
1155+
),
1156+
],
1157+
axis=-1,
1158+
)
1159+
padding_edge_angle_update = xp.where(
1160+
xp.expand_dims(full_mask, axis=-1),
1161+
padding_edge_angle_update,
1162+
edge_ebd,
1163+
)
11521164
e_update_list.append(
11531165
self.act(self.edge_angle_linear2(padding_edge_angle_update))
11541166
)
@@ -1235,7 +1247,7 @@ def serialize(self) -> dict:
12351247
The serialized networks.
12361248
"""
12371249
data = {
1238-
"@class": "RepformerLayer",
1250+
"@class": "RepFlowLayer",
12391251
"@version": 1,
12401252
"e_rcut": self.e_rcut,
12411253
"e_rcut_smth": self.e_rcut_smth,
@@ -1259,6 +1271,7 @@ def serialize(self) -> dict:
12591271
"update_residual_init": self.update_residual_init,
12601272
"precision": self.precision,
12611273
"optim_update": self.optim_update,
1274+
"smooth_edge_update": self.smooth_edge_update,
12621275
"node_self_mlp": self.node_self_mlp.serialize(),
12631276
"node_sym_linear": self.node_sym_linear.serialize(),
12641277
"node_edge_linear": self.node_edge_linear.serialize(),

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def init_subclass_params(sub_data, sub_class):
149149
update_residual_init=self.repflow_args.update_residual_init,
150150
fix_stat_std=self.repflow_args.fix_stat_std,
151151
optim_update=self.repflow_args.optim_update,
152+
smooth_edge_update=self.repflow_args.smooth_edge_update,
152153
exclude_types=exclude_types,
153154
env_protection=env_protection,
154155
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
axis_neuron: int = 4,
5353
update_angle: bool = True,
5454
optim_update: bool = True,
55+
smooth_edge_update: bool = False,
5556
activation_function: str = "silu",
5657
update_style: str = "res_residual",
5758
update_residual: float = 0.1,
@@ -96,6 +97,7 @@ def __init__(
9697
self.seed = seed
9798
self.prec = PRECISION_DICT[precision]
9899
self.optim_update = optim_update
100+
self.smooth_edge_update = smooth_edge_update
99101

100102
assert update_residual_init in [
101103
"norm",
@@ -718,20 +720,22 @@ def forward(
718720
],
719721
dim=2,
720722
)
721-
full_mask = torch.concat(
722-
[
723-
a_nlist_mask,
724-
torch.zeros(
725-
[nb, nloc, self.nnei - self.a_sel],
726-
dtype=a_nlist_mask.dtype,
727-
device=a_nlist_mask.device,
728-
),
729-
],
730-
dim=-1,
731-
)
732-
padding_edge_angle_update = torch.where(
733-
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
734-
)
723+
if not self.smooth_edge_update:
724+
# will be deprecated in the future
725+
full_mask = torch.concat(
726+
[
727+
a_nlist_mask,
728+
torch.zeros(
729+
[nb, nloc, self.nnei - self.a_sel],
730+
dtype=a_nlist_mask.dtype,
731+
device=a_nlist_mask.device,
732+
),
733+
],
734+
dim=-1,
735+
)
736+
padding_edge_angle_update = torch.where(
737+
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
738+
)
735739
e_update_list.append(
736740
self.act(self.edge_angle_linear2(padding_edge_angle_update))
737741
)
@@ -823,7 +827,7 @@ def serialize(self) -> dict:
823827
The serialized networks.
824828
"""
825829
data = {
826-
"@class": "RepformerLayer",
830+
"@class": "RepFlowLayer",
827831
"@version": 1,
828832
"e_rcut": self.e_rcut,
829833
"e_rcut_smth": self.e_rcut_smth,
@@ -847,6 +851,7 @@ def serialize(self) -> dict:
847851
"update_residual_init": self.update_residual_init,
848852
"precision": self.precision,
849853
"optim_update": self.optim_update,
854+
"smooth_edge_update": self.smooth_edge_update,
850855
"node_self_mlp": self.node_self_mlp.serialize(),
851856
"node_sym_linear": self.node_sym_linear.serialize(),
852857
"node_edge_linear": self.node_edge_linear.serialize(),

deepmd/pt/model/descriptor/repflows.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ class DescrptBlockRepflows(DescriptorBlock):
130130
fix_stat_std : float, optional
131131
If non-zero (default is 0.3), use this constant as the normalization standard deviation
132132
instead of computing it from data statistics.
133+
smooth_edge_update : bool, optional
134+
Whether to make edge update smooth.
135+
If True, the edge update from angle message will not use self as padding.
133136
optim_update : bool, optional
134137
Whether to enable the optimized update method.
135138
Uses a more efficient process when enabled. Defaults to True
@@ -179,6 +182,7 @@ def __init__(
179182
env_protection: float = 0.0,
180183
precision: str = "float64",
181184
fix_stat_std: float = 0.3,
185+
smooth_edge_update: bool = False,
182186
optim_update: bool = True,
183187
seed: Optional[Union[int, list[int]]] = None,
184188
) -> None:
@@ -210,6 +214,7 @@ def __init__(
210214
self.set_stddev_constant = fix_stat_std != 0.0
211215
self.a_compress_use_split = a_compress_use_split
212216
self.optim_update = optim_update
217+
self.smooth_edge_update = smooth_edge_update
213218

214219
self.n_dim = n_dim
215220
self.e_dim = e_dim
@@ -262,6 +267,7 @@ def __init__(
262267
update_residual_init=self.update_residual_init,
263268
precision=precision,
264269
optim_update=self.optim_update,
270+
smooth_edge_update=self.smooth_edge_update,
265271
seed=child_seed(child_seed(seed, 1), ii),
266272
)
267273
)

deepmd/utils/argcheck.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,10 @@ def dpa3_repflow_args():
14931493
"Whether to enable the optimized update method. "
14941494
"Uses a more efficient process when enabled. Defaults to True"
14951495
)
1496+
doc_smooth_edge_update = (
1497+
"Whether to make edge update smooth. "
1498+
"If True, the edge update from angle message will not use self as padding."
1499+
)
14961500

14971501
return [
14981502
# repflow args
@@ -1586,6 +1590,13 @@ def dpa3_repflow_args():
15861590
default=True,
15871591
doc=doc_optim_update,
15881592
),
1593+
Argument(
1594+
"smooth_edge_update",
1595+
bool,
1596+
optional=True,
1597+
default=False, # For compatability. This will be True in the future
1598+
doc=doc_smooth_edge_update,
1599+
),
15891600
]
15901601

15911602

source/tests/pt/model/test_dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_consistency(
9090
update_angle=ua,
9191
update_style=rus,
9292
update_residual_init=ruri,
93+
smooth_edge_update=True,
9394
)
9495

9596
# dpa3 new impl
@@ -190,6 +191,7 @@ def test_jit(
190191
update_angle=ua,
191192
update_style=rus,
192193
update_residual_init=ruri,
194+
smooth_edge_update=True,
193195
)
194196

195197
# dpa3 new impl

source/tests/pt/model/test_permutation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@
177177
"a_dim": 8,
178178
"nlayers": 6,
179179
"e_rcut": 6.0,
180-
"e_rcut_smth": 5.0,
180+
"e_rcut_smth": 3.0,
181181
"e_sel": 20,
182182
"a_rcut": 4.0,
183-
"a_rcut_smth": 3.5,
183+
"a_rcut_smth": 2.0,
184184
"a_sel": 10,
185185
"axis_neuron": 4,
186186
"a_compress_rate": 1,
@@ -190,8 +190,9 @@
190190
"update_style": "res_residual",
191191
"update_residual": 0.1,
192192
"update_residual_init": "const",
193+
"smooth_edge_update": True,
193194
},
194-
"activation_function": "silu",
195+
"activation_function": "silut:10.0",
195196
"use_tebd_bias": False,
196197
"precision": "float32",
197198
"concat_output_tebd": False,
@@ -200,7 +201,7 @@
200201
"neuron": [24, 24],
201202
"resnet_dt": True,
202203
"precision": "float32",
203-
"activation_function": "silu",
204+
"activation_function": "silut:10.0",
204205
"seed": 1,
205206
},
206207
}

source/tests/pt/model/test_smooth.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
model_dos,
2222
model_dpa1,
2323
model_dpa2,
24+
model_dpa3,
2425
model_hybrid,
2526
model_se_e2_a,
2627
model_spin,
@@ -59,10 +60,16 @@ def test(
5960
0.0,
6061
4.0 - 0.5 * epsilon,
6162
0.0,
63+
6.0 - 0.5 * epsilon,
64+
0.0,
65+
0.0,
66+
0.0,
67+
6.0 - 0.5 * epsilon,
68+
0.0,
6269
],
6370
dtype=dtype,
6471
device=env.DEVICE,
65-
).view([-1, 3])
72+
).view([-1, 3]) # to test descriptors with two rcuts, e.g. DPA2/3
6673
coord1 = torch.rand(
6774
[natoms - coord0.shape[0], 3],
6875
dtype=dtype,
@@ -77,11 +84,15 @@ def test(
7784
coord0 = torch.clone(coord)
7885
coord1 = torch.clone(coord)
7986
coord1[1][0] += epsilon
87+
coord1[3][0] += epsilon
8088
coord2 = torch.clone(coord)
8189
coord2[2][1] += epsilon
90+
coord2[4][1] += epsilon
8291
coord3 = torch.clone(coord)
8392
coord3[1][0] += epsilon
93+
coord1[3][0] += epsilon
8494
coord3[2][1] += epsilon
95+
coord2[4][1] += epsilon
8596
test_spin = getattr(self, "test_spin", False)
8697
if not test_spin:
8798
test_keys = ["energy", "force", "virial"]
@@ -226,6 +237,17 @@ def setUp(self) -> None:
226237
self.epsilon, self.aprec = None, None
227238

228239

240+
class TestEnergyModelDPA3(unittest.TestCase, SmoothTest):
241+
def setUp(self) -> None:
242+
model_params = copy.deepcopy(model_dpa3)
243+
self.type_split = True
244+
self.model = get_model(model_params).to(env.DEVICE)
245+
# less degree of smoothness,
246+
# error can be systematically removed by reducing epsilon
247+
self.epsilon = 1e-5
248+
self.aprec = 1e-5
249+
250+
229251
class TestEnergyModelHybrid(unittest.TestCase, SmoothTest):
230252
def setUp(self) -> None:
231253
model_params = copy.deepcopy(model_hybrid)

0 commit comments

Comments
 (0)