Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
seed: Optional[Union[int, list[int]]] = None,
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
use_loc_mapping: bool = True,
type_map: Optional[list[str]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -153,13 +154,15 @@ def init_subclass_params(sub_data, sub_class):
use_exp_switch=self.repflow_args.use_exp_switch,
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
use_loc_mapping=use_loc_mapping,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
seed=child_seed(seed, 1),
)

self.use_econf_tebd = use_econf_tebd
self.use_loc_mapping = use_loc_mapping
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
self.tebd_dim = self.repflow_args.n_dim
Expand Down Expand Up @@ -469,12 +472,16 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei

"""
parrallel_mode = comm_dict is not None
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3

node_ebd_ext = self.type_embedding(extended_atype)
if not parrallel_mode and self.use_loc_mapping:
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
else:
node_ebd_ext = self.type_embedding(extended_atype)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def optim_edge_update_dynamic(

def forward(
self,
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parrallel_mode
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
h2: torch.Tensor, # nf x nloc x nnei x 3
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim
Expand Down
41 changes: 28 additions & 13 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(
use_exp_switch: bool = False,
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
use_loc_mapping: bool = True,
optim_update: bool = True,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
Expand Down Expand Up @@ -416,9 +417,9 @@ def forward(
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
if comm_dict is None:
parrallel_mode = comm_dict is not None
if not parrallel_mode:
assert mapping is not None
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand Down Expand Up @@ -470,12 +471,9 @@ def forward(

# get node embedding
# [nframes, nloc, tebd_dim]
if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
else:
atype_embd = extended_atype_embd
assert extended_atype_embd is not None
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
assert isinstance(atype_embd, torch.Tensor) # for jit
node_ebd = self.act(atype_embd)
n_dim = node_ebd.shape[-1]
Expand All @@ -494,10 +492,22 @@ def forward(
cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6)
angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)

if not parrallel_mode and self.use_loc_mapping:
assert mapping is not None
# convert nlist from nall to nloc index
nlist = torch.gather(
mapping,
1,
index=nlist.reshape(nframes, -1),
).reshape(nlist.shape)
if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist, nlist_mask, a_nlist_mask, nall
nlist,
nlist_mask,
a_nlist_mask,
nall,
use_loc_mapping=self.use_loc_mapping,
)
# flat all the tensors
# n_edge x 1
Expand All @@ -524,18 +534,23 @@ def forward(
angle_ebd = self.angle_embd(angle_input)

# nb x nall x n_dim
if comm_dict is None:
if not parrallel_mode:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
)
for idx, ll in enumerate(self.layers):
# node_ebd: nb x nloc x n_dim
# node_ebd_ext: nb x nall x n_dim
if comm_dict is None:
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
if not parrallel_mode:
assert mapping is not None
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
node_ebd_ext = (
torch.gather(node_ebd, 1, mapping)
if not self.use_loc_mapping
else node_ebd
)
else:
assert comm_dict is not None
has_spin = "has_spin" in comm_dict
if not has_spin:
n_padding = nall - nloc
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_graph_index(
nlist_mask: torch.Tensor,
a_nlist_mask: torch.Tensor,
nall: int,
use_loc_mapping: bool = True,
):
"""
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
Expand Down Expand Up @@ -100,7 +101,9 @@ def get_graph_index(
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]

# node_ext(j) to edge(ij) index_select
frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * nall
frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * (
nall if not use_loc_mapping else nloc
)
shifted_nlist = nlist + frame_shift[:, None, None]
# n_edge
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]
Expand Down
8 changes: 8 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,14 @@ def descrpt_dpa3_args():
default=False,
doc=doc_use_tebd_bias,
),
Argument(
"use_loc_mapping",
bool,
optional=True,
default=True,
doc="Whether to use local atom index mapping in non-parallel inference. "
"When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.",
),
]


Expand Down
Loading
Loading