From 8fd9565ce03ded903b04ec6c43b1417c48ce6dad Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:49:40 +0800 Subject: [PATCH 1/8] feat(pt): support spin virial --- deepmd/pt/loss/ener_spin.py | 16 +++++++ deepmd/pt/model/model/make_model.py | 17 +++++++ deepmd/pt/model/model/spin_model.py | 45 +++++++++++++++---- deepmd/pt/model/model/transform_output.py | 7 +++ source/api_c/include/deepmd.hpp | 12 ++--- source/api_c/src/c_api.cc | 10 ++--- source/api_cc/src/DeepSpinPT.cc | 25 +++++------ source/tests/pt/model/test_autodiff.py | 17 ++++++- source/tests/pt/model/test_ener_spin_model.py | 3 +- .../universal/common/cases/model/utils.py | 12 +++-- source/tests/universal/pt/model/test_model.py | 2 + 11 files changed, 126 insertions(+), 40 deletions(-) diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py index 6a926f4051..850f66bf1d 100644 --- a/deepmd/pt/loss/ener_spin.py +++ b/deepmd/pt/loss/ener_spin.py @@ -268,6 +268,22 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): rmse_ae.detach(), find_atom_ener ) + if self.has_v and "virial" in model_pred and "virial" in label: + find_virial = label.get("find_virial", 0.0) + pref_v = pref_v * find_virial + diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) + l2_virial_loss = torch.mean(torch.square(diff_v)) + if not self.inference: + more_loss["l2_virial_loss"] = self.display_if_exist( + l2_virial_loss.detach(), find_virial + ) + loss += atom_norm * (pref_v * l2_virial_loss) + rmse_v = l2_virial_loss.sqrt() * atom_norm + more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial) + if mae: + mae_v = torch.mean(torch.abs(diff_v)) * atom_norm + more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) + if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return model_pred, loss, more_loss diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index c32abaa095..2756c66252 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -135,6 +135,7 @@ def forward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + coord_corr_for_virial: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """Return model prediction. @@ -153,6 +154,9 @@ def forward_common( atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. + coord_corr_for_virial + The coordinates correction of the atoms for virial. + shape: nf x (nloc x 3) Returns ------- @@ -180,6 +184,14 @@ def forward_common( mixed_types=True, box=bb, ) + if coord_corr_for_virial is not None: + coord_corr_for_virial = coord_corr_for_virial.to(cc.dtype) + extended_coord_corr = torch.gather( + coord_corr_for_virial, 1, mapping.unsqueeze(-1).expand(-1, -1, 3) + ) + else: + extended_coord_corr = None + model_predict_lower = self.forward_common_lower( extended_coord, extended_atype, @@ -188,6 +200,7 @@ def forward_common( do_atomic_virial=do_atomic_virial, fparam=fp, aparam=ap, + extended_coord_corr=extended_coord_corr, ) model_predict = communicate_extended_output( model_predict_lower, @@ -242,6 +255,7 @@ def forward_common_lower( do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, extra_nlist_sort: bool = False, + extended_coord_corr: Optional[torch.Tensor] = None, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -268,6 +282,8 @@ def forward_common_lower( The data needed for communication for parallel inference. extra_nlist_sort whether to forcibly sort the nlist. + extended_coord_corr + coordinates correction for virial in extended region. nf x (nall x 3) Returns ------- @@ -299,6 +315,7 @@ def forward_common_lower( cc_ext, do_atomic_virial=do_atomic_virial, create_graph=self.training, + extended_coord_corr=extended_coord_corr, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index ac94668039..a847a869ce 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -54,11 +54,14 @@ def process_spin_input(self, coord, atype, spin): coord = coord.reshape(nframes, nloc, 3) spin = spin.reshape(nframes, nloc, 3) atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1) - virtual_coord = coord + spin * (self.virtual_scale_mask.to(atype.device))[ - atype - ].reshape([nframes, nloc, 1]) + spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape( + [nframes, nloc, 1] + ) + virtual_coord = coord + spin_dist coord_spin = torch.concat([coord, virtual_coord], dim=-2) - return coord_spin, atype_spin + # for spin virial corr + coord_corr = torch.concat([torch.zeros_like(coord), -spin_dist], dim=-2) + return coord_spin, atype_spin, coord_corr def process_spin_input_lower( self, @@ -78,13 +81,18 @@ def process_spin_input_lower( """ nframes, nall = extended_coord.shape[:2] nloc = nlist.shape[1] - virtual_extended_coord = extended_coord + extended_spin * ( + extended_spin_dist = extended_spin * ( self.virtual_scale_mask.to(extended_atype.device) )[extended_atype].reshape([nframes, nall, 1]) + virtual_extended_coord = extended_coord + extended_spin_dist virtual_extended_atype = extended_atype + self.ntypes_real extended_coord_updated = concat_switch_virtual( extended_coord, virtual_extended_coord, nloc ) + # for spin virial corr + extended_coord_corr = concat_switch_virtual( + torch.zeros_like(extended_coord), -extended_spin_dist, nloc + ) extended_atype_updated = concat_switch_virtual( extended_atype, virtual_extended_atype, nloc ) @@ -100,6 +108,7 @@ def process_spin_input_lower( extended_atype_updated, nlist_updated, mapping_updated, + extended_coord_corr, ) def process_spin_output( @@ -367,7 +376,7 @@ def spin_sampled_func(): sampled = sampled_func() spin_sampled = [] for sys in sampled: - coord_updated, atype_updated = self.process_spin_input( + coord_updated, atype_updated, _ = self.process_spin_input( sys["coord"], sys["atype"], sys["spin"] ) tmp_dict = { @@ -398,7 +407,9 @@ def forward_common( do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: nframes, nloc = atype.shape - coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) + coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input( + coord, atype, spin + ) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) model_ret = self.backbone_model.forward_common( @@ -408,6 +419,7 @@ def forward_common( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + coord_corr_for_virial=coord_corr_for_virial, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: @@ -454,6 +466,7 @@ def forward_common_lower( extended_atype_updated, nlist_updated, mapping_updated, + extended_coord_corr_for_virial, ) = self.process_spin_input_lower( extended_coord, extended_atype, extended_spin, nlist, mapping=mapping ) @@ -469,6 +482,7 @@ def forward_common_lower( do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr_for_virial, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: @@ -541,6 +555,11 @@ def translated_output_def(self): output_def["force"].squeeze(-2) output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"]) output_def["force_mag"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"].squeeze(-3) return output_def def forward( @@ -569,7 +588,10 @@ def forward( if self.backbone_model.do_grad_r("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) - # not support virial by far + if self.backbone_model.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) return model_predict @torch.jit.export @@ -606,5 +628,10 @@ def forward_lower( model_predict["extended_force_mag"] = model_ret[ "energy_derv_r_mag" ].squeeze(-2) - # not support virial by far + if self.backbone_model.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) return model_predict diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index e15eda6a1d..fcd41e075c 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -156,6 +156,7 @@ def fit_output_to_model_output( coord_ext: torch.Tensor, do_atomic_virial: bool = False, create_graph: bool = True, + extended_coord_corr: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. @@ -187,6 +188,12 @@ def fit_output_to_model_output( model_ret[kk_derv_r] = dr if vdef.c_differentiable: assert dc is not None + if extended_coord_corr is not None: + dc_corr = ( + dr.squeeze(-2).unsqueeze(-1) + @ extended_coord_corr.unsqueeze(-2) + ).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005 + dc = dc + dc_corr model_ret[kk_derv_c] = dc model_ret[kk_derv_c + "_redu"] = torch.sum( model_ret[kk_derv_c].to(redu_prec), dim=1 diff --git a/source/api_c/include/deepmd.hpp b/source/api_c/include/deepmd.hpp index 8a3656bfc2..a37fe10fa9 100644 --- a/source/api_c/include/deepmd.hpp +++ b/source/api_c/include/deepmd.hpp @@ -2602,9 +2602,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi { for (int j = 0; j < natoms * 3; j++) { force_mag[i][j] = force_mag_flat[i * natoms * 3 + j]; } - // for (int j = 0; j < 9; j++) { - // virial[i][j] = virial_flat[i * 9 + j]; - // } + for (int j = 0; j < 9; j++) { + virial[i][j] = virial_flat[i * 9 + j]; + } } }; /** @@ -2705,9 +2705,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi { for (int j = 0; j < natoms * 3; j++) { force_mag[i][j] = force_mag_flat[i * natoms * 3 + j]; } - // for (int j = 0; j < 9; j++) { - // virial[i][j] = virial_flat[i * 9 + j]; - // } + for (int j = 0; j < 9; j++) { + virial[i][j] = virial_flat[i * 9 + j]; + } for (int j = 0; j < natoms; j++) { atom_energy[i][j] = atom_energy_flat[i * natoms + j]; } diff --git a/source/api_c/src/c_api.cc b/source/api_c/src/c_api.cc index 4a0cff1520..3acb28a002 100644 --- a/source/api_c/src/c_api.cc +++ b/source/api_c/src/c_api.cc @@ -862,11 +862,11 @@ void DP_DeepSpinModelDeviCompute_variant(DP_DeepSpinModelDevi* dp, flatten_vector(fm_flat, fm); std::copy(fm_flat.begin(), fm_flat.end(), force_mag); } - // if (virial) { - // std::vector v_flat; - // flatten_vector(v_flat, v); - // std::copy(v_flat.begin(), v_flat.end(), virial); - // } + if (virial) { + std::vector v_flat; + flatten_vector(v_flat, v); + std::copy(v_flat.begin(), v_flat.end(), virial); + } if (atomic_energy) { std::vector ae_flat; flatten_vector(ae_flat, ae); diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 7421b623db..eb43dbf6d0 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -251,8 +251,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, c10::IValue energy_ = outputs.at("energy"); c10::IValue force_ = outputs.at("extended_force"); c10::IValue force_mag_ = outputs.at("extended_force_mag"); - // spin model not suported yet - // c10::IValue virial_ = outputs.at("virial"); + c10::IValue virial_ = outputs.at("virial"); torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU); ener.assign(cpu_energy_.data_ptr(), @@ -267,11 +266,11 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, dforce_mag.assign( cpu_force_mag_.data_ptr(), cpu_force_mag_.data_ptr() + cpu_force_mag_.numel()); - // spin model not suported yet - // torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); - // torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); - // virial.assign(cpu_virial_.data_ptr(), - // cpu_virial_.data_ptr() + cpu_virial_.numel()); + + torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); // bkw map force.resize(static_cast(nframes) * fwd_map.size() * 3); @@ -415,8 +414,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, c10::IValue energy_ = outputs.at("energy"); c10::IValue force_ = outputs.at("force"); c10::IValue force_mag_ = outputs.at("force_mag"); - // spin model not suported yet - // c10::IValue virial_ = outputs.at("virial"); + c10::IValue virial_ = outputs.at("virial"); torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU); ener.assign(cpu_energy_.data_ptr(), @@ -431,11 +429,10 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, force_mag.assign( cpu_force_mag_.data_ptr(), cpu_force_mag_.data_ptr() + cpu_force_mag_.numel()); - // spin model not suported yet - // torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); - // torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); - // virial.assign(cpu_virial_.data_ptr(), - // cpu_virial_.data_ptr() + cpu_virial_.numel()); + torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); if (atomic) { // c10::IValue atom_virial_ = outputs.at("atom_virial"); c10::IValue atom_energy_ = outputs.at("atom_energy"); diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index 31e06af751..fab637f0f8 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -141,11 +141,17 @@ def test( cell = (cell) + 5.0 * torch.eye(3, device="cpu") coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) atype = torch.IntTensor([0, 0, 0, 1, 1]) # assumes input to be numpy tensor coord = coord.numpy() + spin = spin.numpy() cell = cell.numpy() - test_keys = ["energy", "force", "virial"] + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag", "virial"] def np_infer( new_cell, @@ -157,6 +163,7 @@ def np_infer( ).unsqueeze(0), torch.tensor(new_cell, device="cpu").unsqueeze(0), atype, + spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), ) # detach ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} @@ -251,3 +258,11 @@ def setUp(self) -> None: self.type_split = False self.test_spin = True self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pt/model/test_ener_spin_model.py b/source/tests/pt/model/test_ener_spin_model.py index ddea392f33..66bb1082a0 100644 --- a/source/tests/pt/model/test_ener_spin_model.py +++ b/source/tests/pt/model/test_ener_spin_model.py @@ -115,7 +115,7 @@ def test_input_output_process(self) -> None: nframes, nloc = self.coord.shape[:2] self.real_ntypes = self.model.spin.get_ntypes_real() # 1. test forward input process - coord_updated, atype_updated = self.model.process_spin_input( + coord_updated, atype_updated, _ = self.model.process_spin_input( self.coord, self.atype, self.spin ) # compare atypes of real and virtual atoms @@ -174,6 +174,7 @@ def test_input_output_process(self) -> None: extended_atype_updated, nlist_updated, mapping_updated, + _, ) = self.model.process_spin_input_lower( extended_coord, extended_atype, extended_spin, nlist, mapping=mapping ) diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index 8fe6a131ef..e2a1b4866a 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -892,7 +892,10 @@ def ff_spin(_spin): fdf.reshape(-1, 3), rff.reshape(-1, 3), decimal=places ) - if not test_spin: + # this option can be removed after other backends support spin virial + test_spin_virial = getattr(self, "test_spin_virial", False) + + if not test_spin or test_spin_virial: def ff_cell(bb): input_dict = { @@ -902,6 +905,8 @@ def ff_cell(bb): "aparam": aparam, "fparam": fparam, } + if test_spin: + input_dict["spin"] = spin return module(**input_dict)["energy"] fdv = ( @@ -921,13 +926,12 @@ def ff_cell(bb): "aparam": aparam, "fparam": fparam, } + if test_spin: + input_dict["spin"] = spin rfv = module(**input_dict)["virial"] np.testing.assert_almost_equal( fdv.reshape(-1, 9), rfv.reshape(-1, 9), decimal=places ) - else: - # not support virial by far - pass @unittest.skipIf(TEST_DEVICE == "cpu" and CI, "Skip test on CPU.") def test_device_consistence(self) -> None: diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 3eb1484c45..ec6cd71782 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -713,6 +713,8 @@ def setUpClass(cls) -> None: cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + # this option can be removed after other backends support spin virial + cls.test_spin_virial = True @parameterized( From cbbce6447b7453d18c8c24b0139a373500461193 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 11 May 2025 00:26:56 +0800 Subject: [PATCH 2/8] fix dpa3 spin lmp --- deepmd/pt/model/descriptor/repflows.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 330336b1de..5f41379dca 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -379,7 +379,6 @@ def forward( ): if comm_dict is None: assert mapping is not None - assert extended_atype_embd is not None nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] @@ -403,13 +402,9 @@ def forward( sw = sw.masked_fill(~nlist_mask, 0.0) # [nframes, nloc, tebd_dim] - if comm_dict is None: - assert isinstance(extended_atype_embd, torch.Tensor) # for jit - atype_embd = extended_atype_embd[:, :nloc, :] - assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] - else: - atype_embd = extended_atype_embd - assert isinstance(atype_embd, torch.Tensor) # for jit + assert extended_atype_embd is not None + atype_embd = extended_atype_embd[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] node_ebd = self.act(atype_embd) n_dim = node_ebd.shape[-1] # nb x nloc x nnei x 1, nb x nloc x nnei x 3 From 107cecde848480acb32b37d99c50dfaa8a978f6c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 23 Jun 2025 15:29:39 +0800 Subject: [PATCH 3/8] update dynamic sel --- deepmd/dpmodel/descriptor/dpa3.py | 4 + deepmd/pt/model/descriptor/dpa3.py | 2 + deepmd/pt/model/descriptor/repflow_layer.py | 491 +++++++++++++++++--- deepmd/pt/model/descriptor/repflows.py | 89 +++- deepmd/pt/model/network/utils.py | 135 ++++++ deepmd/utils/argcheck.py | 12 + source/tests/pt/model/test_nosel.py | 205 ++++++++ 7 files changed, 861 insertions(+), 77 deletions(-) create mode 100644 deepmd/pt/model/network/utils.py create mode 100644 source/tests/pt/model/test_nosel.py diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 668ef36043..9c19e7c841 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -151,6 +151,8 @@ def __init__( skip_stat: bool = False, optim_update: bool = True, smooth_edge_update: bool = False, + use_dynamic_sel: bool = False, + sel_reduce_factor: float = 10.0, ) -> None: self.n_dim = n_dim self.e_dim = e_dim @@ -177,6 +179,8 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_dynamic_sel = use_dynamic_sel + self.sel_reduce_factor = sel_reduce_factor def __getitem__(self, key): if hasattr(self, key): diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 545da962e7..0d2ba060df 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -150,6 +150,8 @@ def init_subclass_params(sub_data, sub_class): fix_stat_std=self.repflow_args.fix_stat_std, optim_update=self.repflow_args.optim_update, smooth_edge_update=self.repflow_args.smooth_edge_update, + use_dynamic_sel=self.repflow_args.use_dynamic_sel, + sel_reduce_factor=self.repflow_args.sel_reduce_factor, exclude_types=exclude_types, env_protection=env_protection, precision=precision, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index f109109cfd..1252763bbe 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -19,6 +19,9 @@ from deepmd.pt.model.network.mlp import ( MLPLayer, ) +from deepmd.pt.model.network.utils import ( + aggregate, +) from deepmd.pt.utils.env import ( PRECISION_DICT, ) @@ -52,6 +55,8 @@ def __init__( axis_neuron: int = 4, update_angle: bool = True, optim_update: bool = True, + use_dynamic_sel: bool = False, + sel_reduce_factor: float = 10.0, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -98,6 +103,10 @@ def __init__( self.prec = PRECISION_DICT[precision] self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_dynamic_sel = use_dynamic_sel + self.sel_reduce_factor = sel_reduce_factor + self.dynamic_e_sel = self.nnei / self.sel_reduce_factor + self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor assert update_residual_init in [ "norm", @@ -318,6 +327,58 @@ def _cal_hg( h2g2 = torch.matmul(torch.transpose(h2, -1, -2), edge_ebd) * invnnei return h2g2 + @staticmethod + def _cal_hg_dynamic( + flat_edge_ebd: torch.Tensor, + flat_h2: torch.Tensor, + flat_sw: torch.Tensor, + owner: torch.Tensor, + num_owner: int, + nloc: int, + scale_factor: float, + ) -> torch.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + flat_edge_ebd + Flatted neighbor-wise/pair-wise invariant rep tensors, with shape n_edge x e_dim. + flat_h2 + Flatted neighbor-wise/pair-wise equivariant rep tensors, with shape n_edge x 3. + flat_sw + Flatted switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape n_edge. + owner + The owner index of the neighbor to reduce on. + num_owner : int + The total number of the owner. + nloc : int + The number of local atoms. + scale_factor : float + The scale factor to apply after reduce. + + Returns + ------- + hg + The transposed rotation matrix, with shape nf x nloc x 3 x e_dim. + """ + n_edge, e_dim = flat_edge_ebd.shape + # n_edge x e_dim + flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1) + # n_edge x 3 x e_dim + flat_h2g2 = (flat_h2[:, :, None] * flat_edge_ebd[:, None, :]).reshape( + -1, 3 * e_dim + ) + # nf x nloc x 3 x e_dim + h2g2 = ( + aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape( + -1, nloc, 3, e_dim + ) + * scale_factor + ) + return h2g2 + @staticmethod def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: """ @@ -390,6 +451,59 @@ def symmetrization_op( g1_13 = self._cal_grrg(h2g2, axis_neuron) return g1_13 + def symmetrization_op_dynamic( + self, + flat_edge_ebd: torch.Tensor, + flat_h2: torch.Tensor, + flat_sw: torch.Tensor, + owner: torch.Tensor, + num_owner: int, + nloc: int, + scale_factor: float, + axis_neuron: int, + ) -> torch.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + flat_edge_ebd + Flatted neighbor-wise/pair-wise invariant rep tensors, with shape n_edge x e_dim. + flat_h2 + Flatted neighbor-wise/pair-wise equivariant rep tensors, with shape n_edge x 3. + flat_sw + Flatted switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape n_edge. + owner + The owner index of the neighbor to reduce on. + num_owner : int + The total number of the owner. + nloc : int + The number of local atoms. + scale_factor : float + The scale factor to apply after reduce. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) + """ + # nb x nloc x 3 x e_dim + h2g2 = self._cal_hg_dynamic( + flat_edge_ebd, + flat_h2, + flat_sw, + owner, + num_owner, + nloc, + scale_factor, + ) + # nb x nloc x (axis x e_dim) + grrg = self._cal_grrg(h2g2, axis_neuron) + return grrg + def optim_angle_update( self, angle_ebd: torch.Tensor, @@ -432,6 +546,65 @@ def optim_angle_update( ) return result_update + def optim_angle_update_dynamic( + self, + flat_angle_ebd: torch.Tensor, + node_ebd: torch.Tensor, + flat_edge_ebd: torch.Tensor, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + feat: str = "edge", + ) -> torch.Tensor: + nf, nloc, node_dim = node_ebd.shape + angle_dim = flat_angle_ebd.shape[-1] + edge_dim = flat_edge_ebd.shape[-1] + sub_angle_idx = (0, angle_dim) + sub_node_idx = (angle_dim, angle_dim + node_dim) + sub_edge_idx_ik = (angle_dim + node_dim, angle_dim + node_dim + edge_dim) + sub_edge_idx_ij = ( + angle_dim + node_dim + edge_dim, + angle_dim + node_dim + 2 * edge_dim, + ) + + if feat == "edge": + matrix, bias = self.edge_angle_linear1.matrix, self.edge_angle_linear1.bias + elif feat == "angle": + matrix, bias = self.angle_self_linear.matrix, self.angle_self_linear.bias + else: + raise NotImplementedError + assert angle_dim + node_dim + 2 * edge_dim == matrix.size()[0] + + # n_angle * angle_dim + sub_angle_update = torch.matmul( + flat_angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1]] + ) + + # nf * nloc * angle_dim + sub_node_update = torch.matmul( + node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]] + ) + # n_angle * angle_dim + sub_node_update = torch.index_select( + sub_node_update.reshape(nf * nloc, -1), 0, n2a_index + ) + + # n_edge * angle_dim + sub_edge_update_ik = torch.matmul( + flat_edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1]] + ) + sub_edge_update_ij = torch.matmul( + flat_edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1]] + ) + # n_angle * angle_dim + sub_edge_update_ik = torch.index_select(sub_edge_update_ik, 0, eik2a_index) + sub_edge_update_ij = torch.index_select(sub_edge_update_ij, 0, eij2a_index) + + result_update = ( + sub_angle_update + sub_node_update + sub_edge_update_ik + sub_edge_update_ij + ) + bias + return result_update + def optim_edge_update( self, node_ebd: torch.Tensor, @@ -467,6 +640,56 @@ def optim_edge_update( ) return result_update + def optim_edge_update_dynamic( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + flat_edge_ebd: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + feat: str = "node", + ) -> torch.Tensor: + nf, nall, node_dim = node_ebd_ext.shape + _, nloc, _ = node_ebd.shape + edge_dim = flat_edge_ebd.shape[-1] + sub_node_idx = (0, node_dim) + sub_node_ext_idx = (node_dim, 2 * node_dim) + sub_edge_idx = (2 * node_dim, 2 * node_dim + edge_dim) + + if feat == "node": + matrix, bias = self.node_edge_linear.matrix, self.node_edge_linear.bias + elif feat == "edge": + matrix, bias = self.edge_self_linear.matrix, self.edge_self_linear.bias + else: + raise NotImplementedError + assert 2 * node_dim + edge_dim == matrix.size()[0] + + # nf * nloc * node/edge_dim + sub_node_update = torch.matmul( + node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]] + ) + # n_edge * node/edge_dim + sub_node_update = torch.index_select( + sub_node_update.reshape(nf * nloc, -1), 0, n2e_index + ) + + # nf * nall * node/edge_dim + sub_node_ext_update = torch.matmul( + node_ebd_ext, matrix[sub_node_ext_idx[0] : sub_node_ext_idx[1]] + ) + # n_edge * node/edge_dim + sub_node_ext_update = torch.index_select( + sub_node_ext_update.reshape(nf * nall, -1), 0, n_ext2e_index + ) + + # n_edge * node/edge_dim + sub_edge_update = torch.matmul( + flat_edge_ebd, matrix[sub_edge_idx[0] : sub_edge_idx[1]] + ) + + result_update = (sub_edge_update + sub_node_ext_update + sub_node_update) + bias + return result_update + def forward( self, node_ebd_ext: torch.Tensor, # nf x nall x n_dim @@ -479,6 +702,8 @@ def forward( a_nlist: torch.Tensor, # nf x nloc x a_nnei a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei + edge_index: torch.Tensor, # n_edge x 2 + angle_index: torch.Tensor, # n_angle x 3 ): """ Parameters @@ -503,6 +728,18 @@ def forward( Masks of the neighbor list for angle. real nei 1 otherwise 0 a_sw : nf x nloc x a_nnei Switch function for angle. + edge_index : Optional for dynamic sel, n_edge x 2 + n2e_index : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + n_ext2e_index : n_edge + Broadcast indices from extended node(j) to edge(ij). + angle_index : Optional for dynamic sel, n_angle x 3 + n2a_index : n_angle + Broadcast indices from extended node(j) to angle(ijk). + eij2a_index : n_angle + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + eik2a_index : n_angle + Broadcast indices from extended edge(ik) to angle(ijk). Returns ------- @@ -513,13 +750,33 @@ def forward( a_updated : nf x nloc x a_nnei x a_nnei x a_dim Updated angle embedding. """ - nb, nloc, nnei, _ = edge_ebd.shape + nb, nloc, nnei = nlist.shape nall = node_ebd_ext.shape[1] node_ebd = node_ebd_ext[:, :nloc, :] + n_edge = int(nlist_mask.sum().item()) assert (nb, nloc) == node_ebd.shape[:2] - assert (nb, nloc, nnei) == h2.shape[:3] + if not self.use_dynamic_sel: + assert (nb, nloc, nnei, 3) == h2.shape + else: + assert (n_edge, 3) == h2.shape del a_nlist # may be used in the future + n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1] + n2a_index, eij2a_index, eik2a_index = ( + angle_index[:, 0], + angle_index[:, 1], + angle_index[:, 2], + ) + + # nb x nloc x nnei x n_dim [OR] n_edge x n_dim + nei_node_ebd = ( + _make_nei_g1(node_ebd_ext, nlist) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index + ) + ) + n_update_list: list[torch.Tensor] = [node_ebd] e_update_list: list[torch.Tensor] = [edge_ebd] a_update_list: list[torch.Tensor] = [angle_ebd] @@ -528,8 +785,6 @@ def forward( node_self_mlp = self.act(self.node_self_mlp(node_ebd)) n_update_list.append(node_self_mlp) - nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) - # node sym (grrg + drrd) node_sym_list: list[torch.Tensor] = [] node_sym_list.append( @@ -540,6 +795,17 @@ def forward( sw, self.axis_neuron, ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + edge_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) ) node_sym_list.append( self.symmetrization_op( @@ -549,20 +815,44 @@ def forward( sw, self.axis_neuron, ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) ) node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) n_update_list.append(node_sym) if not self.optim_update: - # nb x nloc x nnei x (n_dim * 2 + e_dim) - edge_info = torch.cat( - [ - torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) + if not self.use_dynamic_sel: + # nb x nloc x nnei x (n_dim * 2 + e_dim) + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + # n_edge x (n_dim * 2 + e_dim) + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) else: edge_info = None @@ -582,11 +872,32 @@ def forward( nlist, "node", ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "node", + ) ) * sw.unsqueeze(-1) + node_edge_update = ( + (torch.sum(node_edge_update, dim=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, -1) + / self.dynamic_e_sel + ) + ) - node_edge_update = torch.sum(node_edge_update, dim=-2) / self.nnei if self.n_multi_edge_message > 1: - # nb x nloc x nnei x h x n_dim + # nb x nloc x h x n_dim node_edge_update_mul_head = node_edge_update.view( nb, nloc, self.n_multi_edge_message, self.n_dim ) @@ -610,6 +921,15 @@ def forward( nlist, "edge", ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) ) e_update_list.append(edge_self_update) @@ -632,40 +952,60 @@ def forward( node_ebd_for_angle = node_ebd edge_ebd_for_angle = edge_ebd - # nb x nloc x a_nnei x e_dim - edge_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_for_angle = torch.where( - a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0 - ) + if not self.use_dynamic_sel: + # nb x nloc x a_nnei x e_dim + edge_ebd_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :] + # nb x nloc x a_nnei x e_dim + edge_ebd_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 + ) if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim - node_for_angle_info = torch.tile( - node_ebd_for_angle.unsqueeze(2).unsqueeze(2), - (1, 1, self.a_sel, self.a_sel, 1), + # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim + node_for_angle_info = ( + torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), + 0, + n2a_index, + ) ) - # nb x nloc x (a_nnei) x a_nnei x edge_ebd - edge_for_angle_i = torch.tile( - edge_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) + + # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim + edge_for_angle_k = ( + torch.tile( + edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) + ) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) ) - # nb x nloc x a_nnei x (a_nnei) x e_dim - edge_for_angle_j = torch.tile( - edge_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) + # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim + edge_for_angle_j = ( + torch.tile( + edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) + ) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) + # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) edge_for_angle_info = torch.cat( - [edge_for_angle_i, edge_for_angle_j], dim=-1 + [edge_for_angle_k, edge_for_angle_j], dim=-1 ) angle_info_list = [angle_ebd] angle_info_list.append(node_for_angle_info) angle_info_list.append(edge_for_angle_info) # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) + # [OR] + # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) angle_info = torch.cat(angle_info_list, dim=-1) else: angle_info = None # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim + # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim if not self.optim_update: assert angle_info is not None edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) @@ -674,33 +1014,62 @@ def forward( self.optim_angle_update( angle_ebd, node_ebd_for_angle, - edge_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, "edge", ) ) - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update - ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = torch.sum( - weighted_edge_angle_update, dim=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = torch.concat( - [ - reduced_edge_angle_update, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd.dtype, - device=edge_ebd.device, - ), - ], - dim=2, - ) + if not self.use_dynamic_sel: + # nb x nloc x a_nnei x a_nnei x e_dim + weighted_edge_angle_update = ( + a_sw[:, :, :, None, None] + * a_sw[:, :, None, :, None] + * edge_angle_update + ) + # nb x nloc x a_nnei x e_dim + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + # nb x nloc x nnei x e_dim + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd.dtype, + device=edge_ebd.device, + ), + ], + dim=2, + ) + else: + # n_angle x e_dim + weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) + # n_edge x e_dim + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + if not self.smooth_edge_update: # will be deprecated in the future + # not support dynamic index, will pass anyway + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) full_mask = torch.concat( [ a_nlist_mask, @@ -731,7 +1100,17 @@ def forward( self.optim_angle_update( angle_ebd, node_ebd_for_angle, - edge_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, "angle", ) ) diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 5f41379dca..991413180a 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -19,6 +19,9 @@ from deepmd.pt.model.network.mlp import ( MLPLayer, ) +from deepmd.pt.model.network.utils import ( + get_graph_index, +) from deepmd.pt.utils import ( env, ) @@ -183,6 +186,8 @@ def __init__( precision: str = "float64", fix_stat_std: float = 0.3, smooth_edge_update: bool = False, + use_dynamic_sel: bool = False, + sel_reduce_factor: float = 10.0, optim_update: bool = True, seed: Optional[Union[int, list[int]]] = None, ) -> None: @@ -215,6 +220,8 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_dynamic_sel = use_dynamic_sel + self.sel_reduce_factor = sel_reduce_factor self.n_dim = n_dim self.e_dim = e_dim @@ -267,6 +274,8 @@ def __init__( update_residual_init=self.update_residual_init, precision=precision, optim_update=self.optim_update, + use_dynamic_sel=self.use_dynamic_sel, + sel_reduce_factor=self.sel_reduce_factor, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), ) @@ -401,17 +410,6 @@ def forward( # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) - # [nframes, nloc, tebd_dim] - assert extended_atype_embd is not None - atype_embd = extended_atype_embd[:, :nloc, :] - assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] - node_ebd = self.act(atype_embd) - n_dim = node_ebd.shape[-1] - # nb x nloc x nnei x 1, nb x nloc x nnei x 3 - edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) - # nb x nloc x nnei x e_dim - edge_ebd = self.act(self.edge_embd(edge_input)) - # get angle nlist (maybe smaller) a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ :, :, : self.a_sel @@ -432,8 +430,22 @@ def forward( a_sw = torch.squeeze(a_sw, -1) # beyond the cutoff sw should be 0.0 a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 a_nlist[a_nlist == -1] = 0 + # get node embedding + # [nframes, nloc, tebd_dim] + assert extended_atype_embd is not None + atype_embd = extended_atype_embd[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] + node_ebd = self.act(atype_embd) + n_dim = node_ebd.shape[-1] + + # get edge and angle embedding input + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 @@ -443,16 +455,37 @@ def forward( # nf x nloc x a_nnei x a_nnei # 1 - 1e-6 for torch.acos stability cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) - # nf x nloc x a_nnei x a_nnei x 1 - cosine_ij = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) - # nf x nloc x a_nnei x a_nnei x a_dim - angle_ebd = self.angle_embd(cosine_ij).reshape( - nframes, nloc, self.a_sel, self.a_sel, self.a_dim - ) + angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) + + if self.use_dynamic_sel: + # get graph index + edge_index, angle_index = get_graph_index( + nlist, nlist_mask, a_nlist_mask, nall + ) + # flat all the tensors + # n_edge x 1 + edge_input = edge_input[nlist_mask] + # n_edge x 3 + h2 = h2[nlist_mask] + # n_edge x 1 + sw = sw[nlist_mask] + # nb x nloc x a_nnei x a_nnei + a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] + # n_angle x 1 + angle_input = angle_input[a_nlist_mask] + # n_angle x 1 + a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] + else: + # avoid jit assertion + edge_index = angle_index = torch.zeros( + [1, 3], device=nlist.device, dtype=nlist.dtype + ) + # get edge and angle embedding + # nb x nloc x nnei x e_dim [OR] n_edge x e_dim + edge_ebd = self.act(self.edge_embd(edge_input)) + # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim + angle_ebd = self.angle_embd(angle_input) - # set all padding positions to index of 0 - # if the a neighbor is real or not is indicated by nlist_mask - nlist[nlist == -1] = 0 # nb x nall x n_dim if comm_dict is None: assert mapping is not None @@ -533,10 +566,24 @@ def forward( a_nlist, a_nlist_mask, a_sw, + edge_index=edge_index, + angle_index=angle_index, ) # nb x nloc x 3 x e_dim - h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) + h2g2 = ( + RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) + if not self.use_dynamic_sel + else RepFlowLayer._cal_hg_dynamic( + edge_ebd, + h2, + sw, + owner=edge_index[:, 0], + num_owner=nframes * nloc, + nloc=nloc, + scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5), + ) + ) # (nb x nloc) x e_dim x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) diff --git a/deepmd/pt/model/network/utils.py b/deepmd/pt/model/network/utils.py new file mode 100644 index 0000000000..18798cf755 --- /dev/null +++ b/deepmd/pt/model/network/utils.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import torch + + +@torch.jit.export +def aggregate( + data: torch.Tensor, + owners: torch.Tensor, + average: bool = True, + num_owner: Optional[int] = None, +) -> torch.Tensor: + """ + Aggregate rows in data by specifying the owners. + + Parameters + ---------- + data : data tensor to aggregate [n_row, feature_dim] + owners : specify the owner of each row [n_row, 1] + average : if True, average the rows, if False, sum the rows. + Default = True + num_owner : the number of owners, this is needed if the + max idx of owner is not presented in owners tensor + Default = None + + Returns + ------- + output: [num_owner, feature_dim] + """ + bin_count = torch.bincount(owners) + bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1)) + + if (num_owner is not None) and (bin_count.shape[0] != num_owner): + difference = num_owner - bin_count.shape[0] + bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) + + # make sure this operation is done on the same device of data and owners + output = data.new_zeros([bin_count.shape[0], data.shape[1]]) + output = output.index_add_(0, owners, data) + if average: + output = (output.T / bin_count).T + return output + + +@torch.jit.export +def get_graph_index( + nlist: torch.Tensor, + nlist_mask: torch.Tensor, + a_nlist_mask: torch.Tensor, + nall: int, +): + """ + Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`. + + Parameters + ---------- + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + nall + The number of extended atoms. + + Returns + ------- + edge_index : n_edge x 2 + n2e_index : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + n_ext2e_index : n_edge + Broadcast indices from extended node(j) to edge(ij). + angle_index : n_angle x 3 + n2a_index : n_angle + Broadcast indices from extended node(j) to angle(ijk). + eij2a_index : n_angle + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + eik2a_index : n_angle + Broadcast indices from extended edge(ik) to angle(ijk). + """ + nf, nloc, nnei = nlist.shape + _, _, a_nnei = a_nlist_mask.shape + # nf x nloc x nnei x nnei + # nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :] + a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] + n_edge = nlist_mask.sum().item() + # n_angle = a_nlist_mask_3d.sum().item() + + # following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index + + # 1. atom graph + # node(i) to edge(ij) index_select; edge(ij) to node aggregate + nlist_loc_index = torch.arange(0, nf * nloc, dtype=nlist.dtype, device=nlist.device) + # nf x nloc x nnei + n2e_index = nlist_loc_index.reshape(nf, nloc, 1).expand(-1, -1, nnei) + # n_edge + n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0] + + # node_ext(j) to edge(ij) index_select + frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * nall + shifted_nlist = nlist + frame_shift[:, None, None] + # n_edge + n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1] + + # 2. edge graph + # node(i) to angle(ijk) index_select + n2a_index = nlist_loc_index.reshape(nf, nloc, 1, 1).expand(-1, -1, a_nnei, a_nnei) + # n_angle + n2a_index = n2a_index[a_nlist_mask_3d] + + # edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate + edge_id = torch.arange(0, n_edge, dtype=nlist.dtype, device=nlist.device) + # nf x nloc x nnei + edge_index = torch.zeros([nf, nloc, nnei], dtype=nlist.dtype, device=nlist.device) + edge_index[nlist_mask] = edge_id + # only cut a_nnei neighbors, to avoid nnei x nnei + edge_index = edge_index[:, :, :a_nnei] + edge_index_ij = edge_index.unsqueeze(-1).expand(-1, -1, -1, a_nnei) + # n_angle + eij2a_index = edge_index_ij[a_nlist_mask_3d] + + # edge(ik) to angle(ijk) index_select + edge_index_ik = edge_index.unsqueeze(-2).expand(-1, -1, a_nnei, -1) + # n_angle + eik2a_index = edge_index_ik[a_nlist_mask_3d] + + return torch.cat( + [n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], dim=-1 + ), torch.cat( + [n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)], + dim=-1, + ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0260700165..c6fbf2513e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1597,6 +1597,18 @@ def dpa3_repflow_args(): default=False, # For compatability. This will be True in the future doc=doc_smooth_edge_update, ), + Argument( + "use_dynamic_sel", + bool, + optional=True, + default=False, + ), + Argument( + "sel_reduce_factor", + float, + optional=True, + default=10.0, + ), ] diff --git a/source/tests/pt/model/test_nosel.py b/source/tests/pt/model/test_nosel.py new file mode 100644 index 0000000000..fe349231ea --- /dev/null +++ b/source/tests/pt/model/test_nosel.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt.model.descriptor import ( + DescrptDPA3, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptDPA3Nosel(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + nme, + prec, + ect, + optim, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate + [1, 2], # n_multi_edge_message + ["float64"], # precision + [False], # use_econf_tebd + [True, False], # optim_update + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei, + a_compress_rate=acr, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + optim_update=optim, + smooth_edge_update=True, + sel_reduce_factor=1.0, # test consistent when sel_reduce_factor == 1.0 + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + repflow.use_dynamic_sel = True + + # dpa3 new impl + dd1 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + # serialization + dd1.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd1.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + # def test_jit( + # self, + # ) -> None: + # rng = np.random.default_rng(100) + # nf, nloc, nnei = self.nlist.shape + # davg = rng.normal(size=(self.nt, nnei, 4)) + # dstd = rng.normal(size=(self.nt, nnei, 4)) + # dstd = 0.1 + np.abs(dstd) + # + # for ( + # ua, + # rus, + # ruri, + # acr, + # nme, + # prec, + # ect, + # ) in itertools.product( + # [True, False], # update_angle + # ["res_residual"], # update_style + # ["norm", "const"], # update_residual_init + # [0, 1], # a_compress_rate + # [1, 2], # n_multi_edge_message + # ["float64"], # precision + # [False], # use_econf_tebd + # ): + # dtype = PRECISION_DICT[prec] + # rtol, atol = get_tols(prec) + # + # repflow = RepFlowArgs( + # n_dim=20, + # e_dim=10, + # a_dim=10, + # nlayers=3, + # e_rcut=self.rcut, + # e_rcut_smth=self.rcut_smth, + # e_sel=nnei, + # a_rcut=self.rcut - 0.1, + # a_rcut_smth=self.rcut_smth, + # a_sel=nnei - 1, + # a_compress_rate=acr, + # n_multi_edge_message=nme, + # axis_neuron=4, + # update_angle=ua, + # update_style=rus, + # update_residual_init=ruri, + # ) + # + # # dpa3 new impl + # dd0 = DescrptDPA3( + # self.nt, + # repflow=repflow, + # # kwargs for descriptor + # exclude_types=[], + # precision=prec, + # use_econf_tebd=ect, + # type_map=["O", "H"] if ect else None, + # seed=GLOBAL_SEED, + # ).to(env.DEVICE) + # + # dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + # dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + # model = torch.jit.script(dd0) From 3587d07ad54a1b39124b49b013c2ad781a405089 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 26 May 2025 14:34:12 +0800 Subject: [PATCH 4/8] add exp sw --- deepmd/dpmodel/descriptor/dpa3.py | 2 ++ deepmd/pt/model/descriptor/dpa3.py | 1 + deepmd/pt/model/descriptor/env_mat.py | 11 ++++++++++- deepmd/pt/model/descriptor/repflows.py | 4 ++++ deepmd/pt/utils/preprocess.py | 12 ++++++++++++ deepmd/utils/argcheck.py | 7 +++++++ 6 files changed, 36 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 9c19e7c841..85b7980c23 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -151,6 +151,7 @@ def __init__( skip_stat: bool = False, optim_update: bool = True, smooth_edge_update: bool = False, + use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, ) -> None: @@ -179,6 +180,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 0d2ba060df..de7b25749d 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -150,6 +150,7 @@ def init_subclass_params(sub_data, sub_class): fix_stat_std=self.repflow_args.fix_stat_std, optim_update=self.repflow_args.optim_update, smooth_edge_update=self.repflow_args.smooth_edge_update, + use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, exclude_types=exclude_types, diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index dc7142249a..c57ae209fd 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -3,6 +3,7 @@ import torch from deepmd.pt.utils.preprocess import ( + compute_exp_sw, compute_smooth_weight, ) @@ -14,6 +15,7 @@ def _make_env_mat( ruct_smth: float, radial_only: bool = False, protection: float = 0.0, + use_exp_switch: bool = False, ): """Make smooth environment matrix.""" bsz, natoms, nnei = nlist.shape @@ -33,7 +35,11 @@ def _make_env_mat( length = length + ~mask.unsqueeze(-1) t0 = 1 / (length + protection) t1 = diff / (length + protection) ** 2 - weight = compute_smooth_weight(length, ruct_smth, rcut) + weight = ( + compute_smooth_weight(length, ruct_smth, rcut) + if not use_exp_switch + else compute_exp_sw(length, ruct_smth, rcut) + ) weight = weight * mask.unsqueeze(-1) if radial_only: env_mat = t0 * weight @@ -52,6 +58,7 @@ def prod_env_mat( rcut_smth: float, radial_only: bool = False, protection: float = 0.0, + use_exp_switch: bool = False, ): """Generate smooth environment matrix from atom coordinates and other context. @@ -64,6 +71,7 @@ def prod_env_mat( - rcut_smth: Smooth hyper-parameter for pair force & energy. - radial_only: Whether to return a full description or a radial-only descriptor. - protection: Protection parameter to prevent division by zero errors during calculations. + - use_exp_switch: Whether to use the exponential switch function. Returns ------- @@ -76,6 +84,7 @@ def prod_env_mat( rcut_smth, radial_only, protection=protection, + use_exp_switch=use_exp_switch, ) # shape [n_atom, dim, 4 or 1] t_avg = mean[atype] # [n_atom, dim, 4 or 1] t_std = stddev[atype] # [n_atom, dim, 4 or 1] diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 991413180a..c38a9c1e40 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -186,6 +186,7 @@ def __init__( precision: str = "float64", fix_stat_std: float = 0.3, smooth_edge_update: bool = False, + use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, optim_update: bool = True, @@ -220,6 +221,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor @@ -404,6 +406,7 @@ def forward( self.e_rcut, self.e_rcut_smth, protection=self.env_protection, + use_exp_switch=self.use_exp_switch, ) nlist_mask = nlist != -1 sw = torch.squeeze(sw, -1) @@ -425,6 +428,7 @@ def forward( self.a_rcut, self.a_rcut_smth, protection=self.env_protection, + use_exp_switch=self.use_exp_switch, ) a_nlist_mask = a_nlist != -1 a_sw = torch.squeeze(a_sw, -1) diff --git a/deepmd/pt/utils/preprocess.py b/deepmd/pt/utils/preprocess.py index 8ab489dede..7161bac692 100644 --- a/deepmd/pt/utils/preprocess.py +++ b/deepmd/pt/utils/preprocess.py @@ -15,3 +15,15 @@ def compute_smooth_weight(distance, rmin: float, rmax: float): uu2 = uu * uu vv = uu2 * uu * (-6 * uu2 + 15 * uu - 10) + 1 return vv + + +def compute_exp_sw(distance, rmin: float, rmax: float): + """Compute the exponential switch function for neighbor update.""" + if rmin >= rmax: + raise ValueError("rmin should be less than rmax.") + distance = torch.clamp(distance, min=0.0, max=rmax) + C = 20 + a = C / rmin + b = rmin + exp_sw = torch.exp(-torch.exp(a * (distance - b))) + return exp_sw diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index c6fbf2513e..0982e9b634 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1597,6 +1597,13 @@ def dpa3_repflow_args(): default=False, # For compatability. This will be True in the future doc=doc_smooth_edge_update, ), + Argument( + "use_exp_switch", + bool, + optional=True, + default=False, + alias=["use_env_envelope"], + ), Argument( "use_dynamic_sel", bool, From e9e39ad70f95a66f0e0f479a12e51796558a897f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 27 May 2025 15:08:53 +0800 Subject: [PATCH 5/8] feat add dist edge --- deepmd/dpmodel/descriptor/dpa3.py | 2 ++ deepmd/pt/model/descriptor/dpa3.py | 1 + deepmd/pt/model/descriptor/repflows.py | 11 ++++++++++- deepmd/utils/argcheck.py | 7 +++++++ 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 85b7980c23..7da88485a2 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -151,6 +151,7 @@ def __init__( skip_stat: bool = False, optim_update: bool = True, smooth_edge_update: bool = False, + edge_init_use_dist: bool = False, use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, @@ -180,6 +181,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.edge_init_use_dist = edge_init_use_dist self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index de7b25749d..16e9022baf 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -150,6 +150,7 @@ def init_subclass_params(sub_data, sub_class): fix_stat_std=self.repflow_args.fix_stat_std, optim_update=self.repflow_args.optim_update, smooth_edge_update=self.repflow_args.smooth_edge_update, + edge_init_use_dist=self.repflow_args.edge_init_use_dist, use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index c38a9c1e40..954bc787fd 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -186,6 +186,7 @@ def __init__( precision: str = "float64", fix_stat_std: float = 0.3, smooth_edge_update: bool = False, + edge_init_use_dist: bool = False, use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, @@ -221,6 +222,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.edge_init_use_dist = edge_init_use_dist self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor @@ -450,6 +452,10 @@ def forward( # get edge and angle embedding input # nb x nloc x nnei x 1, nb x nloc x nnei x 3 edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + if self.edge_init_use_dist: + # nb x nloc x nnei x 1 + edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True) + # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 @@ -486,7 +492,10 @@ def forward( ) # get edge and angle embedding # nb x nloc x nnei x e_dim [OR] n_edge x e_dim - edge_ebd = self.act(self.edge_embd(edge_input)) + if not self.edge_init_use_dist: + edge_ebd = self.act(self.edge_embd(edge_input)) + else: + edge_ebd = self.edge_embd(edge_input) # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim angle_ebd = self.angle_embd(angle_input) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0982e9b634..8a43501220 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1597,6 +1597,13 @@ def dpa3_repflow_args(): default=False, # For compatability. This will be True in the future doc=doc_smooth_edge_update, ), + Argument( + "edge_init_use_dist", + bool, + optional=True, + default=False, + alias=["edge_use_dist"], + ), Argument( "use_exp_switch", bool, From 106c9730160ff2dff32a310cd40cdb92e6e485e3 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 18 Jul 2025 23:52:48 +0800 Subject: [PATCH 6/8] Use insert_or_assign for send_list in comm_dict Replaces comm_dict.insert with comm_dict.insert_or_assign for the 'send_list' key in both DeepPotPT.cc and DeepSpinPT.cc. This ensures that the value is updated if the key already exists, preventing potential issues with duplicate key insertion. --- source/api_cc/src/DeepPotPT.cc | 2 +- source/api_cc/src/DeepSpinPT.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 5f03f7c5cb..ee8df74273 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -197,7 +197,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0); torch::Tensor sendlist_tensor = torch::from_blob(lmp_list.sendlist, {total_send}, int32_option); - comm_dict.insert("send_list", sendlist_tensor); + comm_dict.insert_or_assign("send_list", sendlist_tensor); comm_dict.insert("send_proc", sendproc_tensor); comm_dict.insert("recv_proc", recvproc_tensor); comm_dict.insert("send_num", sendnum_tensor); diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index aa7bcd2657..c7784ba638 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -205,7 +205,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, torch::Tensor sendlist_tensor = torch::from_blob(lmp_list.sendlist, {total_send}, int32_option); torch::Tensor has_spin = torch::tensor({1}, int32_option); - comm_dict.insert("send_list", sendlist_tensor); + comm_dict.insert_or_assign("send_list", sendlist_tensor); comm_dict.insert("send_proc", sendproc_tensor); comm_dict.insert("recv_proc", recvproc_tensor); comm_dict.insert("send_num", sendnum_tensor); From 8fdc524402493075bbee930b7abd565a487ad06b Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 20 Jul 2025 13:56:27 +0800 Subject: [PATCH 7/8] Use insert_or_assign for comm_dict tensor assignments Replaces comm_dict.insert with comm_dict.insert_or_assign for all tensor assignments in DeepPotPT.cc and DeepSpinPT.cc. This ensures that existing keys are updated rather than causing errors or duplications, improving robustness when keys may already exist. --- source/api_cc/src/DeepPotPT.cc | 10 +++++----- source/api_cc/src/DeepSpinPT.cc | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index ee8df74273..0f3a72b87f 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -198,11 +198,11 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, torch::Tensor sendlist_tensor = torch::from_blob(lmp_list.sendlist, {total_send}, int32_option); comm_dict.insert_or_assign("send_list", sendlist_tensor); - comm_dict.insert("send_proc", sendproc_tensor); - comm_dict.insert("recv_proc", recvproc_tensor); - comm_dict.insert("send_num", sendnum_tensor); - comm_dict.insert("recv_num", recvnum_tensor); - comm_dict.insert("communicator", communicator_tensor); + comm_dict.insert_or_assign("send_proc", sendproc_tensor); + comm_dict.insert_or_assign("recv_proc", recvproc_tensor); + comm_dict.insert_or_assign("send_num", sendnum_tensor); + comm_dict.insert_or_assign("recv_num", recvnum_tensor); + comm_dict.insert_or_assign("communicator", communicator_tensor); } if (lmp_list.mapping) { std::vector mapping(nall_real); diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index c7784ba638..d1a455e3f5 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -206,12 +206,12 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, torch::from_blob(lmp_list.sendlist, {total_send}, int32_option); torch::Tensor has_spin = torch::tensor({1}, int32_option); comm_dict.insert_or_assign("send_list", sendlist_tensor); - comm_dict.insert("send_proc", sendproc_tensor); - comm_dict.insert("recv_proc", recvproc_tensor); - comm_dict.insert("send_num", sendnum_tensor); - comm_dict.insert("recv_num", recvnum_tensor); - comm_dict.insert("communicator", communicator_tensor); - comm_dict.insert("has_spin", has_spin); + comm_dict.insert_or_assign("send_proc", sendproc_tensor); + comm_dict.insert_or_assign("recv_proc", recvproc_tensor); + comm_dict.insert_or_assign("send_num", sendnum_tensor); + comm_dict.insert_or_assign("recv_num", recvnum_tensor); + comm_dict.insert_or_assign("communicator", communicator_tensor); + comm_dict.insert_or_assign("has_spin", has_spin); } } at::Tensor firstneigh = createNlistTensor2(nlist_data.jlist); From 6371ea6c7e5c04512c795b1263691162154f5a08 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:12:51 +0800 Subject: [PATCH 8/8] Update repformers.py --- deepmd/pt/model/descriptor/repformers.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 82773d1a78..5aeb4abe09 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -394,7 +394,6 @@ def forward( ): if comm_dict is None: assert mapping is not None - assert extended_atype_embd is not None nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] @@ -418,13 +417,9 @@ def forward( sw = sw.masked_fill(~nlist_mask, 0.0) # [nframes, nloc, tebd_dim] - if comm_dict is None: - assert isinstance(extended_atype_embd, torch.Tensor) # for jit - atype_embd = extended_atype_embd[:, :nloc, :] - assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] - else: - atype_embd = extended_atype_embd - assert isinstance(atype_embd, torch.Tensor) # for jit + assert extended_atype_embd is not None + atype_embd = extended_atype_embd[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] g1 = self.act(atype_embd) ng1 = g1.shape[-1] # nb x nloc x nnei x 1, nb x nloc x nnei x 3