Skip to content

Commit bad74ab

Browse files
committed
fix(dpmodel/pt/pd/jax): pass trainable to layer & support JAX trainable
1. For dpmodel, pt, and pd, pass the trainable parameter to the layer (not actually used in this PR). 2. For JAX, support the `trainable` parameter in the layer. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent ab6e300 commit bad74ab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+968
-529
lines changed

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def __init__(
319319
trainable_ln=trainable_ln,
320320
ln_eps=ln_eps,
321321
seed=child_seed(seed, 0),
322+
trainable=trainable,
322323
)
323324
self.use_econf_tebd = use_econf_tebd
324325
self.use_tebd_bias = use_tebd_bias
@@ -333,6 +334,7 @@ def __init__(
333334
use_tebd_bias=use_tebd_bias,
334335
type_map=type_map,
335336
seed=child_seed(seed, 1),
337+
trainable=trainable,
336338
)
337339
self.tebd_dim = tebd_dim
338340
self.concat_output_tebd = concat_output_tebd
@@ -691,6 +693,7 @@ def __init__(
691693
ln_eps: Optional[float] = 1e-5,
692694
smooth: bool = True,
693695
seed: Optional[Union[int, list[int]]] = None,
696+
trainable: bool = True,
694697
) -> None:
695698
self.rcut = rcut
696699
self.rcut_smth = rcut_smth
@@ -741,6 +744,7 @@ def __init__(
741744
self.resnet_dt,
742745
self.precision,
743746
seed=child_seed(seed, 0),
747+
trainable=trainable,
744748
)
745749
self.embeddings = embeddings
746750
if self.tebd_input_mode in ["strip"]:
@@ -756,6 +760,7 @@ def __init__(
756760
self.resnet_dt,
757761
self.precision,
758762
seed=child_seed(seed, 1),
763+
trainable=trainable,
759764
)
760765
self.embeddings_strip = embeddings_strip
761766
else:
@@ -774,6 +779,7 @@ def __init__(
774779
smooth=self.smooth,
775780
precision=self.precision,
776781
seed=child_seed(seed, 2),
782+
trainable=trainable,
777783
)
778784

779785
wanted_shape = (self.ntypes, self.nnei, 4)
@@ -1186,6 +1192,7 @@ def __init__(
11861192
smooth: bool = True,
11871193
precision: str = DEFAULT_PRECISION,
11881194
seed: Optional[Union[int, list[int]]] = None,
1195+
trainable: bool = True,
11891196
) -> None:
11901197
"""Construct a neighbor-wise attention net."""
11911198
super().__init__()
@@ -1219,6 +1226,7 @@ def __init__(
12191226
smooth=smooth,
12201227
precision=precision,
12211228
seed=child_seed(seed, ii),
1229+
trainable=trainable,
12221230
)
12231231
for ii in range(layer_num)
12241232
]
@@ -1314,6 +1322,7 @@ def __init__(
13141322
smooth: bool = True,
13151323
precision: str = DEFAULT_PRECISION,
13161324
seed: Optional[Union[int, list[int]]] = None,
1325+
trainable: bool = True,
13171326
) -> None:
13181327
"""Construct a neighbor-wise attention layer."""
13191328
super().__init__()
@@ -1340,6 +1349,7 @@ def __init__(
13401349
smooth=smooth,
13411350
precision=precision,
13421351
seed=child_seed(seed, 0),
1352+
trainable=trainable,
13431353
)
13441354
self.attn_layer_norm = LayerNorm(
13451355
self.embed_dim,
@@ -1420,6 +1430,7 @@ def __init__(
14201430
smooth: bool = True,
14211431
precision: str = DEFAULT_PRECISION,
14221432
seed: Optional[Union[int, list[int]]] = None,
1433+
trainable: bool = True,
14231434
) -> None:
14241435
"""Construct a multi-head neighbor-wise attention net."""
14251436
super().__init__()
@@ -1449,6 +1460,7 @@ def __init__(
14491460
use_timestep=False,
14501461
precision=precision,
14511462
seed=child_seed(seed, 0),
1463+
trainable=trainable,
14521464
)
14531465
self.out_proj = NativeLayer(
14541466
hidden_dim,
@@ -1457,6 +1469,7 @@ def __init__(
14571469
use_timestep=False,
14581470
precision=precision,
14591471
seed=child_seed(seed, 1),
1472+
trainable=trainable,
14601473
)
14611474

14621475
def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0):

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def init_subclass_params(sub_data, sub_class):
474474
smooth=smooth,
475475
type_one_side=self.repinit_args.type_one_side,
476476
seed=child_seed(seed, 0),
477+
trainable=trainable,
477478
)
478479
self.use_three_body = self.repinit_args.use_three_body
479480
if self.use_three_body:
@@ -493,6 +494,7 @@ def init_subclass_params(sub_data, sub_class):
493494
resnet_dt=self.repinit_args.resnet_dt,
494495
smooth=smooth,
495496
seed=child_seed(seed, 5),
497+
trainable=trainable,
496498
)
497499
else:
498500
self.repinit_three_body = None
@@ -533,6 +535,7 @@ def init_subclass_params(sub_data, sub_class):
533535
g1_out_mlp=self.repformer_args.g1_out_mlp,
534536
ln_eps=self.repformer_args.ln_eps,
535537
seed=child_seed(seed, 1),
538+
trainable=trainable,
536539
)
537540
self.rcsl_list = [
538541
(self.repformers.get_rcut(), self.repformers.get_nsel()),
@@ -562,6 +565,7 @@ def init_subclass_params(sub_data, sub_class):
562565
use_tebd_bias=use_tebd_bias,
563566
type_map=type_map,
564567
seed=child_seed(seed, 2),
568+
trainable=trainable,
565569
)
566570
self.concat_output_tebd = concat_output_tebd
567571
self.precision = precision
@@ -585,6 +589,7 @@ def init_subclass_params(sub_data, sub_class):
585589
bias=False,
586590
precision=precision,
587591
seed=child_seed(seed, 3),
592+
trainable=trainable,
588593
)
589594
self.tebd_transform = None
590595
if self.add_tebd_to_repinit_out:
@@ -594,6 +599,7 @@ def init_subclass_params(sub_data, sub_class):
594599
bias=False,
595600
precision=precision,
596601
seed=child_seed(seed, 4),
602+
trainable=trainable,
597603
)
598604
assert self.repinit.rcut > self.repformers.rcut
599605
assert self.repinit.sel[0] > self.repformers.sel[0]

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def init_subclass_params(sub_data, sub_class):
357357
env_protection=env_protection,
358358
precision=precision,
359359
seed=child_seed(seed, 1),
360+
trainable=trainable,
360361
)
361362

362363
self.use_econf_tebd = use_econf_tebd
@@ -374,6 +375,7 @@ def init_subclass_params(sub_data, sub_class):
374375
use_tebd_bias=use_tebd_bias,
375376
type_map=type_map,
376377
seed=child_seed(seed, 2),
378+
trainable=trainable,
377379
)
378380
self.concat_output_tebd = concat_output_tebd
379381
self.precision = precision

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
167167
For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection.
168168
seed : int, optional
169169
Random seed for parameter initialization.
170+
trainable : bool, default: True
171+
Whether the block is trainable
170172
"""
171173

172174
def __init__(
@@ -205,6 +207,7 @@ def __init__(
205207
sel_reduce_factor: float = 10.0,
206208
use_loc_mapping: bool = True,
207209
seed: Optional[Union[int, list[int]]] = None,
210+
trainable: bool = True,
208211
) -> None:
209212
super().__init__()
210213
self.e_rcut = float(e_rcut)
@@ -269,10 +272,19 @@ def __init__(
269272
self.seed = seed
270273

271274
self.edge_embd = NativeLayer(
272-
1, self.e_dim, precision=precision, seed=child_seed(seed, 0)
275+
1,
276+
self.e_dim,
277+
precision=precision,
278+
seed=child_seed(seed, 0),
279+
trainable=trainable,
273280
)
274281
self.angle_embd = NativeLayer(
275-
1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1)
282+
1,
283+
self.a_dim,
284+
precision=precision,
285+
bias=False,
286+
seed=child_seed(seed, 1),
287+
trainable=trainable,
276288
)
277289
layers = []
278290
for ii in range(nlayers):
@@ -304,6 +316,7 @@ def __init__(
304316
sel_reduce_factor=self.sel_reduce_factor,
305317
smooth_edge_update=self.smooth_edge_update,
306318
seed=child_seed(child_seed(seed, 1), ii),
319+
trainable=trainable,
307320
)
308321
)
309322
self.layers = layers
@@ -860,6 +873,7 @@ def __init__(
860873
update_residual_init: str = "const",
861874
precision: str = "float64",
862875
seed: Optional[Union[int, list[int]]] = None,
876+
trainable: bool = True,
863877
) -> None:
864878
super().__init__()
865879
self.epsilon = 1e-4 # protection of 1./nnei
@@ -922,6 +936,7 @@ def __init__(
922936
n_dim,
923937
precision=precision,
924938
seed=child_seed(seed, 0),
939+
trainable=trainable,
925940
)
926941
if self.update_style == "res_residual":
927942
self.n_residual.append(
@@ -931,6 +946,7 @@ def __init__(
931946
self.update_residual_init,
932947
precision=precision,
933948
seed=child_seed(seed, 1),
949+
trainable=trainable,
934950
)
935951
)
936952

@@ -941,6 +957,7 @@ def __init__(
941957
n_dim,
942958
precision=precision,
943959
seed=child_seed(seed, 2),
960+
trainable=trainable,
944961
)
945962
if self.update_style == "res_residual":
946963
self.n_residual.append(
@@ -950,6 +967,7 @@ def __init__(
950967
self.update_residual_init,
951968
precision=precision,
952969
seed=child_seed(seed, 3),
970+
trainable=trainable,
953971
)
954972
)
955973

@@ -959,6 +977,7 @@ def __init__(
959977
self.n_multi_edge_message * n_dim,
960978
precision=precision,
961979
seed=child_seed(seed, 4),
980+
trainable=trainable,
962981
)
963982
if self.update_style == "res_residual":
964983
for head_index in range(self.n_multi_edge_message):
@@ -969,6 +988,7 @@ def __init__(
969988
self.update_residual_init,
970989
precision=precision,
971990
seed=child_seed(child_seed(seed, 5), head_index),
991+
trainable=trainable,
972992
)
973993
)
974994

@@ -978,6 +998,7 @@ def __init__(
978998
e_dim,
979999
precision=precision,
9801000
seed=child_seed(seed, 6),
1001+
trainable=trainable,
9811002
)
9821003
if self.update_style == "res_residual":
9831004
self.e_residual.append(
@@ -987,6 +1008,7 @@ def __init__(
9871008
self.update_residual_init,
9881009
precision=precision,
9891010
seed=child_seed(seed, 7),
1011+
trainable=trainable,
9901012
)
9911013
)
9921014

@@ -1015,13 +1037,15 @@ def __init__(
10151037
precision=precision,
10161038
bias=False,
10171039
seed=child_seed(seed, 8),
1040+
trainable=trainable,
10181041
)
10191042
self.a_compress_e_linear = NativeLayer(
10201043
self.e_dim,
10211044
self.e_a_compress_dim,
10221045
precision=precision,
10231046
bias=False,
10241047
seed=child_seed(seed, 9),
1048+
trainable=trainable,
10251049
)
10261050
else:
10271051
self.a_compress_n_linear = None
@@ -1033,12 +1057,14 @@ def __init__(
10331057
self.e_dim,
10341058
precision=precision,
10351059
seed=child_seed(seed, 10),
1060+
trainable=trainable,
10361061
)
10371062
self.edge_angle_linear2 = NativeLayer(
10381063
self.e_dim,
10391064
self.e_dim,
10401065
precision=precision,
10411066
seed=child_seed(seed, 11),
1067+
trainable=trainable,
10421068
)
10431069
if self.update_style == "res_residual":
10441070
self.e_residual.append(
@@ -1048,6 +1074,7 @@ def __init__(
10481074
self.update_residual_init,
10491075
precision=precision,
10501076
seed=child_seed(seed, 12),
1077+
trainable=trainable,
10511078
)
10521079
)
10531080

@@ -1057,6 +1084,7 @@ def __init__(
10571084
self.a_dim,
10581085
precision=precision,
10591086
seed=child_seed(seed, 13),
1087+
trainable=trainable,
10601088
)
10611089
if self.update_style == "res_residual":
10621090
self.a_residual.append(
@@ -1066,6 +1094,7 @@ def __init__(
10661094
self.update_residual_init,
10671095
precision=precision,
10681096
seed=child_seed(seed, 14),
1097+
trainable=trainable,
10691098
)
10701099
)
10711100
else:

0 commit comments

Comments
 (0)