diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 7d3292dff2..58c0997f64 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -150,6 +150,7 @@ def __init__( skip_stat: bool = False, optim_update: bool = True, smooth_edge_update: bool = False, + use_loc_mapping: bool = True, ) -> None: self.n_dim = n_dim self.e_dim = e_dim @@ -176,6 +177,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_loc_mapping = use_loc_mapping def __getitem__(self, key): if hasattr(self, key): @@ -207,6 +209,7 @@ def serialize(self) -> dict: "fix_stat_std": self.fix_stat_std, "optim_update": self.optim_update, "smooth_edge_update": self.smooth_edge_update, + "use_loc_mapping": self.use_loc_mapping, } @classmethod @@ -262,6 +265,7 @@ def __init__( use_econf_tebd: bool = False, use_tebd_bias: bool = False, type_map: Optional[list[str]] = None, + use_loc_mapping: bool = True, ) -> None: super().__init__() @@ -275,6 +279,7 @@ def init_subclass_params(sub_data, sub_class): f"Input args must be a {sub_class.__name__} class or a dict!" ) + self.use_loc_mapping = use_loc_mapping self.repflow_args = init_subclass_params(repflow, RepFlowArgs) self.activation_function = activation_function @@ -307,6 +312,7 @@ def init_subclass_params(sub_data, sub_class): env_protection=env_protection, precision=precision, seed=child_seed(seed, 1), + use_loc_mapping=use_loc_mapping, ) self.use_econf_tebd = use_econf_tebd @@ -544,6 +550,7 @@ def serialize(self) -> dict: "use_tebd_bias": self.use_tebd_bias, "type_map": self.type_map, "type_embedding": self.type_embedding.serialize(), + "use_loc_mapping": self.use_loc_mapping, } repflow_variable = { "edge_embd": repflows.edge_embd.serialize(), diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 7273ba8ebe..1b583cb30b 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -171,6 +171,7 @@ def __init__( optim_update: bool = True, smooth_edge_update: bool = False, seed: Optional[Union[int, list[int]]] = None, + use_loc_mapping: bool = True, ) -> None: super().__init__() self.e_rcut = float(e_rcut) @@ -201,6 +202,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_loc_mapping = use_loc_mapping self.n_dim = n_dim self.e_dim = e_dim @@ -583,6 +585,7 @@ def serialize(self): "repflow_layers": [layer.serialize() for layer in self.layers], "env_mat_edge": self.env_mat_edge.serialize(), "env_mat_angle": self.env_mat_angle.serialize(), + "use_loc_mapping": self.use_loc_mapping, "@variables": { "davg": to_numpy_array(self["davg"]), "dstd": to_numpy_array(self["dstd"]), diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 545da962e7..794bfcd5ee 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -91,6 +91,8 @@ class DescrptDPA3(BaseDescriptor, torch.nn.Module): Whether to use bias in the type embedding layer. type_map : list[str], Optional A list of strings. Give the name to each type of atoms. + use_loc_mapping : bool + Whether to use local mapping. """ def __init__( @@ -109,6 +111,7 @@ def __init__( use_econf_tebd: bool = False, use_tebd_bias: bool = False, type_map: Optional[list[str]] = None, + use_loc_mapping: bool = True, ) -> None: super().__init__() @@ -122,6 +125,7 @@ def init_subclass_params(sub_data, sub_class): f"Input args must be a {sub_class.__name__} class or a dict!" ) + self.use_loc_mapping = use_loc_mapping self.repflow_args = init_subclass_params(repflow, RepFlowArgs) self.activation_function = activation_function @@ -150,6 +154,7 @@ def init_subclass_params(sub_data, sub_class): fix_stat_std=self.repflow_args.fix_stat_std, optim_update=self.repflow_args.optim_update, smooth_edge_update=self.repflow_args.smooth_edge_update, + use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, precision=precision, @@ -157,6 +162,7 @@ def init_subclass_params(sub_data, sub_class): ) self.use_econf_tebd = use_econf_tebd + self.use_loc_mapping = use_loc_mapping self.use_tebd_bias = use_tebd_bias self.type_map = type_map self.tebd_dim = self.repflow_args.n_dim @@ -375,6 +381,7 @@ def serialize(self) -> dict: "use_tebd_bias": self.use_tebd_bias, "type_map": self.type_map, "type_embedding": self.type_embedding.embedding.serialize(), + "use_loc_mapping": self.use_loc_mapping, } repflow_variable = { "edge_embd": repflows.edge_embd.serialize(), @@ -466,19 +473,22 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + parallel_mode = comm_dict is not None # cast the input to internal precsion extended_coord = extended_coord.to(dtype=self.prec) nframes, nloc, nnei = nlist.shape - nall = extended_coord.view(nframes, -1).shape[1] // 3 - - node_ebd_ext = self.type_embedding(extended_atype) - node_ebd_inp = node_ebd_ext[:, :nloc, :] + # nall = extended_coord.view(nframes, -1).shape[1] // 3 + if parallel_mode: + atype = extended_atype + else: + atype = extended_atype[:, :nloc] + node_ebd_inp = self.type_embedding(atype) # repflows - node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( + node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward( nlist, extended_coord, extended_atype, - node_ebd_ext, + node_ebd_inp, mapping, comm_dict=comm_dict, ) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index f109109cfd..5e8bad7319 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -435,9 +435,8 @@ def optim_angle_update( def optim_edge_update( self, node_ebd: torch.Tensor, - node_ebd_ext: torch.Tensor, + nei_node_ebd: torch.Tensor, edge_ebd: torch.Tensor, - nlist: torch.Tensor, feat: str = "node", ) -> torch.Tensor: if feat == "node": @@ -455,10 +454,8 @@ def optim_edge_update( # nf * nloc * node/edge_dim sub_node_update = torch.matmul(node_ebd, node) - # nf * nall * node/edge_dim - sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext) - # nf * nloc * nnei * node/edge_dim - sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist) + # nf * nloc * node/edge_dim + sub_node_ext_update = torch.matmul(nei_node_ebd, node_ext) # nf * nloc * nnei * node/edge_dim sub_edge_update = torch.matmul(edge_ebd, edge) @@ -469,7 +466,8 @@ def optim_edge_update( def forward( self, - node_ebd_ext: torch.Tensor, # nf x nall x n_dim + node_ebd: torch.Tensor, # nf x nloc x n_dim + node_ebd_ext: Optional[torch.Tensor], # nf x nall x n_dim edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim h2: torch.Tensor, # nf x nloc x nnei x 3 angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim @@ -514,8 +512,6 @@ def forward( Updated angle embedding. """ nb, nloc, nnei, _ = edge_ebd.shape - nall = node_ebd_ext.shape[1] - node_ebd = node_ebd_ext[:, :nloc, :] assert (nb, nloc) == node_ebd.shape[:2] assert (nb, nloc, nnei) == h2.shape[:3] del a_nlist # may be used in the future @@ -527,8 +523,10 @@ def forward( # node self mlp node_self_mlp = self.act(self.node_self_mlp(node_ebd)) n_update_list.append(node_self_mlp) - - nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) + if node_ebd_ext is not None: + nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) + else: + nei_node_ebd = _make_nei_g1(node_ebd, nlist) # node sym (grrg + drrd) node_sym_list: list[torch.Tensor] = [] @@ -577,9 +575,8 @@ def forward( node_edge_update = self.act( self.optim_edge_update( node_ebd, - node_ebd_ext, + nei_node_ebd, edge_ebd, - nlist, "node", ) ) * sw.unsqueeze(-1) @@ -605,9 +602,8 @@ def forward( edge_self_update = self.act( self.optim_edge_update( node_ebd, - node_ebd_ext, + nei_node_ebd, edge_ebd, - nlist, "edge", ) ) diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 330336b1de..9a078645b9 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -183,6 +183,7 @@ def __init__( precision: str = "float64", fix_stat_std: float = 0.3, smooth_edge_update: bool = False, + use_loc_mapping: bool = True, optim_update: bool = True, seed: Optional[Union[int, list[int]]] = None, ) -> None: @@ -215,6 +216,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_loc_mapping = use_loc_mapping self.n_dim = n_dim self.e_dim = e_dim @@ -373,13 +375,10 @@ def forward( nlist: torch.Tensor, extended_coord: torch.Tensor, extended_atype: torch.Tensor, - extended_atype_embd: Optional[torch.Tensor] = None, + extended_atype_embd: torch.Tensor, mapping: Optional[torch.Tensor] = None, comm_dict: Optional[dict[str, torch.Tensor]] = None, ): - if comm_dict is None: - assert mapping is not None - assert extended_atype_embd is not None nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] @@ -402,14 +401,9 @@ def forward( # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) + atype_embd = extended_atype_embd # [nframes, nloc, tebd_dim] - if comm_dict is None: - assert isinstance(extended_atype_embd, torch.Tensor) # for jit - atype_embd = extended_atype_embd[:, :nloc, :] - assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] - else: - atype_embd = extended_atype_embd - assert isinstance(atype_embd, torch.Tensor) # for jit + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] node_ebd = self.act(atype_embd) n_dim = node_ebd.shape[-1] # nb x nloc x nnei x 1, nb x nloc x nnei x 3 @@ -459,17 +453,26 @@ def forward( # if the a neighbor is real or not is indicated by nlist_mask nlist[nlist == -1] = 0 # nb x nall x n_dim - if comm_dict is None: + if comm_dict is None or self.use_loc_mapping: assert mapping is not None mapping = ( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim) ) + node_ebd_ext = None + nlist = torch.gather( + mapping.reshape(nframes, -1), + 1, + nlist.reshape(nframes, -1), + ).reshape(nlist.shape) for idx, ll in enumerate(self.layers): # node_ebd: nb x nloc x n_dim # node_ebd_ext: nb x nall x n_dim if comm_dict is None: - assert mapping is not None - node_ebd_ext = torch.gather(node_ebd, 1, mapping) + if self.use_loc_mapping: + node_ebd_ext = None + else: + assert mapping is not None + node_ebd_ext = torch.gather(node_ebd, 1, mapping) else: has_spin = "has_spin" in comm_dict if not has_spin: @@ -528,6 +531,7 @@ def forward( node_ebd_real_ext, node_ebd_virtual_ext, real_nloc ) node_ebd, edge_ebd, angle_ebd = ll.forward( + node_ebd, node_ebd_ext, edge_ebd, h2, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0260700165..41f6cd49b9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1421,6 +1421,12 @@ def descrpt_dpa3_args(): default=False, doc=doc_use_tebd_bias, ), + Argument( + "use_loc_mapping", + bool, + optional=True, + default=True, + ), ] diff --git a/source/tests/pt/model/test_loc_mapping.py b/source/tests/pt/model/test_loc_mapping.py new file mode 100644 index 0000000000..aaca6430a4 --- /dev/null +++ b/source/tests/pt/model/test_loc_mapping.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt.model.descriptor import ( + DescrptDPA3, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptDPA3LocMapping(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + nme, + prec, + ect, + optim, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate + [1, 2], # n_multi_edge_message + ["float64"], # precision + [False], # use_econf_tebd + [True, False], # optim_update + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei, + a_compress_rate=acr, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + optim_update=optim, + smooth_edge_update=True, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( # type: ignore[call-arg] + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + use_loc_mapping=False, + ).to(env.DEVICE) + + # dpa3 using local mapping + dd1 = DescrptDPA3( # type: ignore[call-arg] + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + use_loc_mapping=True, + ).to(env.DEVICE) + + coord_ext = np.concatenate([self.coord_ext[:1], self.coord_ext[:1]], axis=0) + atype_ext = np.concatenate([self.atype_ext[:1], self.atype_ext[:1]], axis=0) + nlist = np.concatenate([self.nlist[:1], self.nlist[:1]], axis=0) + mapping = np.concatenate([self.mapping[:1], self.mapping[:1]], axis=0) + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0.forward( + torch.tensor(coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(nlist, dtype=int, device=env.DEVICE), + torch.tensor(mapping, dtype=int, device=env.DEVICE), + ) + + dd1.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd1.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd1, _, _, _, _ = dd1.forward( + torch.tensor(coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(nlist, dtype=int, device=env.DEVICE), + torch.tensor(mapping, dtype=int, device=env.DEVICE), + ) + + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + def test_consistency_nosel( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + nme, + prec, + ect, + optim, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate + [1, 2], # n_multi_edge_message + ["float64"], # precision + [False], # use_econf_tebd + [True, False], # optim_update + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei, + a_compress_rate=acr, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + optim_update=optim, + smooth_edge_update=True, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( # type: ignore[call-arg] + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + use_loc_mapping=False, + ).to(env.DEVICE) + + # dpa3 using local mapping + dd1 = DescrptDPA3( # type: ignore[call-arg] + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + use_loc_mapping=True, + ).to(env.DEVICE) + + coord_ext = np.concatenate([self.coord_ext[:1], self.coord_ext[:1]], axis=0) + atype_ext = np.concatenate([self.atype_ext[:1], self.atype_ext[:1]], axis=0) + nlist = np.concatenate([self.nlist[:1], self.nlist[:1]], axis=0) + mapping = np.concatenate([self.mapping[:1], self.mapping[:1]], axis=0) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0.forward( + torch.tensor(coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(nlist, dtype=int, device=env.DEVICE), + torch.tensor(mapping, dtype=int, device=env.DEVICE), + ) + + dd1.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd1.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd1, _, _, _, _ = dd1.forward( + torch.tensor(coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(nlist, dtype=int, device=env.DEVICE), + torch.tensor(mapping, dtype=int, device=env.DEVICE), + ) + + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 4fa0593419..f37c02d666 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -484,6 +484,7 @@ def DescriptorParamDPA3( smooth_edge_update=False, fix_stat_std=0.3, precision="float64", + use_loc_mapping=True, ): input_dict = { # kwargs for repformer @@ -522,6 +523,7 @@ def DescriptorParamDPA3( "trainable": True, "use_econf_tebd": False, "use_tebd_bias": False, + "use_loc_mapping": use_loc_mapping, "type_map": type_map, "seed": GLOBAL_SEED, } @@ -544,6 +546,7 @@ def DescriptorParamDPA3( "n_multi_edge_message": (1, 2), "env_protection": (0.0, 1e-8), "precision": ("float64",), + "use_loc_mapping": (True, False), } ), )