Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
optim_update: bool = True,
smooth_edge_update: bool = False,
seed: Optional[Union[int, list[int]]] = None,
use_ext_ebd: bool = False,
) -> None:
super().__init__()
self.e_rcut = float(e_rcut)
Expand Down Expand Up @@ -201,6 +202,7 @@ def __init__(
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.use_ext_ebd = use_ext_ebd

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down
21 changes: 14 additions & 7 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class DescrptDPA3(BaseDescriptor, torch.nn.Module):
Whether to use bias in the type embedding layer.
type_map : list[str], Optional
A list of strings. Give the name to each type of atoms.
use_ext_ebd : bool
Whether to use extended embedding.
"""

def __init__(
Expand All @@ -109,6 +111,7 @@ def __init__(
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
type_map: Optional[list[str]] = None,
use_ext_ebd: bool = False,
) -> None:
super().__init__()

Expand All @@ -121,7 +124,7 @@ def init_subclass_params(sub_data, sub_class):
raise ValueError(
f"Input args must be a {sub_class.__name__} class or a dict!"
)

self.use_ext_ebd = use_ext_ebd
self.repflow_args = init_subclass_params(repflow, RepFlowArgs)
self.activation_function = activation_function

Expand Down Expand Up @@ -154,6 +157,7 @@ def init_subclass_params(sub_data, sub_class):
env_protection=env_protection,
precision=precision,
seed=child_seed(seed, 1),
use_ext_ebd=use_ext_ebd,
)

self.use_econf_tebd = use_econf_tebd
Expand Down Expand Up @@ -375,6 +379,7 @@ def serialize(self) -> dict:
"use_tebd_bias": self.use_tebd_bias,
"type_map": self.type_map,
"type_embedding": self.type_embedding.embedding.serialize(),
"use_ext_ebd": self.use_ext_ebd,
}
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
Expand Down Expand Up @@ -469,16 +474,18 @@ def forward(
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3

node_ebd_ext = self.type_embedding(extended_atype)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
# nall = extended_coord.view(nframes, -1).shape[1] // 3
if comm_dict is None or self.use_ext_ebd:
atype = extended_atype[:, :nloc]
else:
atype = extended_atype
node_ebd_inp = self.type_embedding(atype)
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward(
nlist,
extended_coord,
extended_atype,
node_ebd_ext,
node_ebd_inp,
mapping,
comm_dict=comm_dict,
)
Expand Down
26 changes: 11 additions & 15 deletions deepmd/pt/model/descriptor/repflow_layer.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you need to abort modifications in this file, which means to keep exact modifications in commit iProzd@28803d9 and I passed all the uts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you need to abort modifications in this file

Did you mean you'll submit another PR to this branch?

Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,8 @@ def optim_angle_update(
def optim_edge_update(
self,
node_ebd: torch.Tensor,
node_ebd_ext: torch.Tensor,
nei_node_ebd: torch.Tensor,
edge_ebd: torch.Tensor,
nlist: torch.Tensor,
feat: str = "node",
) -> torch.Tensor:
if feat == "node":
Expand All @@ -455,10 +454,8 @@ def optim_edge_update(

# nf * nloc * node/edge_dim
sub_node_update = torch.matmul(node_ebd, node)
# nf * nall * node/edge_dim
sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext)
# nf * nloc * nnei * node/edge_dim
sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist)
# nf * nloc * node/edge_dim
sub_node_ext_update = torch.matmul(nei_node_ebd, node_ext)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You missed a _make_nei_g1 here. But when you use locak mapping, it's more efficient to keep the old implementation.

# nf * nloc * nnei * node/edge_dim
sub_edge_update = torch.matmul(edge_ebd, edge)

Expand All @@ -469,7 +466,8 @@ def optim_edge_update(

def forward(
self,
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
node_ebd: torch.Tensor, # nf x nloc x n_dim
node_ebd_ext: Optional[torch.Tensor], # nf x nall x n_dim
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
h2: torch.Tensor, # nf x nloc x nnei x 3
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim
Expand Down Expand Up @@ -514,8 +512,6 @@ def forward(
Updated angle embedding.
"""
nb, nloc, nnei, _ = edge_ebd.shape
nall = node_ebd_ext.shape[1]
node_ebd = node_ebd_ext[:, :nloc, :]
assert (nb, nloc) == node_ebd.shape[:2]
assert (nb, nloc, nnei) == h2.shape[:3]
del a_nlist # may be used in the future
Expand All @@ -527,8 +523,10 @@ def forward(
# node self mlp
node_self_mlp = self.act(self.node_self_mlp(node_ebd))
n_update_list.append(node_self_mlp)

nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist)
if node_ebd_ext is not None:
nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist)
else:
nei_node_ebd = _make_nei_g1(node_ebd, nlist)

# node sym (grrg + drrd)
node_sym_list: list[torch.Tensor] = []
Expand Down Expand Up @@ -577,9 +575,8 @@ def forward(
node_edge_update = self.act(
self.optim_edge_update(
node_ebd,
node_ebd_ext,
nei_node_ebd,
edge_ebd,
nlist,
"node",
)
) * sw.unsqueeze(-1)
Expand All @@ -605,9 +602,8 @@ def forward(
edge_self_update = self.act(
self.optim_edge_update(
node_ebd,
node_ebd_ext,
nei_node_ebd,
edge_ebd,
nlist,
"edge",
)
)
Expand Down
31 changes: 18 additions & 13 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
smooth_edge_update: bool = False,
optim_update: bool = True,
seed: Optional[Union[int, list[int]]] = None,
use_ext_ebd: bool = False,
) -> None:
super().__init__()
self.e_rcut = float(e_rcut)
Expand Down Expand Up @@ -215,6 +216,7 @@ def __init__(
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.use_ext_ebd = use_ext_ebd

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -373,13 +375,10 @@ def forward(
nlist: torch.Tensor,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
extended_atype_embd: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
if comm_dict is None:
assert mapping is not None
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand All @@ -402,14 +401,9 @@ def forward(
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)

atype_embd = extended_atype_embd
# [nframes, nloc, tebd_dim]
if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
else:
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
node_ebd = self.act(atype_embd)
n_dim = node_ebd.shape[-1]
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
Expand Down Expand Up @@ -459,18 +453,28 @@ def forward(
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x n_dim
if comm_dict is None:
if self.use_ext_ebd:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
)
if comm_dict is None or self.use_ext_ebd:
assert mapping is not None
node_ebd_ext = None
nlist = torch.gather(
mapping,
1,
nlist.reshape(nframes, -1),
).reshape(nlist.shape)
for idx, ll in enumerate(self.layers):
# node_ebd: nb x nloc x n_dim
# node_ebd_ext: nb x nall x n_dim
if comm_dict is None:
if self.use_ext_ebd or comm_dict is not None:
assert mapping is not None
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
else:
node_ebd_ext = None
if comm_dict is not None:
has_spin = "has_spin" in comm_dict
if not has_spin:
n_padding = nall - nloc
Expand Down Expand Up @@ -528,6 +532,7 @@ def forward(
node_ebd_real_ext, node_ebd_virtual_ext, real_nloc
)
node_ebd, edge_ebd, angle_ebd = ll.forward(
node_ebd,
node_ebd_ext,
edge_ebd,
h2,
Expand Down
1 change: 1 addition & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,7 @@ def descrpt_dpa3_args():
default=False,
doc=doc_use_tebd_bias,
),
Argument("use_ext_ebd", bool, optional=True, default=False),
]


Expand Down
3 changes: 3 additions & 0 deletions source/tests/pt/model/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_consistency(
nme,
prec,
ect,
use_ext_ebd,
) in itertools.product(
[True, False], # update_angle
["res_residual"], # update_style
Expand All @@ -65,6 +66,7 @@ def test_consistency(
[1, 2], # n_multi_edge_message
["float64"], # precision
[False], # use_econf_tebd
[False, True], # use_ext_ebd
):
dtype = PRECISION_DICT[prec]
rtol, atol = get_tols(prec)
Expand Down Expand Up @@ -103,6 +105,7 @@ def test_consistency(
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
seed=GLOBAL_SEED,
use_ext_ebd=use_ext_ebd,
).to(env.DEVICE)

dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
Expand Down
Loading