Skip to content

Commit 527cb85

Browse files
committed
rename and add uts
1 parent e23dc5f commit 527cb85

File tree

9 files changed

+603
-581
lines changed

9 files changed

+603
-581
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def __init__(
1515
a_rcut_smth: float = 3.5,
1616
a_sel: int = 20,
1717
axis_neuron: int = 4,
18-
node_has_conv: bool = False,
1918
update_angle: bool = True,
2019
update_style: str = "res_residual",
2120
update_residual: float = 0.1,
@@ -73,7 +72,6 @@ def __init__(
7372
self.a_rcut_smth = a_rcut_smth
7473
self.a_sel = a_sel
7574
self.axis_neuron = axis_neuron
76-
self.node_has_conv = node_has_conv # tmp
7775
self.update_angle = update_angle
7876
self.update_style = update_style
7977
self.update_residual = update_residual
@@ -98,7 +96,6 @@ def serialize(self) -> dict:
9896
"a_rcut_smth": self.a_rcut_smth,
9997
"a_sel": self.a_sel,
10098
"axis_neuron": self.axis_neuron,
101-
"node_has_conv": self.node_has_conv, # tmp
10299
"update_angle": self.update_angle,
103100
"update_style": self.update_style,
104101
"update_residual": self.update_residual,

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def init_subclass_params(sub_data, sub_class):
152152
e_dim=self.repflow_args.e_dim,
153153
a_dim=self.repflow_args.a_dim,
154154
axis_neuron=self.repflow_args.axis_neuron,
155-
node_has_conv=self.repflow_args.node_has_conv,
156155
update_angle=self.repflow_args.update_angle,
157156
activation_function=self.activation_function,
158157
update_style=self.repflow_args.update_style,
@@ -299,7 +298,7 @@ def change_type_map(
299298
extend_descrpt_stat(
300299
repflow,
301300
type_map,
302-
des_with_stat=model_with_new_type_stat.repflow
301+
des_with_stat=model_with_new_type_stat.repflows
303302
if model_with_new_type_stat is not None
304303
else None,
305304
)
@@ -380,6 +379,7 @@ def serialize(self) -> dict:
380379
}
381380
repflow_variable = {
382381
"edge_embd": repflows.edge_embd.serialize(),
382+
"angle_embd": repflows.angle_embd.serialize(),
383383
"repflow_layers": [layer.serialize() for layer in repflows.layers],
384384
"env_mat": DPEnvMat(repflows.rcut, repflows.rcut_smth).serialize(),
385385
"@variables": {
@@ -417,6 +417,9 @@ def t_cvt(xx):
417417
env_mat = repflow_variable.pop("env_mat")
418418
repflow_layers = repflow_variable.pop("repflow_layers")
419419
obj.repflows.edge_embd = MLPLayer.deserialize(repflow_variable.pop("edge_embd"))
420+
obj.repflows.angle_embd = MLPLayer.deserialize(
421+
repflow_variable.pop("angle_embd")
422+
)
420423
obj.repflows["davg"] = t_cvt(statistic_repflows["davg"])
421424
obj.repflows["dstd"] = t_cvt(statistic_repflows["dstd"])
422425
obj.repflows.layers = torch.nn.ModuleList(
@@ -449,12 +452,12 @@ def forward(
449452
450453
Returns
451454
-------
452-
node_embd
455+
node_ebd
453456
The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim)
454457
rot_mat
455458
The rotationally equivariant and permutationally invariant single particle
456459
representation. shape: nf x nloc x e_dim x 3
457-
edge_embd
460+
edge_ebd
458461
The edge embedding.
459462
shape: nf x nloc x nnei x e_dim
460463
h2
@@ -469,23 +472,23 @@ def forward(
469472
nframes, nloc, nnei = nlist.shape
470473
nall = extended_coord.view(nframes, -1).shape[1] // 3
471474

472-
node_embd_ext = self.type_embedding(extended_atype)
473-
node_embd_inp = node_embd_ext[:, :nloc, :]
475+
node_ebd_ext = self.type_embedding(extended_atype)
476+
node_ebd_inp = node_ebd_ext[:, :nloc, :]
474477
# repflows
475-
node_embd, edge_embd, h2, rot_mat, sw = self.repflows(
478+
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
476479
nlist,
477480
extended_coord,
478481
extended_atype,
479-
node_embd_ext,
482+
node_ebd_ext,
480483
mapping,
481484
comm_dict=comm_dict,
482485
)
483486
if self.concat_output_tebd:
484-
node_embd = torch.cat([node_embd, node_embd_inp], dim=-1)
487+
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1)
485488
return (
486-
node_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
489+
node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
487490
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
488-
edge_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
491+
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
489492
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
490493
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
491494
)

0 commit comments

Comments
 (0)