From 71d1aead75f4bdc53480811593ee472435cb8fc7 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 18 Nov 2025 20:56:26 +0800 Subject: [PATCH 01/14] feat(pt): Implement type embedding compression for two-side mode for se_atten - Added functionality to enable type embedding compression in the class, specifically for two-side mode. - Introduced a new method to precompute type embedding network outputs for all type pair combinations, optimizing inference by using precomputed values. - Updated the method to utilize precomputed embeddings when compression is enabled, improving performance during inference. This enhancement allows for more efficient handling of type embeddings in the descriptor model. --- deepmd/pt/model/descriptor/dpa1.py | 6 +++ deepmd/pt/model/descriptor/se_atten.py | 57 +++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index e158dd3725..fb53746410 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -645,6 +645,12 @@ def enable_compression( self.se_atten.enable_compression( self.table.data, self.table_config, self.lower, self.upper ) + + # Enable type embedding compression only for two-side mode + # TODO: why not enable for one-side mode? (do not consider this for now) + if not self.se_atten.type_one_side: + self.se_atten.enable_type_embedding_compression(self.type_embedding) + self.compress = True def forward( diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index bfcb510810..08101d6bcf 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -27,6 +27,9 @@ MLPLayer, NetworkCollection, ) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, +) from deepmd.pt.utils import ( env, ) @@ -281,6 +284,9 @@ def __init__( self.compress_data = nn.ParameterList( [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] ) + # For type embedding compression (strip mode, two-side only) + self.compress_type_embd = False + self.two_side_embd_data = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -447,6 +453,50 @@ def enable_compression( self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec) self.compress = True + def enable_type_embedding_compression( + self, type_embedding_net: TypeEmbedNet + ) -> None: + """Enable type embedding compression for strip mode (two-side only). + + This method precomputes the type embedding network outputs for all possible + type pairs, following the same approach as TF backend's compression: + + TF approach: + 1. get_two_side_type_embedding(): creates (ntypes+1)^2 type pair combinations + 2. make_data(): applies embedding network to get precomputed outputs + 3. In forward: lookup precomputed values instead of real-time computation + + PyTorch implementation: + - Precomputes all (ntypes+1)^2 type pair embedding network outputs + - Stores in buffer for proper serialization and device management + - Uses lookup during inference to avoid redundant computations + + Parameters + ---------- + type_embedding_net : TypeEmbedNet + The type embedding network that provides get_full_embedding() method + """ + with torch.no_grad(): + # Get full type embedding: (ntypes+1) x t_dim + full_embd = type_embedding_net.get_full_embedding(env.DEVICE) + nt, t_dim = full_embd.shape + + # Create all type pair combinations [neighbor, center] + # for a fixed row i, all columns j have different neighbor types + embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim) + # for a fixed row i, all columns j share the same center type i + embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim) + two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape( + -1, t_dim * 2 + ) + + # Apply strip embedding network and store + # index logic: index = center_type * nt + neighbor_type + self.two_side_embd_data = self.filter_layers_strip.networks[0]( + two_side_embd + ).detach() + self.compress_type_embd = True + def forward( self, nlist: torch.Tensor, @@ -605,7 +655,12 @@ def forward( two_side_type_embedding = torch.cat( [type_embedding_nei, type_embedding_center], -1 ).reshape(-1, nt * 2) - tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) + if self.compress_type_embd and self.two_side_embd_data is not None: + tt_full = self.two_side_embd_data + else: + tt_full = self.filter_layers_strip.networks[0]( + two_side_type_embedding + ) # (nf x nl x nnei) x ng gg_t = torch.gather(tt_full, dim=0, index=idx) # (nf x nl) x nnei x ng From 14169fb8e8936444e97ef534021c30614c5ba49e Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 18 Nov 2025 21:12:55 +0800 Subject: [PATCH 02/14] refactor(pt): Simplify index calculation and enhance type embedding handling in se_atten - Streamlined the index calculation for neighbor types by removing unnecessary reshaping and expansion. - Improved clarity in the handling of type embeddings, ensuring that the logic for two-side embeddings remains intact while enhancing readability. - Maintained functionality for both compressed and uncompressed type embeddings, optimizing the inference process. These changes contribute to cleaner code and maintain the performance benefits introduced in previous commits. --- deepmd/pt/model/descriptor/se_atten.py | 41 +++++++++++--------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 08101d6bcf..2879e71a95 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -633,36 +633,31 @@ def forward( atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei] ).view(-1) idx_j = nei_type.view(-1) - # (nf x nl x nnei) x ng - idx = ( - (idx_i + idx_j) - .view(-1, 1) - .expand(-1, ng) - .type(torch.long) - .to(torch.long) - ) - # (ntypes) * ntypes * nt - type_embedding_nei = torch.tile( - type_embedding.view(1, ntypes_with_padding, nt), - [ntypes_with_padding, 1, 1], - ) - # ntypes * (ntypes) * nt - type_embedding_center = torch.tile( - type_embedding.view(ntypes_with_padding, 1, nt), - [1, ntypes_with_padding, 1], - ) - # (ntypes * ntypes) * (nt+nt) - two_side_type_embedding = torch.cat( - [type_embedding_nei, type_embedding_center], -1 - ).reshape(-1, nt * 2) + # (nf x nl x nnei) + idx = (idx_i + idx_j).to(torch.long) if self.compress_type_embd and self.two_side_embd_data is not None: + # (ntypes^2, ng) tt_full = self.two_side_embd_data else: + # (ntypes) * ntypes * nt + type_embedding_nei = torch.tile( + type_embedding.view(1, ntypes_with_padding, nt), + [ntypes_with_padding, 1, 1], + ) + # ntypes * (ntypes) * nt + type_embedding_center = torch.tile( + type_embedding.view(ntypes_with_padding, 1, nt), + [1, ntypes_with_padding, 1], + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = torch.cat( + [type_embedding_nei, type_embedding_center], -1 + ).reshape(-1, nt * 2) tt_full = self.filter_layers_strip.networks[0]( two_side_type_embedding ) # (nf x nl x nnei) x ng - gg_t = torch.gather(tt_full, dim=0, index=idx) + gg_t = tt_full[idx] # (nf x nl) x nnei x ng gg_t = gg_t.reshape(nfnl, nnei, ng) if self.smooth: From 85279758e2204c5bd3d8420212df32bf30c92b12 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 18 Nov 2025 22:43:27 +0800 Subject: [PATCH 03/14] refactor(pt): Streamline type embedding compression logic in dpa1 and se_atten - Simplified the type embedding compression process by consolidating methods and removing unnecessary conditions. - Enhanced clarity in the handling of one-side and two-side type embeddings, ensuring consistent functionality across both modes. - Updated comments for better understanding of the compression logic and its implications on performance. These changes contribute to cleaner code and improved maintainability of the descriptor model. --- deepmd/pt/model/descriptor/dpa1.py | 6 +- deepmd/pt/model/descriptor/se_atten.py | 82 ++++++++++++-------------- 2 files changed, 41 insertions(+), 47 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index fb53746410..78a277881c 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -646,10 +646,8 @@ def enable_compression( self.table.data, self.table_config, self.lower, self.upper ) - # Enable type embedding compression only for two-side mode - # TODO: why not enable for one-side mode? (do not consider this for now) - if not self.se_atten.type_one_side: - self.se_atten.enable_type_embedding_compression(self.type_embedding) + # Enable type embedding compression + self.se_atten.type_embedding_compression(self.type_embedding) self.compress = True diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 2879e71a95..f19c319457 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -275,7 +275,7 @@ def __init__( self.filter_layers_strip = filter_layers_strip self.stats = None - # add for compression + # For geometric compression self.compress = False self.is_sorted = False self.compress_info = nn.ParameterList( @@ -284,9 +284,8 @@ def __init__( self.compress_data = nn.ParameterList( [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] ) - # For type embedding compression (strip mode, two-side only) - self.compress_type_embd = False - self.two_side_embd_data = None + # For type embedding compression + self.type_embd_data = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -453,23 +452,12 @@ def enable_compression( self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec) self.compress = True - def enable_type_embedding_compression( - self, type_embedding_net: TypeEmbedNet - ) -> None: - """Enable type embedding compression for strip mode (two-side only). - - This method precomputes the type embedding network outputs for all possible - type pairs, following the same approach as TF backend's compression: - - TF approach: - 1. get_two_side_type_embedding(): creates (ntypes+1)^2 type pair combinations - 2. make_data(): applies embedding network to get precomputed outputs - 3. In forward: lookup precomputed values instead of real-time computation + def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: + """Enable type embedding compression for strip mode. - PyTorch implementation: - - Precomputes all (ntypes+1)^2 type pair embedding network outputs - - Stores in buffer for proper serialization and device management - - Uses lookup during inference to avoid redundant computations + Precomputes embedding network outputs for all type combinations: + - One-side: (ntypes+1) combinations (neighbor types only) + - Two-side: (ntypes+1)² combinations (neighbor x center type pairs) Parameters ---------- @@ -477,25 +465,31 @@ def enable_type_embedding_compression( The type embedding network that provides get_full_embedding() method """ with torch.no_grad(): - # Get full type embedding: (ntypes+1) x t_dim + # Get full type embedding: (ntypes+1) x tebd_dim full_embd = type_embedding_net.get_full_embedding(env.DEVICE) nt, t_dim = full_embd.shape - # Create all type pair combinations [neighbor, center] - # for a fixed row i, all columns j have different neighbor types - embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim) - # for a fixed row i, all columns j share the same center type i - embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim) - two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape( - -1, t_dim * 2 - ) - - # Apply strip embedding network and store - # index logic: index = center_type * nt + neighbor_type - self.two_side_embd_data = self.filter_layers_strip.networks[0]( - two_side_embd - ).detach() - self.compress_type_embd = True + if self.type_one_side: + # One-side: only neighbor types, much simpler! + # Precompute for all (ntypes+1) neighbor types + self.type_embd_data = self.filter_layers_strip.networks[0]( + full_embd + ).detach() + else: + # Two-side: all (ntypes+1)² type pair combinations + # Create [neighbor, center] combinations + # for a fixed row i, all columns j have different neighbor types + embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim) + # for a fixed row i, all columns j share the same center type i + embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim) + two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape( + -1, t_dim * 2 + ) + # Precompute for all type pairs + # Index formula: idx = center_type * nt + neighbor_type + self.type_embd_data = self.filter_layers_strip.networks[0]( + two_side_embd + ).detach() def forward( self, @@ -622,12 +616,14 @@ def forward( nlist_index = nlist.reshape(nb, nloc * nnei) # nf x (nl x nnei) nei_type = torch.gather(extended_atype, dim=1, index=nlist_index) - # (nf x nl x nnei) x ng - nei_type_index = nei_type.view(-1, 1).expand(-1, ng).type(torch.long) if self.type_one_side: - tt_full = self.filter_layers_strip.networks[0](type_embedding) - # (nf x nl x nnei) x ng - gg_t = torch.gather(tt_full, dim=0, index=nei_type_index) + if self.type_embd_data is not None: + tt_full = self.type_embd_data + else: + # (ntypes+1, tebd_dim) -> (ntypes+1, ng) + tt_full = self.filter_layers_strip.networks[0](type_embedding) + # (nf*nl*nnei,) -> (nf*nl*nnei, ng) + gg_t = tt_full[nei_type.view(-1).type(torch.long)] else: idx_i = torch.tile( atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei] @@ -635,9 +631,9 @@ def forward( idx_j = nei_type.view(-1) # (nf x nl x nnei) idx = (idx_i + idx_j).to(torch.long) - if self.compress_type_embd and self.two_side_embd_data is not None: + if self.type_embd_data is not None: # (ntypes^2, ng) - tt_full = self.two_side_embd_data + tt_full = self.type_embd_data else: # (ntypes) * ntypes * nt type_embedding_nei = torch.tile( From 4f72994b7f7cedfaa5e336c86decd3a42e0c66bf Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 18 Nov 2025 23:16:35 +0800 Subject: [PATCH 04/14] fix(pt): Enforce conditions for type embedding compression in se_atten - Added runtime checks to ensure type embedding compression only operates in "strip" mode. - Introduced validation for the initialization of `filter_layers_strip` to prevent runtime errors. - Updated comments to clarify the expected dimensions of type embeddings, enhancing code readability. These changes improve error handling and maintain the integrity of the type embedding compression logic. --- deepmd/pt/model/descriptor/se_atten.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index f19c319457..28c060ea4f 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -464,6 +464,13 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: type_embedding_net : TypeEmbedNet The type embedding network that provides get_full_embedding() method """ + if self.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") + if self.filter_layers_strip is None: + raise RuntimeError( + "filter_layers_strip must be initialized for type embedding compression" + ) + with torch.no_grad(): # Get full type embedding: (ntypes+1) x tebd_dim full_embd = type_embedding_net.get_full_embedding(env.DEVICE) @@ -632,20 +639,20 @@ def forward( # (nf x nl x nnei) idx = (idx_i + idx_j).to(torch.long) if self.type_embd_data is not None: - # (ntypes^2, ng) + # ((ntypes+1)^2, ng) tt_full = self.type_embd_data else: - # (ntypes) * ntypes * nt + # ((ntypes+1)^2) * (ntypes+1)^2 * nt type_embedding_nei = torch.tile( type_embedding.view(1, ntypes_with_padding, nt), [ntypes_with_padding, 1, 1], ) - # ntypes * (ntypes) * nt + # (ntypes+1)^2 * ((ntypes+1)^2) * nt type_embedding_center = torch.tile( type_embedding.view(ntypes_with_padding, 1, nt), [1, ntypes_with_padding, 1], ) - # (ntypes * ntypes) * (nt+nt) + # ((ntypes+1)^2 * (ntypes+1)^2) * (nt+nt) two_side_type_embedding = torch.cat( [type_embedding_nei, type_embedding_center], -1 ).reshape(-1, nt * 2) From c42ffe5c36228afea211a5cc0d42fb09801163ff Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 19 Nov 2025 16:18:32 +0800 Subject: [PATCH 05/14] refactor(pt): Optimize type embedding handling in se_atten - Replaced direct assignment of `type_embd_data` with a call to `register_buffer` for better memory management. - Improved clarity by using a temporary variable `embd_tensor` for storing the output of the embedding network before registration. - Maintained functionality for both one-side and two-side type embeddings, ensuring consistent behavior across modes. These changes enhance the maintainability and performance of the descriptor model. --- deepmd/pt/model/descriptor/se_atten.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 28c060ea4f..23f08c0409 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -479,9 +479,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: if self.type_one_side: # One-side: only neighbor types, much simpler! # Precompute for all (ntypes+1) neighbor types - self.type_embd_data = self.filter_layers_strip.networks[0]( - full_embd - ).detach() + embd_tensor = self.filter_layers_strip.networks[0](full_embd).detach() + self.register_buffer("type_embd_data", embd_tensor) else: # Two-side: all (ntypes+1)² type pair combinations # Create [neighbor, center] combinations @@ -494,9 +493,10 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: ) # Precompute for all type pairs # Index formula: idx = center_type * nt + neighbor_type - self.type_embd_data = self.filter_layers_strip.networks[0]( + embd_tensor = self.filter_layers_strip.networks[0]( two_side_embd ).detach() + self.register_buffer("type_embd_data", embd_tensor) def forward( self, From bffe4869ace9290773f8875feac917f1b7b58baf Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 20 Nov 2025 15:19:16 +0800 Subject: [PATCH 06/14] fix(pt): Specify type for type_embd_data in DescrptBlockSeAtten --- deepmd/pt/model/descriptor/se_atten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 23f08c0409..18d39cae92 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -285,7 +285,7 @@ def __init__( [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] ) # For type embedding compression - self.type_embd_data = None + self.type_embd_data: Optional[torch.Tensor] = None def get_rcut(self) -> float: """Returns the cut-off radius.""" From 83165db9cfd738cefb3926062d597c172b2929d4 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 20 Nov 2025 15:23:46 +0800 Subject: [PATCH 07/14] fix(pt): Remove existing type_embd_data before registering new buffer in DescrptBlockSeAtten --- deepmd/pt/model/descriptor/se_atten.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 18d39cae92..c388ec2ce7 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -480,6 +480,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: # One-side: only neighbor types, much simpler! # Precompute for all (ntypes+1) neighbor types embd_tensor = self.filter_layers_strip.networks[0](full_embd).detach() + if hasattr(self, "type_embd_data"): + del self.type_embd_data self.register_buffer("type_embd_data", embd_tensor) else: # Two-side: all (ntypes+1)² type pair combinations @@ -496,6 +498,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: embd_tensor = self.filter_layers_strip.networks[0]( two_side_embd ).detach() + if hasattr(self, "type_embd_data"): + del self.type_embd_data self.register_buffer("type_embd_data", embd_tensor) def forward( From 2d4602759a6fa0aebf5934feab4a423d79bcc427 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 21 Nov 2025 15:12:55 +0800 Subject: [PATCH 08/14] test(pt): add compression test for se_atten --- .../pt/test_model_compression_se_atten.py | 811 ++++++++++++++++++ 1 file changed, 811 insertions(+) create mode 100644 source/tests/pt/test_model_compression_se_atten.py diff --git a/source/tests/pt/test_model_compression_se_atten.py b/source/tests/pt/test_model_compression_se_atten.py new file mode 100644 index 0000000000..e7fefa22ef --- /dev/null +++ b/source/tests/pt/test_model_compression_se_atten.py @@ -0,0 +1,811 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest + +import numpy as np + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.infer.deep_eval import ( + DeepEval, +) + +from .common import ( + j_loader, + run_dp, + tests_path, +) + +if GLOBAL_NP_FLOAT_PRECISION == np.float32: + default_places = 4 +else: + default_places = 10 + + +def _file_delete(file) -> None: + if os.path.isdir(file): + os.rmdir(file) + elif os.path.isfile(file): + os.remove(file) + + +def _init_models(): + data_file = str(tests_path / os.path.join("model_compression", "data")) + frozen_model = str(tests_path / "dp-original-se-atten.pth") + compressed_model = str(tests_path / "dp-compressed-se-atten.pth") + INPUT = str(tests_path / "input.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + + # Configure se_atten descriptor with attn_layer=0 for compression compatibility + jdata["model"]["descriptor"] = { + "type": "se_atten", + "sel": 120, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 0, # Must be 0 for compression + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "type_one_side": True, + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "seed": 1, + "tebd_input_mode": "strip", + } + + # Add type_embedding configuration + jdata["model"]["type_embedding"] = { + "neuron": [8, 16, 32], + "resnet_dt": False, + "seed": 1, + } + + jdata["training"]["training_data"]["systems"] = data_file + with open(INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + ret = run_dp("dp --pt train " + INPUT) + np.testing.assert_equal(ret, 0, "DP train failed!") + ret = run_dp("dp --pt freeze -o " + frozen_model) + np.testing.assert_equal(ret, 0, "DP freeze failed!") + ret = run_dp( + "dp --pt compress " + " -i " + frozen_model + " -o " + compressed_model + ) + np.testing.assert_equal(ret, 0, "DP model compression failed!") + return INPUT, frozen_model, compressed_model + + +def _init_models_exclude_types(): + data_file = str(tests_path / os.path.join("model_compression", "data")) + frozen_model = str(tests_path / "dp-original-se-atten-exclude-types.pth") + compressed_model = str(tests_path / "dp-compressed-se-atten-exclude-types.pth") + INPUT = str(tests_path / "input.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + + # Configure se_atten descriptor with exclude_types + jdata["model"]["descriptor"] = { + "type": "se_atten", + "exclude_types": [[0, 1]], + "sel": 120, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 0, # Must be 0 for compression + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "type_one_side": True, + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "seed": 1, + "tebd_input_mode": "strip", + } + + # Add type_embedding configuration + jdata["model"]["type_embedding"] = { + "neuron": [8, 16, 32], + "resnet_dt": False, + "seed": 1, + } + + jdata["training"]["training_data"]["systems"] = data_file + with open(INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + ret = run_dp("dp --pt train " + INPUT) + np.testing.assert_equal(ret, 0, "DP train failed!") + ret = run_dp("dp --pt freeze -o " + frozen_model) + np.testing.assert_equal(ret, 0, "DP freeze failed!") + ret = run_dp( + "dp --pt compress " + " -i " + frozen_model + " -o " + compressed_model + ) + np.testing.assert_equal(ret, 0, "DP model compression failed!") + return INPUT, frozen_model, compressed_model + + +def _init_models_skip_neighbor_stat(): + suffix = "-skip-neighbor-stat" + data_file = str(tests_path / os.path.join("model_compression", "data")) + frozen_model = str(tests_path / f"dp-original-se-atten{suffix}.pth") + compressed_model = str(tests_path / f"dp-compressed-se-atten{suffix}.pth") + INPUT = str(tests_path / "input.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + + # Configure se_atten descriptor + jdata["model"]["descriptor"] = { + "type": "se_atten", + "sel": 120, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 0, # Must be 0 for compression + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "type_one_side": True, + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "seed": 1, + "tebd_input_mode": "strip", + } + + # Add type_embedding configuration + jdata["model"]["type_embedding"] = { + "neuron": [8, 16, 32], + "resnet_dt": False, + "seed": 1, + } + + jdata["training"]["training_data"]["systems"] = data_file + with open(INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + ret = run_dp("dp --pt train " + INPUT + " --skip-neighbor-stat") + np.testing.assert_equal(ret, 0, "DP train failed!") + ret = run_dp("dp --pt freeze -o " + frozen_model) + np.testing.assert_equal(ret, 0, "DP freeze failed!") + ret = run_dp( + "dp --pt compress " + + " -i " + + frozen_model + + " -o " + + compressed_model + + " -t " + + INPUT + ) + np.testing.assert_equal(ret, 0, "DP model compression failed!") + return INPUT, frozen_model, compressed_model + + +def setUpModule() -> None: + global \ + INPUT, \ + FROZEN_MODEL, \ + COMPRESSED_MODEL, \ + INPUT_ET, \ + FROZEN_MODEL_ET, \ + COMPRESSED_MODEL_ET, \ + FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \ + COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() + _, FROZEN_MODEL_SKIP_NEIGHBOR_STAT, COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = ( + _init_models_skip_neighbor_stat() + ) + INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() + + +class TestDeepPotAPBC(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.dp_original = DeepEval(FROZEN_MODEL) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def test_attrs(self) -> None: + self.assertEqual(self.dp_original.get_ntypes(), 2) + self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) + self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_original.get_dim_fparam(), 0) + self.assertEqual(self.dp_original.get_dim_aparam(), 0) + + self.assertEqual(self.dp_compressed.get_ntypes(), 2) + self.assertAlmostEqual( + self.dp_compressed.get_rcut(), 6.0, places=default_places + ) + self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) + self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) + + def test_1frame(self) -> None: + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self) -> None: + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self) -> None: + coords2 = np.concatenate((self.coords, self.coords)) + box2 = np.concatenate((self.box, self.box)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, box2, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, box2, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + +class TestDeepPotANoPBC(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.dp_original = DeepEval(FROZEN_MODEL) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = None + + def test_1frame(self) -> None: + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self) -> None: + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self) -> None: + coords2 = np.concatenate((self.coords, self.coords)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + +class TestDeepPotALargeBoxNoPBC(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.dp_original = DeepEval(FROZEN_MODEL) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([19.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def test_1frame(self) -> None: + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self) -> None: + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_ase(self) -> None: + from ase import ( + Atoms, + ) + + from deepmd.tf.calculator import ( + DP, + ) + + water0 = Atoms( + "OHHOHH", + positions=self.coords.reshape((-1, 3)), + cell=self.box.reshape((3, 3)), + calculator=DP(FROZEN_MODEL), + ) + water1 = Atoms( + "OHHOHH", + positions=self.coords.reshape((-1, 3)), + cell=self.box.reshape((3, 3)), + calculator=DP(COMPRESSED_MODEL), + ) + ee0 = water0.get_potential_energy() + ff0 = water0.get_forces() + ee1 = water1.get_potential_energy() + ff1 = water1.get_forces() + # nframes = 1 + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + + +class TestDeepPotAPBCExcludeTypes(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.dp_original = DeepEval(FROZEN_MODEL_ET) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL_ET) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + @classmethod + def tearDownClass(cls) -> None: + _file_delete(INPUT_ET) + _file_delete(FROZEN_MODEL_ET) + _file_delete(COMPRESSED_MODEL_ET) + _file_delete("out.json") + _file_delete("compress.json") + _file_delete("checkpoint") + _file_delete("lcurve.out") + _file_delete("model.ckpt") + _file_delete("model-compression/checkpoint") + _file_delete("model-compression") + + def test_attrs(self) -> None: + self.assertEqual(self.dp_original.get_ntypes(), 2) + self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) + self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_original.get_dim_fparam(), 0) + self.assertEqual(self.dp_original.get_dim_aparam(), 0) + + self.assertEqual(self.dp_compressed.get_ntypes(), 2) + self.assertAlmostEqual( + self.dp_compressed.get_rcut(), 6.0, places=default_places + ) + self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) + self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) + + def test_1frame(self) -> None: + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self) -> None: + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self) -> None: + coords2 = np.concatenate((self.coords, self.coords)) + box2 = np.concatenate((self.box, self.box)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, box2, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, box2, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + +class TestSkipNeighborStat(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def test_attrs(self) -> None: + self.assertEqual(self.dp_original.get_ntypes(), 2) + self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) + self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_original.get_dim_fparam(), 0) + self.assertEqual(self.dp_original.get_dim_aparam(), 0) + + self.assertEqual(self.dp_compressed.get_ntypes(), 2) + self.assertAlmostEqual( + self.dp_compressed.get_rcut(), 6.0, places=default_places + ) + self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) + self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) + + def test_1frame(self) -> None: + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self) -> None: + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self) -> None: + coords2 = np.concatenate((self.coords, self.coords)) + box2 = np.concatenate((self.box, self.box)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, box2, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, box2, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + +if __name__ == "__main__": + unittest.main() From 3606359721e91c88b87c9fc24f03448abd4f2597 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 21 Nov 2025 15:26:42 +0800 Subject: [PATCH 09/14] fix(pt): update import path for DP calculator in test_model_compression_se_atten --- source/tests/pt/test_model_compression_se_atten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/test_model_compression_se_atten.py b/source/tests/pt/test_model_compression_se_atten.py index e7fefa22ef..c86cd7f1fc 100644 --- a/source/tests/pt/test_model_compression_se_atten.py +++ b/source/tests/pt/test_model_compression_se_atten.py @@ -521,7 +521,7 @@ def test_ase(self) -> None: Atoms, ) - from deepmd.tf.calculator import ( + from deepmd.calculator import ( DP, ) From cdd7272008b0f9ede9327f78643dff0556dcc68a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 21 Nov 2025 18:05:09 +0800 Subject: [PATCH 10/14] fix(pt): enhance file deletion logic and clean up artifacts in tests --- .../pt/test_model_compression_se_atten.py | 62 ++++++++----------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/source/tests/pt/test_model_compression_se_atten.py b/source/tests/pt/test_model_compression_se_atten.py index c86cd7f1fc..43f8c150e7 100644 --- a/source/tests/pt/test_model_compression_se_atten.py +++ b/source/tests/pt/test_model_compression_se_atten.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json import os +import shutil import unittest import numpy as np @@ -26,7 +27,7 @@ def _file_delete(file) -> None: if os.path.isdir(file): - os.rmdir(file) + shutil.rmtree(file) elif os.path.isfile(file): os.remove(file) @@ -45,11 +46,10 @@ def _init_models(): "rcut_smth": 0.50, "rcut": 6.00, "neuron": [25, 50, 100], + "tebd_dim": 8, "axis_neuron": 16, "attn": 64, "attn_layer": 0, # Must be 0 for compression - "attn_dotr": True, - "attn_mask": False, "activation_function": "tanh", "type_one_side": True, "scaling_factor": 1.0, @@ -59,13 +59,6 @@ def _init_models(): "tebd_input_mode": "strip", } - # Add type_embedding configuration - jdata["model"]["type_embedding"] = { - "neuron": [8, 16, 32], - "resnet_dt": False, - "seed": 1, - } - jdata["training"]["training_data"]["systems"] = data_file with open(INPUT, "w") as fp: json.dump(jdata, fp, indent=4) @@ -110,13 +103,6 @@ def _init_models_exclude_types(): "tebd_input_mode": "strip", } - # Add type_embedding configuration - jdata["model"]["type_embedding"] = { - "neuron": [8, 16, 32], - "resnet_dt": False, - "seed": 1, - } - jdata["training"]["training_data"]["systems"] = data_file with open(INPUT, "w") as fp: json.dump(jdata, fp, indent=4) @@ -161,13 +147,6 @@ def _init_models_skip_neighbor_stat(): "tebd_input_mode": "strip", } - # Add type_embedding configuration - jdata["model"]["type_embedding"] = { - "neuron": [8, 16, 32], - "resnet_dt": False, - "seed": 1, - } - jdata["training"]["training_data"]["systems"] = data_file with open(INPUT, "w") as fp: json.dump(jdata, fp, indent=4) @@ -206,6 +185,28 @@ def setUpModule() -> None: INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() +def tearDownModule() -> None: + # Clean up files created by _init_models + _file_delete(INPUT) + _file_delete(FROZEN_MODEL) + _file_delete(COMPRESSED_MODEL) + # Clean up files created by _init_models_skip_neighbor_stat + _file_delete(FROZEN_MODEL_SKIP_NEIGHBOR_STAT) + _file_delete(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT) + # Clean up files created by _init_models_exclude_types + _file_delete(INPUT_ET) + _file_delete(FROZEN_MODEL_ET) + _file_delete(COMPRESSED_MODEL_ET) + # Clean up other artifacts + _file_delete("out.json") + _file_delete("compress.json") + _file_delete("checkpoint") + _file_delete("lcurve.out") + _file_delete("model.ckpt") + _file_delete("model-compression/checkpoint") + _file_delete("model-compression") + + class TestDeepPotAPBC(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -576,19 +577,6 @@ def setUpClass(cls) -> None: cls.atype = [0, 1, 1, 0, 1, 1] cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) - @classmethod - def tearDownClass(cls) -> None: - _file_delete(INPUT_ET) - _file_delete(FROZEN_MODEL_ET) - _file_delete(COMPRESSED_MODEL_ET) - _file_delete("out.json") - _file_delete("compress.json") - _file_delete("checkpoint") - _file_delete("lcurve.out") - _file_delete("model.ckpt") - _file_delete("model-compression/checkpoint") - _file_delete("model-compression") - def test_attrs(self) -> None: self.assertEqual(self.dp_original.get_ntypes(), 2) self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) From 01f30a5448edd423e1925273ce0953b1f2b0cfac Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 22 Nov 2025 17:43:19 +0800 Subject: [PATCH 11/14] refactor(pt): Update type embedding handling in DescrptBlockSeAtten - Replaced the optional `type_embd_data` with a registered buffer to improve memory management. - Adjusted conditional checks for type embedding usage to rely on the `compress` attribute, enhancing clarity and functionality. - Removed the associated test file for model compression of se_atten, streamlining the test suite. These changes enhance the maintainability and performance of the descriptor model while ensuring consistent behavior in type embedding compression. --- deepmd/pt/model/descriptor/se_atten.py | 8 +- .../pt/test_model_compression_se_atten.py | 799 ------------------ 2 files changed, 5 insertions(+), 802 deletions(-) delete mode 100644 source/tests/pt/test_model_compression_se_atten.py diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index c388ec2ce7..30d6024e60 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -285,7 +285,9 @@ def __init__( [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] ) # For type embedding compression - self.type_embd_data: Optional[torch.Tensor] = None + self.register_buffer( + "type_embd_data", torch.zeros(0, dtype=self.prec, device=env.DEVICE) + ) def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -628,7 +630,7 @@ def forward( # nf x (nl x nnei) nei_type = torch.gather(extended_atype, dim=1, index=nlist_index) if self.type_one_side: - if self.type_embd_data is not None: + if self.compress: tt_full = self.type_embd_data else: # (ntypes+1, tebd_dim) -> (ntypes+1, ng) @@ -642,7 +644,7 @@ def forward( idx_j = nei_type.view(-1) # (nf x nl x nnei) idx = (idx_i + idx_j).to(torch.long) - if self.type_embd_data is not None: + if self.compress: # ((ntypes+1)^2, ng) tt_full = self.type_embd_data else: diff --git a/source/tests/pt/test_model_compression_se_atten.py b/source/tests/pt/test_model_compression_se_atten.py deleted file mode 100644 index 43f8c150e7..0000000000 --- a/source/tests/pt/test_model_compression_se_atten.py +++ /dev/null @@ -1,799 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import json -import os -import shutil -import unittest - -import numpy as np - -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) -from deepmd.infer.deep_eval import ( - DeepEval, -) - -from .common import ( - j_loader, - run_dp, - tests_path, -) - -if GLOBAL_NP_FLOAT_PRECISION == np.float32: - default_places = 4 -else: - default_places = 10 - - -def _file_delete(file) -> None: - if os.path.isdir(file): - shutil.rmtree(file) - elif os.path.isfile(file): - os.remove(file) - - -def _init_models(): - data_file = str(tests_path / os.path.join("model_compression", "data")) - frozen_model = str(tests_path / "dp-original-se-atten.pth") - compressed_model = str(tests_path / "dp-compressed-se-atten.pth") - INPUT = str(tests_path / "input.json") - jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) - - # Configure se_atten descriptor with attn_layer=0 for compression compatibility - jdata["model"]["descriptor"] = { - "type": "se_atten", - "sel": 120, - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [25, 50, 100], - "tebd_dim": 8, - "axis_neuron": 16, - "attn": 64, - "attn_layer": 0, # Must be 0 for compression - "activation_function": "tanh", - "type_one_side": True, - "scaling_factor": 1.0, - "normalize": False, - "temperature": 1.0, - "seed": 1, - "tebd_input_mode": "strip", - } - - jdata["training"]["training_data"]["systems"] = data_file - with open(INPUT, "w") as fp: - json.dump(jdata, fp, indent=4) - - ret = run_dp("dp --pt train " + INPUT) - np.testing.assert_equal(ret, 0, "DP train failed!") - ret = run_dp("dp --pt freeze -o " + frozen_model) - np.testing.assert_equal(ret, 0, "DP freeze failed!") - ret = run_dp( - "dp --pt compress " + " -i " + frozen_model + " -o " + compressed_model - ) - np.testing.assert_equal(ret, 0, "DP model compression failed!") - return INPUT, frozen_model, compressed_model - - -def _init_models_exclude_types(): - data_file = str(tests_path / os.path.join("model_compression", "data")) - frozen_model = str(tests_path / "dp-original-se-atten-exclude-types.pth") - compressed_model = str(tests_path / "dp-compressed-se-atten-exclude-types.pth") - INPUT = str(tests_path / "input.json") - jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) - - # Configure se_atten descriptor with exclude_types - jdata["model"]["descriptor"] = { - "type": "se_atten", - "exclude_types": [[0, 1]], - "sel": 120, - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [25, 50, 100], - "axis_neuron": 16, - "attn": 64, - "attn_layer": 0, # Must be 0 for compression - "attn_dotr": True, - "attn_mask": False, - "activation_function": "tanh", - "type_one_side": True, - "scaling_factor": 1.0, - "normalize": False, - "temperature": 1.0, - "seed": 1, - "tebd_input_mode": "strip", - } - - jdata["training"]["training_data"]["systems"] = data_file - with open(INPUT, "w") as fp: - json.dump(jdata, fp, indent=4) - - ret = run_dp("dp --pt train " + INPUT) - np.testing.assert_equal(ret, 0, "DP train failed!") - ret = run_dp("dp --pt freeze -o " + frozen_model) - np.testing.assert_equal(ret, 0, "DP freeze failed!") - ret = run_dp( - "dp --pt compress " + " -i " + frozen_model + " -o " + compressed_model - ) - np.testing.assert_equal(ret, 0, "DP model compression failed!") - return INPUT, frozen_model, compressed_model - - -def _init_models_skip_neighbor_stat(): - suffix = "-skip-neighbor-stat" - data_file = str(tests_path / os.path.join("model_compression", "data")) - frozen_model = str(tests_path / f"dp-original-se-atten{suffix}.pth") - compressed_model = str(tests_path / f"dp-compressed-se-atten{suffix}.pth") - INPUT = str(tests_path / "input.json") - jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) - - # Configure se_atten descriptor - jdata["model"]["descriptor"] = { - "type": "se_atten", - "sel": 120, - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [25, 50, 100], - "axis_neuron": 16, - "attn": 64, - "attn_layer": 0, # Must be 0 for compression - "attn_dotr": True, - "attn_mask": False, - "activation_function": "tanh", - "type_one_side": True, - "scaling_factor": 1.0, - "normalize": False, - "temperature": 1.0, - "seed": 1, - "tebd_input_mode": "strip", - } - - jdata["training"]["training_data"]["systems"] = data_file - with open(INPUT, "w") as fp: - json.dump(jdata, fp, indent=4) - - ret = run_dp("dp --pt train " + INPUT + " --skip-neighbor-stat") - np.testing.assert_equal(ret, 0, "DP train failed!") - ret = run_dp("dp --pt freeze -o " + frozen_model) - np.testing.assert_equal(ret, 0, "DP freeze failed!") - ret = run_dp( - "dp --pt compress " - + " -i " - + frozen_model - + " -o " - + compressed_model - + " -t " - + INPUT - ) - np.testing.assert_equal(ret, 0, "DP model compression failed!") - return INPUT, frozen_model, compressed_model - - -def setUpModule() -> None: - global \ - INPUT, \ - FROZEN_MODEL, \ - COMPRESSED_MODEL, \ - INPUT_ET, \ - FROZEN_MODEL_ET, \ - COMPRESSED_MODEL_ET, \ - FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \ - COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT - INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() - _, FROZEN_MODEL_SKIP_NEIGHBOR_STAT, COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = ( - _init_models_skip_neighbor_stat() - ) - INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() - - -def tearDownModule() -> None: - # Clean up files created by _init_models - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) - _file_delete(COMPRESSED_MODEL) - # Clean up files created by _init_models_skip_neighbor_stat - _file_delete(FROZEN_MODEL_SKIP_NEIGHBOR_STAT) - _file_delete(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT) - # Clean up files created by _init_models_exclude_types - _file_delete(INPUT_ET) - _file_delete(FROZEN_MODEL_ET) - _file_delete(COMPRESSED_MODEL_ET) - # Clean up other artifacts - _file_delete("out.json") - _file_delete("compress.json") - _file_delete("checkpoint") - _file_delete("lcurve.out") - _file_delete("model.ckpt") - _file_delete("model-compression/checkpoint") - _file_delete("model-compression") - - -class TestDeepPotAPBC(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.dp_original = DeepEval(FROZEN_MODEL) - cls.dp_compressed = DeepEval(COMPRESSED_MODEL) - cls.coords = np.array( - [ - 12.83, - 2.56, - 2.18, - 12.09, - 2.87, - 2.74, - 00.25, - 3.32, - 1.68, - 3.36, - 3.00, - 1.81, - 3.51, - 2.51, - 2.60, - 4.27, - 3.22, - 1.56, - ] - ) - cls.atype = [0, 1, 1, 0, 1, 1] - cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) - - def test_attrs(self) -> None: - self.assertEqual(self.dp_original.get_ntypes(), 2) - self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) - self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) - self.assertEqual(self.dp_original.get_dim_fparam(), 0) - self.assertEqual(self.dp_original.get_dim_aparam(), 0) - - self.assertEqual(self.dp_compressed.get_ntypes(), 2) - self.assertAlmostEqual( - self.dp_compressed.get_rcut(), 6.0, places=default_places - ) - self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) - self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) - self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) - - def test_1frame(self) -> None: - ee0, ff0, vv0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=False - ) - ee1, ff1, vv1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=False - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_1frame_atm(self) -> None: - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=True - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_2frame_atm(self) -> None: - coords2 = np.concatenate((self.coords, self.coords)) - box2 = np.concatenate((self.box, self.box)) - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - coords2, box2, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - coords2, box2, self.atype, atomic=True - ) - # check shape of the returns - nframes = 2 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - -class TestDeepPotANoPBC(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.dp_original = DeepEval(FROZEN_MODEL) - cls.dp_compressed = DeepEval(COMPRESSED_MODEL) - cls.coords = np.array( - [ - 12.83, - 2.56, - 2.18, - 12.09, - 2.87, - 2.74, - 00.25, - 3.32, - 1.68, - 3.36, - 3.00, - 1.81, - 3.51, - 2.51, - 2.60, - 4.27, - 3.22, - 1.56, - ] - ) - cls.atype = [0, 1, 1, 0, 1, 1] - cls.box = None - - def test_1frame(self) -> None: - ee0, ff0, vv0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=False - ) - ee1, ff1, vv1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=False - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_1frame_atm(self) -> None: - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=True - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_2frame_atm(self) -> None: - coords2 = np.concatenate((self.coords, self.coords)) - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - coords2, self.box, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - coords2, self.box, self.atype, atomic=True - ) - # check shape of the returns - nframes = 2 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - -class TestDeepPotALargeBoxNoPBC(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.dp_original = DeepEval(FROZEN_MODEL) - cls.dp_compressed = DeepEval(COMPRESSED_MODEL) - cls.coords = np.array( - [ - 12.83, - 2.56, - 2.18, - 12.09, - 2.87, - 2.74, - 00.25, - 3.32, - 1.68, - 3.36, - 3.00, - 1.81, - 3.51, - 2.51, - 2.60, - 4.27, - 3.22, - 1.56, - ] - ) - cls.atype = [0, 1, 1, 0, 1, 1] - cls.box = np.array([19.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) - - def test_1frame(self) -> None: - ee0, ff0, vv0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=False - ) - ee1, ff1, vv1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=False - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_1frame_atm(self) -> None: - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=True - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_ase(self) -> None: - from ase import ( - Atoms, - ) - - from deepmd.calculator import ( - DP, - ) - - water0 = Atoms( - "OHHOHH", - positions=self.coords.reshape((-1, 3)), - cell=self.box.reshape((3, 3)), - calculator=DP(FROZEN_MODEL), - ) - water1 = Atoms( - "OHHOHH", - positions=self.coords.reshape((-1, 3)), - cell=self.box.reshape((3, 3)), - calculator=DP(COMPRESSED_MODEL), - ) - ee0 = water0.get_potential_energy() - ff0 = water0.get_forces() - ee1 = water1.get_potential_energy() - ff1 = water1.get_forces() - # nframes = 1 - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - - -class TestDeepPotAPBCExcludeTypes(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.dp_original = DeepEval(FROZEN_MODEL_ET) - cls.dp_compressed = DeepEval(COMPRESSED_MODEL_ET) - cls.coords = np.array( - [ - 12.83, - 2.56, - 2.18, - 12.09, - 2.87, - 2.74, - 00.25, - 3.32, - 1.68, - 3.36, - 3.00, - 1.81, - 3.51, - 2.51, - 2.60, - 4.27, - 3.22, - 1.56, - ] - ) - cls.atype = [0, 1, 1, 0, 1, 1] - cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) - - def test_attrs(self) -> None: - self.assertEqual(self.dp_original.get_ntypes(), 2) - self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) - self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) - self.assertEqual(self.dp_original.get_dim_fparam(), 0) - self.assertEqual(self.dp_original.get_dim_aparam(), 0) - - self.assertEqual(self.dp_compressed.get_ntypes(), 2) - self.assertAlmostEqual( - self.dp_compressed.get_rcut(), 6.0, places=default_places - ) - self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) - self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) - self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) - - def test_1frame(self) -> None: - ee0, ff0, vv0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=False - ) - ee1, ff1, vv1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=False - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_1frame_atm(self) -> None: - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=True - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_2frame_atm(self) -> None: - coords2 = np.concatenate((self.coords, self.coords)) - box2 = np.concatenate((self.box, self.box)) - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - coords2, box2, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - coords2, box2, self.atype, atomic=True - ) - # check shape of the returns - nframes = 2 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - -class TestSkipNeighborStat(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT) - cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT) - cls.coords = np.array( - [ - 12.83, - 2.56, - 2.18, - 12.09, - 2.87, - 2.74, - 00.25, - 3.32, - 1.68, - 3.36, - 3.00, - 1.81, - 3.51, - 2.51, - 2.60, - 4.27, - 3.22, - 1.56, - ] - ) - cls.atype = [0, 1, 1, 0, 1, 1] - cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) - - def test_attrs(self) -> None: - self.assertEqual(self.dp_original.get_ntypes(), 2) - self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) - self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) - self.assertEqual(self.dp_original.get_dim_fparam(), 0) - self.assertEqual(self.dp_original.get_dim_aparam(), 0) - - self.assertEqual(self.dp_compressed.get_ntypes(), 2) - self.assertAlmostEqual( - self.dp_compressed.get_rcut(), 6.0, places=default_places - ) - self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) - self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) - self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) - - def test_1frame(self) -> None: - ee0, ff0, vv0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=False - ) - ee1, ff1, vv1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=False - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_1frame_atm(self) -> None: - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - self.coords, self.box, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - self.coords, self.box, self.atype, atomic=True - ) - # check shape of the returns - nframes = 1 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - def test_2frame_atm(self) -> None: - coords2 = np.concatenate((self.coords, self.coords)) - box2 = np.concatenate((self.box, self.box)) - ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( - coords2, box2, self.atype, atomic=True - ) - ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( - coords2, box2, self.atype, atomic=True - ) - # check shape of the returns - nframes = 2 - natoms = len(self.atype) - self.assertEqual(ee0.shape, (nframes, 1)) - self.assertEqual(ff0.shape, (nframes, natoms, 3)) - self.assertEqual(vv0.shape, (nframes, 9)) - self.assertEqual(ae0.shape, (nframes, natoms, 1)) - self.assertEqual(av0.shape, (nframes, natoms, 9)) - self.assertEqual(ee1.shape, (nframes, 1)) - self.assertEqual(ff1.shape, (nframes, natoms, 3)) - self.assertEqual(vv1.shape, (nframes, 9)) - self.assertEqual(ae1.shape, (nframes, natoms, 1)) - self.assertEqual(av1.shape, (nframes, natoms, 9)) - - # check values - np.testing.assert_almost_equal(ff0, ff1, default_places) - np.testing.assert_almost_equal(ae0, ae1, default_places) - np.testing.assert_almost_equal(av0, av1, default_places) - np.testing.assert_almost_equal(ee0, ee1, default_places) - np.testing.assert_almost_equal(vv0, vv1, default_places) - - -if __name__ == "__main__": - unittest.main() From cc0be5981e369dcce1061eff0c14a32d88e8edef Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 23 Nov 2025 12:13:02 +0800 Subject: [PATCH 12/14] feat(pt): enable type embedding compression in repinit and update test validation --- deepmd/pt/model/descriptor/dpa2.py | 4 ++++ source/tests/pt/model/test_descriptor_dpa1.py | 1 + 2 files changed, 5 insertions(+) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 5858206cc3..8985a92196 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -970,4 +970,8 @@ def enable_compression( self.repinit.enable_compression( self.table.data, self.table_config, self.lower, self.upper ) + + # Enable type embedding compression for repinit (se_atten) + self.repinit.type_embedding_compression(self.type_embedding) + self.compress = True diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index abf5d1af01..56f4dd0c56 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -377,5 +377,6 @@ def translate_se_atten_and_type_embd_dicts_to_dpa1( target_dict[tk] = type_embd_dict[kk] record[all_keys.index("se_atten.compress_data.0")] = True record[all_keys.index("se_atten.compress_info.0")] = True + record[all_keys.index("se_atten.type_embd_data")] = True assert all(record) return target_dict From c1465d7e1f507105a936a137b6ccbe9ea002a21c Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 23 Nov 2025 13:11:31 +0800 Subject: [PATCH 13/14] fix(pt): update type embedding dictionary handling in translate_type_embd_dicts_to_dpa2 --- source/tests/pt/model/test_descriptor_dpa2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index 6a859a497a..3fa6b86636 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -196,5 +196,6 @@ def translate_type_embd_dicts_to_dpa2( target_dict[tk] = type_embd_dict[kk] record[all_keys.index("repinit.compress_data.0")] = True record[all_keys.index("repinit.compress_info.0")] = True + record[all_keys.index("repinit.type_embd_data")] = True assert all(record) return target_dict From c35082da6525d635f17e3fc9e222fb469ac04098 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 23 Nov 2025 14:12:24 +0800 Subject: [PATCH 14/14] fix(pt): add type embedding data to state dict in test_descriptor --- source/tests/pt/model/test_descriptor_dpa1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index 56f4dd0c56..27b84879dc 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -249,6 +249,7 @@ def test_descriptor_block(self) -> None: # this is an old state dict, modify manually state_dict["compress_info.0"] = des.compress_info[0] state_dict["compress_data.0"] = des.compress_data[0] + state_dict["type_embd_data"] = des.type_embd_data des.load_state_dict(state_dict) coord = self.coord atype = self.atype