From 38282175294ac795ac3061c3764c955dc3983b2a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 26 Nov 2025 15:10:37 +0800 Subject: [PATCH 1/4] feat(pt): type embedding can still be compress even if attn_layer != 0 --- deepmd/pt/model/descriptor/dpa1.py | 73 +++--- deepmd/pt/model/descriptor/dpa2.py | 54 ++--- deepmd/pt/model/descriptor/se_atten.py | 17 +- deepmd/utils/argcheck.py | 2 +- doc/model/train-se-atten.md | 4 +- .../pt/test_model_compression_se_atten.py | 220 +++++++++++++++++- 6 files changed, 301 insertions(+), 69 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 78a277881c..264f13fdd3 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import warnings from typing import ( Any, Callable, @@ -304,7 +305,8 @@ def __init__( self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias self.type_map = type_map - self.compress = False + self.tebd_compress = False + self.geo_compress = False self.type_embedding = TypeEmbedNet( ntypes, tebd_dim, @@ -592,12 +594,18 @@ def enable_compression( check_frequency The overflow check frequency """ - # do some checks before the mocel compression process - if self.compress: + # do some checks before the model compression process + if self.tebd_compress or self.geo_compress: raise ValueError("Compression is already enabled.") + + assert self.tebd_input_mode != "strip", ( + "Cannot compress model when tebd_input_mode == 'strip'" + ) + assert not self.se_atten.resnet_dt, ( "Model compression error: descriptor resnet_dt must be false!" ) + for tt in self.se_atten.exclude_types: if (tt[0] not in range(self.se_atten.ntypes)) or ( tt[1] not in range(self.se_atten.ntypes) @@ -609,6 +617,7 @@ def enable_compression( + str(self.se_atten.ntypes) + "!" ) + if ( self.se_atten.ntypes * self.se_atten.ntypes - len(self.se_atten.exclude_types) @@ -618,38 +627,38 @@ def enable_compression( "Empty embedding-nets are not supported in model compression!" ) - if self.se_atten.attn_layer != 0: - raise RuntimeError("Cannot compress model when attention layer is not 0.") - - if self.tebd_input_mode != "strip": - raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'") - - data = self.serialize() - self.table = DPTabulate( - self, - data["neuron"], - data["type_one_side"], - data["exclude_types"], - ActivationFn(data["activation_function"]), - ) - self.table_config = [ - table_extrapolate, - table_stride_1, - table_stride_2, - check_frequency, - ] - self.lower, self.upper = self.table.build( - min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 - ) - - self.se_atten.enable_compression( - self.table.data, self.table_config, self.lower, self.upper - ) - # Enable type embedding compression self.se_atten.type_embedding_compression(self.type_embedding) + self.tebd_compress = True + + if self.se_atten.attn_layer == 0: + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) - self.compress = True + self.se_atten.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.geo_compress = True + else: + warnings.warn( + "Attention layer is not 0, only type embedding is compressed. Geometric part is not compressed.", + UserWarning, + ) def forward( self, diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 8985a92196..48dd09fdbf 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import warnings from typing import ( Any, Callable, @@ -938,38 +939,39 @@ def enable_compression( "Repinit empty embedding-nets are not supported in model compression!" ) - if self.repinit.attn_layer != 0: - raise RuntimeError( - "Cannot compress model when repinit attention layer is not 0." - ) - if self.repinit.tebd_input_mode != "strip": raise RuntimeError( "Cannot compress model when repinit tebd_input_mode == 'concat'" ) - # repinit doesn't have a serialize method - data = self.serialize() - self.table = DPTabulate( - self, - data["repinit_args"]["neuron"], - data["repinit_args"]["type_one_side"], - data["exclude_types"], - ActivationFn(data["repinit_args"]["activation_function"]), - ) - self.table_config = [ - table_extrapolate, - table_stride_1, - table_stride_2, - check_frequency, - ] - self.lower, self.upper = self.table.build( - min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 - ) + if self.repinit.attn_layer == 0: + # repinit doesn't have a serialize method + data = self.serialize() + self.table = DPTabulate( + self, + data["repinit_args"]["neuron"], + data["repinit_args"]["type_one_side"], + data["exclude_types"], + ActivationFn(data["repinit_args"]["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) - self.repinit.enable_compression( - self.table.data, self.table_config, self.lower, self.upper - ) + self.repinit.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + else: + warnings.warn( + "Attention layer is not 0, only type embedding is compressed. Geometric part is not compressed.", + UserWarning, + ) # Enable type embedding compression for repinit (se_atten) self.repinit.type_embedding_compression(self.type_embedding) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 30d6024e60..7c1de6146a 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -275,9 +275,10 @@ def __init__( self.filter_layers_strip = filter_layers_strip self.stats = None - # For geometric compression - self.compress = False + self.tebd_compress = False + self.geo_compress = False self.is_sorted = False + # For geometric compression self.compress_info = nn.ParameterList( [nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))] ) @@ -452,7 +453,7 @@ def enable_compression( device="cpu", ) self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec) - self.compress = True + self.geo_compress = True def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: """Enable type embedding compression for strip mode. @@ -504,6 +505,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: del self.type_embd_data self.register_buffer("type_embd_data", embd_tensor) + self.tebd_compress = True + def forward( self, nlist: torch.Tensor, @@ -630,7 +633,7 @@ def forward( # nf x (nl x nnei) nei_type = torch.gather(extended_atype, dim=1, index=nlist_index) if self.type_one_side: - if self.compress: + if self.tebd_compress: tt_full = self.type_embd_data else: # (ntypes+1, tebd_dim) -> (ntypes+1, ng) @@ -644,7 +647,7 @@ def forward( idx_j = nei_type.view(-1) # (nf x nl x nnei) idx = (idx_i + idx_j).to(torch.long) - if self.compress: + if self.tebd_compress: # ((ntypes+1)^2, ng) tt_full = self.type_embd_data else: @@ -671,7 +674,7 @@ def forward( gg_t = gg_t.reshape(nfnl, nnei, ng) if self.smooth: gg_t = gg_t * sw.reshape(-1, self.nnei, 1) - if self.compress: + if self.geo_compress: ss = ss.reshape(-1, 1) gg_t = gg_t.reshape(-1, gg_t.size(-1)) xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( @@ -719,7 +722,7 @@ def forward( return ( result.view(nframes, nloc, self.filter_neuron[-1] * self.axis_neuron), gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]) - if not self.compress + if not self.geo_compress else None, dmatrix.view(nframes, nloc, self.nnei, 4)[..., 1:], rot_mat.view(nframes, nloc, self.filter_neuron[-1], 3), diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0f7acb4266..a58180051e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -505,7 +505,7 @@ def descrpt_se_atten_common_args() -> list[Argument]: doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." doc_attn = "The length of hidden vectors in attention layers" - doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and tebd_input_mode=='strip'" + doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` works for any attn_layer value when tebd_input_mode=='strip'. When attn_layer!=0, only type embedding is compressed, geometric parts are not compressed." doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates" doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix" diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 6c0ca0817c..ad476caa67 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -134,7 +134,7 @@ You can use descriptor `"se_atten_v2"` and is not allowed to set `tebd_input_mod Practical evidence demonstrates that `"se_atten_v2"` offers better and more stable performance compared to `"se_atten"`. -Notice: Model compression for the `se_atten_v2` descriptor is exclusively designed for models with the training parameter {ref}`attn_layer ` set to 0. +Notice: Model compression for the `se_atten_v2` descriptor is designed for models with any `attn_layer` value. When `attn_layer` is not 0, only type embedding will be compressed, while the geometric part will not be compressed. ## Type embedding @@ -182,7 +182,7 @@ DPA-1 supports both the [standard data format](../data/system.md) and the [mixed ## Model compression -Model compression is supported only when there is no attention layer (`attn_layer` is 0) and `tebd_input_mode` is `strip`. +Model compression is supported for any `attn_layer` value when `tebd_input_mode` is `strip`. When `attn_layer` is not 0, only type embedding will be compressed, while the geometric part will not be compressed. ## Training example diff --git a/source/tests/pt/test_model_compression_se_atten.py b/source/tests/pt/test_model_compression_se_atten.py index 034d847656..41818f81d7 100644 --- a/source/tests/pt/test_model_compression_se_atten.py +++ b/source/tests/pt/test_model_compression_se_atten.py @@ -39,7 +39,7 @@ def _init_models(): INPUT = str(tests_path / "input.json") jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) - # Configure se_atten descriptor with attn_layer=0 for compression compatibility + # Configure se_atten descriptor with attn_layer=0 for full compression compatibility (both type embedding and geometric parts) jdata["model"]["descriptor"] = { "type": "se_atten_v2", "sel": 120, @@ -123,6 +123,54 @@ def _init_models_exclude_types(): return INPUT, frozen_model, compressed_model +def _init_models_nonzero_attn_layer(): + """Initialize models with attn_layer > 0 for partial compression testing.""" + suffix = "-nonzero-attn" + data_file = str(tests_path / os.path.join("model_compression", "data")) + frozen_model = str(tests_path / f"dp-original-se-atten{suffix}.pth") + compressed_model = str(tests_path / f"dp-compressed-se-atten{suffix}.pth") + INPUT = str(tests_path / f"input{suffix}.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + + # Configure se_atten descriptor with attn_layer=2 for partial compression + # Only type embedding will be compressed, geometric parts (attention layers) will not + jdata["model"]["descriptor"] = { + "type": "se_atten_v2", + "sel": 120, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + "attn": 128, + "attn_layer": 2, # Non-zero attention layer for partial compression testing + "attn_dotr": True, + "attn_mask": False, + "precision": "float64", + } + + jdata["training"]["training_data"]["systems"] = data_file + with open(INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + ret = run_dp("dp --pt train " + INPUT) + np.testing.assert_equal(ret, 0, "DP train failed!") + ret = run_dp("dp --pt freeze -o " + frozen_model) + np.testing.assert_equal(ret, 0, "DP freeze failed!") + ret = run_dp( + "dp --pt compress " + + " -i " + + frozen_model + + " -o " + + compressed_model + + " -t " + + INPUT + ) + np.testing.assert_equal(ret, 0, "DP model compression failed!") + return INPUT, frozen_model, compressed_model + + def _init_models_skip_neighbor_stat(): suffix = "-skip-neighbor-stat" data_file = str(tests_path / os.path.join("model_compression", "data")) @@ -177,6 +225,9 @@ def setUpModule() -> None: INPUT_ET, \ FROZEN_MODEL_ET, \ COMPRESSED_MODEL_ET, \ + INPUT_NONZERO_ATTN, \ + FROZEN_MODEL_NONZERO_ATTN, \ + COMPRESSED_MODEL_NONZERO_ATTN, \ FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \ COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() @@ -184,6 +235,9 @@ def setUpModule() -> None: _init_models_skip_neighbor_stat() ) INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() + INPUT_NONZERO_ATTN, FROZEN_MODEL_NONZERO_ATTN, COMPRESSED_MODEL_NONZERO_ATTN = ( + _init_models_nonzero_attn_layer() + ) def tearDownModule() -> None: @@ -198,6 +252,10 @@ def tearDownModule() -> None: _file_delete(INPUT_ET) _file_delete(FROZEN_MODEL_ET) _file_delete(COMPRESSED_MODEL_ET) + # Clean up files created by _init_models_nonzero_attn_layer + _file_delete(INPUT_NONZERO_ATTN) + _file_delete(FROZEN_MODEL_NONZERO_ATTN) + _file_delete(COMPRESSED_MODEL_NONZERO_ATTN) # Clean up other artifacts _file_delete("out.json") _file_delete("input_v2_compat.json") @@ -797,5 +855,165 @@ def test_2frame_atm(self) -> None: np.testing.assert_almost_equal(vv0, vv1, default_places) +class TestDeepPotATNonZeroAttnLayer(unittest.TestCase): + """Test model compression with attn_layer > 0 (partial compression).""" + + @classmethod + def setUpClass(cls) -> None: + cls.dp_original = DeepEval(FROZEN_MODEL_NONZERO_ATTN) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL_NONZERO_ATTN) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def test_attrs(self) -> None: + """Test model attributes are consistent between original and compressed models.""" + self.assertEqual(self.dp_original.get_ntypes(), 2) + self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) + self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_original.get_dim_fparam(), 0) + self.assertEqual(self.dp_original.get_dim_aparam(), 0) + + self.assertEqual(self.dp_compressed.get_ntypes(), 2) + self.assertAlmostEqual( + self.dp_compressed.get_rcut(), 6.0, places=default_places + ) + self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) + self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) + + def test_1frame(self) -> None: + """Test single frame evaluation with partial compression.""" + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values - should be identical even with partial compression + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self) -> None: + """Test single frame atomic evaluation with partial compression.""" + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values - should be identical even with partial compression + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self) -> None: + """Test multi-frame atomic evaluation with partial compression.""" + coords2 = np.concatenate((self.coords, self.coords)) + box2 = np.concatenate((self.box, self.box)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, box2, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, box2, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values - should be identical even with partial compression + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_ase(self) -> None: + """Test ASE calculator integration with partial compression.""" + from ase import ( + Atoms, + ) + + from deepmd.calculator import ( + DP, + ) + + water0 = Atoms( + "OHHOHH", + positions=self.coords.reshape((-1, 3)), + cell=self.box.reshape((3, 3)), + calculator=DP(FROZEN_MODEL_NONZERO_ATTN), + ) + water1 = Atoms( + "OHHOHH", + positions=self.coords.reshape((-1, 3)), + cell=self.box.reshape((3, 3)), + calculator=DP(COMPRESSED_MODEL_NONZERO_ATTN), + ) + ee0 = water0.get_potential_energy() + ff0 = water0.get_forces() + ee1 = water1.get_potential_energy() + ff1 = water1.get_forces() + # nframes = 1 + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + + if __name__ == "__main__": unittest.main() From afe5dd4636dea8d3b29924e8582b60ca4f2ed832 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 27 Nov 2025 10:21:46 +0800 Subject: [PATCH 2/4] bug fix --- deepmd/pt/model/descriptor/dpa1.py | 5 ++--- deepmd/pt/model/descriptor/dpa2.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 264f13fdd3..7f600ccc2e 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -598,9 +598,8 @@ def enable_compression( if self.tebd_compress or self.geo_compress: raise ValueError("Compression is already enabled.") - assert self.tebd_input_mode != "strip", ( - "Cannot compress model when tebd_input_mode == 'strip'" - ) + if self.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") assert not self.se_atten.resnet_dt, ( "Model compression error: descriptor resnet_dt must be false!" diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 48dd09fdbf..583f18f2be 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -941,7 +941,7 @@ def enable_compression( if self.repinit.tebd_input_mode != "strip": raise RuntimeError( - "Cannot compress model when repinit tebd_input_mode == 'concat'" + "Cannot compress model when repinit tebd_input_mode != 'strip'" ) if self.repinit.attn_layer == 0: From 4206f3cb86376ff766d807d807d4b69197f13788 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 28 Nov 2025 10:10:40 +0800 Subject: [PATCH 3/4] doc --- doc/model/dpa2.md | 6 +++++- doc/model/train-se-atten.md | 12 ++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/doc/model/dpa2.md b/doc/model/dpa2.md index c8e60c514a..466a4de4f2 100644 --- a/doc/model/dpa2.md +++ b/doc/model/dpa2.md @@ -38,6 +38,10 @@ Type embedding is within this descriptor with the {ref}`tebd_dim ` is `strip`, but only the `repinit` part is compressed. +Model compression is supported when {ref}`repinit/tebd_input_mode ` is `strip`. + +- If {ref}`repinit/attn_layer ` is `0`, both the type embedding and geometric parts inside `repinit` are compressed. +- If `repinit/attn_layer` is not `0`, only the type embedding tables are compressed and the geometric attention layers remain as neural networks. + An example is given in `examples/water/dpa2/input_torch_compressible.json`. The performance improvement will be limited if other parts are more expensive. diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index ad476caa67..91a9d4932b 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -134,7 +134,9 @@ You can use descriptor `"se_atten_v2"` and is not allowed to set `tebd_input_mod Practical evidence demonstrates that `"se_atten_v2"` offers better and more stable performance compared to `"se_atten"`. -Notice: Model compression for the `se_atten_v2` descriptor is designed for models with any `attn_layer` value. When `attn_layer` is not 0, only type embedding will be compressed, while the geometric part will not be compressed. +:::{note} +Model compression support differs across backends. See [Model compression](#model-compression) for backend-specific requirements. +::: ## Type embedding @@ -182,7 +184,13 @@ DPA-1 supports both the [standard data format](../data/system.md) and the [mixed ## Model compression -Model compression is supported for any `attn_layer` value when `tebd_input_mode` is `strip`. When `attn_layer` is not 0, only type embedding will be compressed, while the geometric part will not be compressed. +### TensorFlow {{ tensorflow_icon }} + +Model compression is supported only when the descriptor attention depth {ref}`attn_layer ` is 0 and {ref}`tebd_input_mode ` is `"strip"`. Attention layers higher than 0 cannot be compressed in the TensorFlow implementation because the geometric part is tabulated from the static computation graph. + +### PyTorch {{ pytorch_icon }} {{ jax_icon }} {{ paddle_icon }} {{ dpmodel_icon }} + +Model compression is supported for any {ref}`attn_layer ` value when {ref}`tebd_input_mode ` is `"strip"`. When `attn_layer` is 0, both the type embedding and geometric parts are compressed. When `attn_layer` is not 0, only the type embedding is compressed while the geometric part keeps the neural network implementation (a warning is emitted during compression). ## Training example From 7026d9c297d58beb2860f1e12d8ec1dba795b3d4 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 30 Nov 2025 12:23:38 +0800 Subject: [PATCH 4/4] doc --- deepmd/utils/argcheck.py | 2 +- doc/model/train-se-atten.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a58180051e..5878ea473d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -505,7 +505,7 @@ def descrpt_se_atten_common_args() -> list[Argument]: doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." doc_attn = "The length of hidden vectors in attention layers" - doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` works for any attn_layer value when tebd_input_mode=='strip'. When attn_layer!=0, only type embedding is compressed, geometric parts are not compressed." + doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` works for any attn_layer value (for pytorch backend only, for other backends, attn_layer=0 is still needed to compress) when tebd_input_mode=='strip'. When attn_layer!=0, only type embedding is compressed, geometric parts are not compressed." doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates" doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix" diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 91a9d4932b..2e0c236cf6 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -188,7 +188,7 @@ DPA-1 supports both the [standard data format](../data/system.md) and the [mixed Model compression is supported only when the descriptor attention depth {ref}`attn_layer ` is 0 and {ref}`tebd_input_mode ` is `"strip"`. Attention layers higher than 0 cannot be compressed in the TensorFlow implementation because the geometric part is tabulated from the static computation graph. -### PyTorch {{ pytorch_icon }} {{ jax_icon }} {{ paddle_icon }} {{ dpmodel_icon }} +### PyTorch {{ pytorch_icon }} Model compression is supported for any {ref}`attn_layer ` value when {ref}`tebd_input_mode ` is `"strip"`. When `attn_layer` is 0, both the type embedding and geometric parts are compressed. When `attn_layer` is not 0, only the type embedding is compressed while the geometric part keeps the neural network implementation (a warning is emitted during compression).