Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ class DescrptDPA3(NativeOP, BaseDescriptor):
Whether to use electronic configuration type embedding.
use_tebd_bias : bool, Optional
Whether to use bias in the type embedding layer.
use_loc_mapping : bool, Optional
Whether to use local atom index mapping in non-parallel inference.
type_map : list[str], Optional
A list of strings. Give the name to each type of atoms.
"""
Expand All @@ -290,6 +292,7 @@ def __init__(
seed: Optional[Union[int, list[int]]] = None,
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
use_loc_mapping: bool = True,
type_map: Optional[list[str]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -335,6 +338,7 @@ def init_subclass_params(sub_data, sub_class):
use_exp_switch=self.repflow_args.use_exp_switch,
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
use_loc_mapping=use_loc_mapping,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
Expand All @@ -343,6 +347,7 @@ def init_subclass_params(sub_data, sub_class):

self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.use_loc_mapping = use_loc_mapping
self.type_map = type_map
self.tebd_dim = self.repflow_args.n_dim
self.type_embedding = TypeEmbedNet(
Expand Down Expand Up @@ -541,10 +546,16 @@ def call(
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3

type_embedding = self.type_embedding.call()
node_ebd_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nframes, nall, self.tebd_dim),
)
if self.use_loc_mapping:
node_ebd_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], [-1]), axis=0),
(nframes, nloc, self.tebd_dim),
)
else:
node_ebd_ext = xp.reshape(
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
(nframes, nall, self.tebd_dim),
)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
Expand Down
20 changes: 18 additions & 2 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
accommodating larger selection numbers.
use_loc_mapping : bool, optional
Whether to use local atom index mapping in non-parallel inference.
ntypes : int
Number of element types
activation_function : str, optional
Expand Down Expand Up @@ -196,6 +198,7 @@ def __init__(
use_exp_switch: bool = False,
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
use_loc_mapping: bool = True,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -229,6 +232,7 @@ def __init__(
self.smooth_edge_update = smooth_edge_update
self.use_exp_switch = use_exp_switch
self.use_dynamic_sel = use_dynamic_sel
self.use_loc_mapping = use_loc_mapping
self.sel_reduce_factor = sel_reduce_factor
if self.use_dynamic_sel and not self.smooth_edge_update:
raise NotImplementedError(
Expand Down Expand Up @@ -527,10 +531,18 @@ def call(
cosine_ij, (nframes, nloc, self.a_sel, self.a_sel, 1)
) / (xp.pi**0.5)

if self.use_loc_mapping:
assert mapping is not None
flat_map = xp.reshape(mapping, (nframes, -1))
nlist = xp.reshape(
xp_take_along_axis(flat_map, xp.reshape(nlist, (nframes, -1)), axis=1),
nlist.shape,
)

if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist, nlist_mask, a_nlist_mask, nall
nlist, nlist_mask, a_nlist_mask, nall, use_loc_mapping=self.use_loc_mapping
)
# flat all the tensors
# n_edge x 1
Expand Down Expand Up @@ -561,7 +573,11 @@ def call(
for idx, ll in enumerate(self.layers):
# node_ebd: nb x nloc x n_dim
# node_ebd_ext: nb x nall x n_dim
node_ebd_ext = xp_take_along_axis(node_ebd, mapping, axis=1)
node_ebd_ext = (
node_ebd
if self.use_loc_mapping
else xp_take_along_axis(node_ebd, mapping, axis=1)
)
node_ebd, edge_ebd, angle_ebd = ll.call(
node_ebd_ext,
edge_ebd,
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ def get_graph_index(
nlist_mask: np.ndarray,
a_nlist_mask: np.ndarray,
nall: int,
use_loc_mapping: bool = True,
):
"""
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
Expand All @@ -1020,6 +1021,8 @@ def get_graph_index(
Masks of the neighbor list for angle. real nei 1 otherwise 0
nall
The number of extended atoms.
use_loc_mapping
Whether to use local atom index mapping in non-parallel inference.

Returns
-------
Expand Down Expand Up @@ -1060,7 +1063,7 @@ def get_graph_index(
n2e_index = n2e_index[xp.astype(nlist_mask, xp.bool)]

# node_ext(j) to edge(ij) index_select
frame_shift = xp.arange(nf, dtype=nlist.dtype) * nall
frame_shift = xp.arange(nf, dtype=nlist.dtype) * (nall if not use_loc_mapping else nloc)
shifted_nlist = nlist + frame_shift[:, xp.newaxis, xp.newaxis]
# n_edge
n_ext2e_index = shifted_nlist[xp.astype(nlist_mask, xp.bool)]
Expand Down
11 changes: 10 additions & 1 deletion deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class DescrptDPA3(BaseDescriptor, torch.nn.Module):
Whether to use electronic configuration type embedding.
use_tebd_bias : bool, Optional
Whether to use bias in the type embedding layer.
use_loc_mapping : bool, Optional
Whether to use local atom index mapping in non-parallel inference.
type_map : list[str], Optional
A list of strings. Give the name to each type of atoms.
"""
Expand All @@ -108,6 +110,7 @@ def __init__(
seed: Optional[Union[int, list[int]]] = None,
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
use_loc_mapping: bool = True,
type_map: Optional[list[str]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -153,13 +156,15 @@ def init_subclass_params(sub_data, sub_class):
use_exp_switch=self.repflow_args.use_exp_switch,
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
use_loc_mapping=use_loc_mapping,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
seed=child_seed(seed, 1),
)

self.use_econf_tebd = use_econf_tebd
self.use_loc_mapping = use_loc_mapping
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
self.tebd_dim = self.repflow_args.n_dim
Expand Down Expand Up @@ -469,12 +474,16 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei

"""
parallel_mode = comm_dict is not None
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3

node_ebd_ext = self.type_embedding(extended_atype)
if not parallel_mode and self.use_loc_mapping:
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
else:
node_ebd_ext = self.type_embedding(extended_atype)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def optim_edge_update_dynamic(

def forward(
self,
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
h2: torch.Tensor, # nf x nloc x nnei x 3
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim
Expand Down
42 changes: 29 additions & 13 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(
use_exp_switch: bool = False,
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
use_loc_mapping: bool = True,
optim_update: bool = True,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
Expand Down Expand Up @@ -239,6 +240,7 @@ def __init__(
self.fix_stat_std = fix_stat_std
self.set_stddev_constant = fix_stat_std != 0.0
self.a_compress_use_split = a_compress_use_split
self.use_loc_mapping = use_loc_mapping
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.use_exp_switch = use_exp_switch
Expand Down Expand Up @@ -416,9 +418,9 @@ def forward(
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
if comm_dict is None:
parallel_mode = comm_dict is not None
if not parallel_mode:
assert mapping is not None
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand Down Expand Up @@ -470,12 +472,9 @@ def forward(

# get node embedding
# [nframes, nloc, tebd_dim]
if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
else:
atype_embd = extended_atype_embd
assert extended_atype_embd is not None
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
assert isinstance(atype_embd, torch.Tensor) # for jit
node_ebd = self.act(atype_embd)
n_dim = node_ebd.shape[-1]
Expand All @@ -494,10 +493,22 @@ def forward(
cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6)
angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)

if not parallel_mode and self.use_loc_mapping:
assert mapping is not None
# convert nlist from nall to nloc index
nlist = torch.gather(
mapping,
1,
index=nlist.reshape(nframes, -1),
).reshape(nlist.shape)
if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist, nlist_mask, a_nlist_mask, nall
nlist,
nlist_mask,
a_nlist_mask,
nall,
use_loc_mapping=self.use_loc_mapping,
)
# flat all the tensors
# n_edge x 1
Expand All @@ -524,18 +535,23 @@ def forward(
angle_ebd = self.angle_embd(angle_input)

# nb x nall x n_dim
if comm_dict is None:
if not parallel_mode:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
)
for idx, ll in enumerate(self.layers):
# node_ebd: nb x nloc x n_dim
# node_ebd_ext: nb x nall x n_dim
if comm_dict is None:
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode
if not parallel_mode:
assert mapping is not None
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
node_ebd_ext = (
torch.gather(node_ebd, 1, mapping)
if not self.use_loc_mapping
else node_ebd
)
else:
assert comm_dict is not None
has_spin = "has_spin" in comm_dict
if not has_spin:
n_padding = nall - nloc
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_graph_index(
nlist_mask: torch.Tensor,
a_nlist_mask: torch.Tensor,
nall: int,
use_loc_mapping: bool = True,
):
"""
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
Expand Down Expand Up @@ -100,7 +101,9 @@ def get_graph_index(
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]

# node_ext(j) to edge(ij) index_select
frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * nall
frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * (
nall if not use_loc_mapping else nloc
)
shifted_nlist = nlist + frame_shift[:, None, None]
# n_edge
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]
Expand Down
8 changes: 8 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,14 @@ def descrpt_dpa3_args():
default=False,
doc=doc_use_tebd_bias,
),
Argument(
"use_loc_mapping",
bool,
optional=True,
default=True,
doc="Whether to use local atom index mapping in non-parallel inference. "
"When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.",
),
]


Expand Down
Loading
Loading