diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index e158dd3725..78a277881c 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -645,6 +645,10 @@ def enable_compression( self.se_atten.enable_compression( self.table.data, self.table_config, self.lower, self.upper ) + + # Enable type embedding compression + self.se_atten.type_embedding_compression(self.type_embedding) + self.compress = True def forward( diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 5858206cc3..8985a92196 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -970,4 +970,8 @@ def enable_compression( self.repinit.enable_compression( self.table.data, self.table_config, self.lower, self.upper ) + + # Enable type embedding compression for repinit (se_atten) + self.repinit.type_embedding_compression(self.type_embedding) + self.compress = True diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index bfcb510810..30d6024e60 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -27,6 +27,9 @@ MLPLayer, NetworkCollection, ) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, +) from deepmd.pt.utils import ( env, ) @@ -272,7 +275,7 @@ def __init__( self.filter_layers_strip = filter_layers_strip self.stats = None - # add for compression + # For geometric compression self.compress = False self.is_sorted = False self.compress_info = nn.ParameterList( @@ -281,6 +284,10 @@ def __init__( self.compress_data = nn.ParameterList( [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] ) + # For type embedding compression + self.register_buffer( + "type_embd_data", torch.zeros(0, dtype=self.prec, device=env.DEVICE) + ) def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -447,6 +454,56 @@ def enable_compression( self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec) self.compress = True + def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: + """Enable type embedding compression for strip mode. + + Precomputes embedding network outputs for all type combinations: + - One-side: (ntypes+1) combinations (neighbor types only) + - Two-side: (ntypes+1)² combinations (neighbor x center type pairs) + + Parameters + ---------- + type_embedding_net : TypeEmbedNet + The type embedding network that provides get_full_embedding() method + """ + if self.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") + if self.filter_layers_strip is None: + raise RuntimeError( + "filter_layers_strip must be initialized for type embedding compression" + ) + + with torch.no_grad(): + # Get full type embedding: (ntypes+1) x tebd_dim + full_embd = type_embedding_net.get_full_embedding(env.DEVICE) + nt, t_dim = full_embd.shape + + if self.type_one_side: + # One-side: only neighbor types, much simpler! + # Precompute for all (ntypes+1) neighbor types + embd_tensor = self.filter_layers_strip.networks[0](full_embd).detach() + if hasattr(self, "type_embd_data"): + del self.type_embd_data + self.register_buffer("type_embd_data", embd_tensor) + else: + # Two-side: all (ntypes+1)² type pair combinations + # Create [neighbor, center] combinations + # for a fixed row i, all columns j have different neighbor types + embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim) + # for a fixed row i, all columns j share the same center type i + embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim) + two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape( + -1, t_dim * 2 + ) + # Precompute for all type pairs + # Index formula: idx = center_type * nt + neighbor_type + embd_tensor = self.filter_layers_strip.networks[0]( + two_side_embd + ).detach() + if hasattr(self, "type_embd_data"): + del self.type_embd_data + self.register_buffer("type_embd_data", embd_tensor) + def forward( self, nlist: torch.Tensor, @@ -572,42 +629,44 @@ def forward( nlist_index = nlist.reshape(nb, nloc * nnei) # nf x (nl x nnei) nei_type = torch.gather(extended_atype, dim=1, index=nlist_index) - # (nf x nl x nnei) x ng - nei_type_index = nei_type.view(-1, 1).expand(-1, ng).type(torch.long) if self.type_one_side: - tt_full = self.filter_layers_strip.networks[0](type_embedding) - # (nf x nl x nnei) x ng - gg_t = torch.gather(tt_full, dim=0, index=nei_type_index) + if self.compress: + tt_full = self.type_embd_data + else: + # (ntypes+1, tebd_dim) -> (ntypes+1, ng) + tt_full = self.filter_layers_strip.networks[0](type_embedding) + # (nf*nl*nnei,) -> (nf*nl*nnei, ng) + gg_t = tt_full[nei_type.view(-1).type(torch.long)] else: idx_i = torch.tile( atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei] ).view(-1) idx_j = nei_type.view(-1) + # (nf x nl x nnei) + idx = (idx_i + idx_j).to(torch.long) + if self.compress: + # ((ntypes+1)^2, ng) + tt_full = self.type_embd_data + else: + # ((ntypes+1)^2) * (ntypes+1)^2 * nt + type_embedding_nei = torch.tile( + type_embedding.view(1, ntypes_with_padding, nt), + [ntypes_with_padding, 1, 1], + ) + # (ntypes+1)^2 * ((ntypes+1)^2) * nt + type_embedding_center = torch.tile( + type_embedding.view(ntypes_with_padding, 1, nt), + [1, ntypes_with_padding, 1], + ) + # ((ntypes+1)^2 * (ntypes+1)^2) * (nt+nt) + two_side_type_embedding = torch.cat( + [type_embedding_nei, type_embedding_center], -1 + ).reshape(-1, nt * 2) + tt_full = self.filter_layers_strip.networks[0]( + two_side_type_embedding + ) # (nf x nl x nnei) x ng - idx = ( - (idx_i + idx_j) - .view(-1, 1) - .expand(-1, ng) - .type(torch.long) - .to(torch.long) - ) - # (ntypes) * ntypes * nt - type_embedding_nei = torch.tile( - type_embedding.view(1, ntypes_with_padding, nt), - [ntypes_with_padding, 1, 1], - ) - # ntypes * (ntypes) * nt - type_embedding_center = torch.tile( - type_embedding.view(ntypes_with_padding, 1, nt), - [1, ntypes_with_padding, 1], - ) - # (ntypes * ntypes) * (nt+nt) - two_side_type_embedding = torch.cat( - [type_embedding_nei, type_embedding_center], -1 - ).reshape(-1, nt * 2) - tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) - # (nf x nl x nnei) x ng - gg_t = torch.gather(tt_full, dim=0, index=idx) + gg_t = tt_full[idx] # (nf x nl) x nnei x ng gg_t = gg_t.reshape(nfnl, nnei, ng) if self.smooth: diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index abf5d1af01..27b84879dc 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -249,6 +249,7 @@ def test_descriptor_block(self) -> None: # this is an old state dict, modify manually state_dict["compress_info.0"] = des.compress_info[0] state_dict["compress_data.0"] = des.compress_data[0] + state_dict["type_embd_data"] = des.type_embd_data des.load_state_dict(state_dict) coord = self.coord atype = self.atype @@ -377,5 +378,6 @@ def translate_se_atten_and_type_embd_dicts_to_dpa1( target_dict[tk] = type_embd_dict[kk] record[all_keys.index("se_atten.compress_data.0")] = True record[all_keys.index("se_atten.compress_info.0")] = True + record[all_keys.index("se_atten.type_embd_data")] = True assert all(record) return target_dict diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index 6a859a497a..3fa6b86636 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -196,5 +196,6 @@ def translate_type_embd_dicts_to_dpa2( target_dict[tk] = type_embd_dict[kk] record[all_keys.index("repinit.compress_data.0")] = True record[all_keys.index("repinit.compress_info.0")] = True + record[all_keys.index("repinit.type_embd_data")] = True assert all(record) return target_dict