Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
skip_stat: bool = False,
optim_update: bool = True,
smooth_edge_update: bool = False,
use_loc_mapping: bool = True,
) -> None:
self.n_dim = n_dim
self.e_dim = e_dim
Expand All @@ -176,6 +177,7 @@ def __init__(
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.use_loc_mapping = use_loc_mapping

def __getitem__(self, key):
if hasattr(self, key):
Expand Down Expand Up @@ -207,6 +209,7 @@ def serialize(self) -> dict:
"fix_stat_std": self.fix_stat_std,
"optim_update": self.optim_update,
"smooth_edge_update": self.smooth_edge_update,
"use_loc_mapping": self.use_loc_mapping,
}

@classmethod
Expand Down Expand Up @@ -262,6 +265,7 @@ def __init__(
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
type_map: Optional[list[str]] = None,
use_loc_mapping: bool = True,
) -> None:
super().__init__()

Expand All @@ -275,6 +279,7 @@ def init_subclass_params(sub_data, sub_class):
f"Input args must be a {sub_class.__name__} class or a dict!"
)

self.use_loc_mapping = use_loc_mapping
self.repflow_args = init_subclass_params(repflow, RepFlowArgs)
self.activation_function = activation_function

Expand Down Expand Up @@ -307,6 +312,7 @@ def init_subclass_params(sub_data, sub_class):
env_protection=env_protection,
precision=precision,
seed=child_seed(seed, 1),
use_loc_mapping=use_loc_mapping,
)

self.use_econf_tebd = use_econf_tebd
Expand Down Expand Up @@ -544,6 +550,7 @@ def serialize(self) -> dict:
"use_tebd_bias": self.use_tebd_bias,
"type_map": self.type_map,
"type_embedding": self.type_embedding.serialize(),
"use_loc_mapping": self.use_loc_mapping,
}
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
optim_update: bool = True,
smooth_edge_update: bool = False,
seed: Optional[Union[int, list[int]]] = None,
use_loc_mapping: bool = True,
) -> None:
super().__init__()
self.e_rcut = float(e_rcut)
Expand Down Expand Up @@ -201,6 +202,7 @@ def __init__(
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.use_loc_mapping = use_loc_mapping

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -583,6 +585,7 @@ def serialize(self):
"repflow_layers": [layer.serialize() for layer in self.layers],
"env_mat_edge": self.env_mat_edge.serialize(),
"env_mat_angle": self.env_mat_angle.serialize(),
"use_loc_mapping": self.use_loc_mapping,
"@variables": {
"davg": to_numpy_array(self["davg"]),
"dstd": to_numpy_array(self["dstd"]),
Expand Down
22 changes: 16 additions & 6 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class DescrptDPA3(BaseDescriptor, torch.nn.Module):
Whether to use bias in the type embedding layer.
type_map : list[str], Optional
A list of strings. Give the name to each type of atoms.
use_loc_mapping : bool
Whether to use local mapping.
"""

def __init__(
Expand All @@ -109,6 +111,7 @@ def __init__(
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
type_map: Optional[list[str]] = None,
use_loc_mapping: bool = True,
) -> None:
super().__init__()

Expand All @@ -122,6 +125,7 @@ def init_subclass_params(sub_data, sub_class):
f"Input args must be a {sub_class.__name__} class or a dict!"
)

self.use_loc_mapping = use_loc_mapping
self.repflow_args = init_subclass_params(repflow, RepFlowArgs)
self.activation_function = activation_function

Expand Down Expand Up @@ -150,13 +154,15 @@ def init_subclass_params(sub_data, sub_class):
fix_stat_std=self.repflow_args.fix_stat_std,
optim_update=self.repflow_args.optim_update,
smooth_edge_update=self.repflow_args.smooth_edge_update,
use_loc_mapping=use_loc_mapping,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
seed=child_seed(seed, 1),
)

self.use_econf_tebd = use_econf_tebd
self.use_loc_mapping = use_loc_mapping
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
self.tebd_dim = self.repflow_args.n_dim
Expand Down Expand Up @@ -375,6 +381,7 @@ def serialize(self) -> dict:
"use_tebd_bias": self.use_tebd_bias,
"type_map": self.type_map,
"type_embedding": self.type_embedding.embedding.serialize(),
"use_loc_mapping": self.use_loc_mapping,
}
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
Expand Down Expand Up @@ -466,19 +473,22 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei

"""
parallel_mode = comm_dict is not None
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3

node_ebd_ext = self.type_embedding(extended_atype)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
# nall = extended_coord.view(nframes, -1).shape[1] // 3
if parallel_mode:
atype = extended_atype
else:
atype = extended_atype[:, :nloc]
node_ebd_inp = self.type_embedding(atype)
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward(
nlist,
extended_coord,
extended_atype,
node_ebd_ext,
node_ebd_inp,
mapping,
comm_dict=comm_dict,
)
Expand Down
26 changes: 11 additions & 15 deletions deepmd/pt/model/descriptor/repflow_layer.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you need to abort modifications in this file, which means to keep exact modifications in commit iProzd@28803d9 and I passed all the uts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you need to abort modifications in this file

Did you mean you'll submit another PR to this branch?

Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,8 @@ def optim_angle_update(
def optim_edge_update(
self,
node_ebd: torch.Tensor,
node_ebd_ext: torch.Tensor,
nei_node_ebd: torch.Tensor,
edge_ebd: torch.Tensor,
nlist: torch.Tensor,
feat: str = "node",
) -> torch.Tensor:
if feat == "node":
Expand All @@ -455,10 +454,8 @@ def optim_edge_update(

# nf * nloc * node/edge_dim
sub_node_update = torch.matmul(node_ebd, node)
# nf * nall * node/edge_dim
sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext)
# nf * nloc * nnei * node/edge_dim
sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist)
# nf * nloc * node/edge_dim
sub_node_ext_update = torch.matmul(nei_node_ebd, node_ext)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You missed a _make_nei_g1 here. But when you use locak mapping, it's more efficient to keep the old implementation.

# nf * nloc * nnei * node/edge_dim
sub_edge_update = torch.matmul(edge_ebd, edge)

Expand All @@ -469,7 +466,8 @@ def optim_edge_update(

def forward(
self,
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
node_ebd: torch.Tensor, # nf x nloc x n_dim
node_ebd_ext: Optional[torch.Tensor], # nf x nall x n_dim
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
h2: torch.Tensor, # nf x nloc x nnei x 3
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim
Expand Down Expand Up @@ -514,8 +512,6 @@ def forward(
Updated angle embedding.
"""
nb, nloc, nnei, _ = edge_ebd.shape
nall = node_ebd_ext.shape[1]
node_ebd = node_ebd_ext[:, :nloc, :]
assert (nb, nloc) == node_ebd.shape[:2]
assert (nb, nloc, nnei) == h2.shape[:3]
del a_nlist # may be used in the future
Expand All @@ -527,8 +523,10 @@ def forward(
# node self mlp
node_self_mlp = self.act(self.node_self_mlp(node_ebd))
n_update_list.append(node_self_mlp)

nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist)
if node_ebd_ext is not None:
nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist)
else:
nei_node_ebd = _make_nei_g1(node_ebd, nlist)

# node sym (grrg + drrd)
node_sym_list: list[torch.Tensor] = []
Expand Down Expand Up @@ -577,9 +575,8 @@ def forward(
node_edge_update = self.act(
self.optim_edge_update(
node_ebd,
node_ebd_ext,
nei_node_ebd,
edge_ebd,
nlist,
"node",
)
) * sw.unsqueeze(-1)
Expand All @@ -605,9 +602,8 @@ def forward(
edge_self_update = self.act(
self.optim_edge_update(
node_ebd,
node_ebd_ext,
nei_node_ebd,
edge_ebd,
nlist,
"edge",
)
)
Expand Down
32 changes: 18 additions & 14 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
precision: str = "float64",
fix_stat_std: float = 0.3,
smooth_edge_update: bool = False,
use_loc_mapping: bool = True,
optim_update: bool = True,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
Expand Down Expand Up @@ -215,6 +216,7 @@ def __init__(
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.use_loc_mapping = use_loc_mapping

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -373,13 +375,10 @@ def forward(
nlist: torch.Tensor,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
extended_atype_embd: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
if comm_dict is None:
assert mapping is not None
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand All @@ -402,14 +401,9 @@ def forward(
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)

atype_embd = extended_atype_embd
# [nframes, nloc, tebd_dim]
if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
else:
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
node_ebd = self.act(atype_embd)
n_dim = node_ebd.shape[-1]
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
Expand Down Expand Up @@ -459,17 +453,26 @@ def forward(
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x n_dim
if comm_dict is None:
if comm_dict is None or self.use_loc_mapping:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
)
node_ebd_ext = None
nlist = torch.gather(
mapping.reshape(nframes, -1),
1,
nlist.reshape(nframes, -1),
).reshape(nlist.shape)
for idx, ll in enumerate(self.layers):
# node_ebd: nb x nloc x n_dim
# node_ebd_ext: nb x nall x n_dim
if comm_dict is None:
assert mapping is not None
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
if self.use_loc_mapping:
node_ebd_ext = None
else:
assert mapping is not None
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
else:
has_spin = "has_spin" in comm_dict
if not has_spin:
Expand Down Expand Up @@ -528,6 +531,7 @@ def forward(
node_ebd_real_ext, node_ebd_virtual_ext, real_nloc
)
node_ebd, edge_ebd, angle_ebd = ll.forward(
node_ebd,
node_ebd_ext,
edge_ebd,
h2,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,12 @@ def descrpt_dpa3_args():
default=False,
doc=doc_use_tebd_bias,
),
Argument(
"use_loc_mapping",
bool,
optional=True,
default=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a docstr for use_loc_mapping

),
]


Expand Down
Loading
Loading