diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index e8ab0522ff..8225d8e4af 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -553,6 +553,8 @@ def enable_compression( assert not self.se_ttebd.resnet_dt, ( "Model compression error: descriptor resnet_dt must be false!" ) + if self.tebd_input_mode != "strip": + raise RuntimeError("Cannot compress model when tebd_input_mode != 'strip'") for tt in self.se_ttebd.exclude_types: if (tt[0] not in range(self.se_ttebd.ntypes)) or ( tt[1] not in range(self.se_ttebd.ntypes) @@ -573,9 +575,6 @@ def enable_compression( "Empty embedding-nets are not supported in model compression!" ) - if self.tebd_input_mode != "strip": - raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'") - data = self.serialize() self.table = DPTabulate( self, @@ -597,7 +596,11 @@ def enable_compression( ) self.se_ttebd.enable_compression( - self.table.data, self.table_config, self.lower, self.upper + self.type_embedding, + self.table.data, + self.table_config, + self.lower, + self.upper, ) self.compress = True @@ -694,12 +697,17 @@ def __init__( self.stats = None # compression related variables self.compress = False + # For geometric compression self.compress_info = nn.ParameterList( [nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))] ) self.compress_data = nn.ParameterList( [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] ) + # For type embedding compression + self.register_buffer( + "type_embd_data", torch.zeros(0, dtype=self.prec, device=env.DEVICE) + ) def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -986,31 +994,24 @@ def forward( nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1]) idx_i = nei_type_i * ntypes_with_padding idx_j = nei_type_j - # (nf x nl x nt_i x nt_j) x ng - idx = ( - (idx_i + idx_j) - .view(-1, 1) - .expand(-1, ng) - .type(torch.long) - .to(torch.long) - ) - # ntypes * (ntypes) * nt - type_embedding_i = torch.tile( - type_embedding.view(ntypes_with_padding, 1, nt), - [1, ntypes_with_padding, 1], - ) - # (ntypes) * ntypes * nt - type_embedding_j = torch.tile( - type_embedding.view(1, ntypes_with_padding, nt), - [ntypes_with_padding, 1, 1], - ) - # (ntypes * ntypes) * (nt+nt) - two_side_type_embedding = torch.cat( - [type_embedding_i, type_embedding_j], -1 - ).reshape(-1, nt * 2) - tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) + idx = (idx_i + idx_j).reshape(-1).to(torch.long) + if self.compress: + tt_full = self.type_embd_data + else: + type_embedding_i = torch.tile( + type_embedding.view(ntypes_with_padding, 1, nt), + [1, ntypes_with_padding, 1], + ) + type_embedding_j = torch.tile( + type_embedding.view(1, ntypes_with_padding, nt), + [ntypes_with_padding, 1, 1], + ) + two_side_type_embedding = torch.cat( + [type_embedding_i, type_embedding_j], -1 + ).reshape(-1, nt * 2) + tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) # (nfnl x nt_i x nt_j) x ng - gg_t = torch.gather(tt_full, dim=0, index=idx) + gg_t = tt_full[idx] # (nfnl x nt_i x nt_j) x ng gg_t = gg_t.reshape(nfnl, nnei, nnei, ng) if self.smooth: @@ -1042,6 +1043,7 @@ def forward( def enable_compression( self, + type_embedding_net: TypeEmbedNet, table_data: dict, table_config: dict, lower: dict, @@ -1051,6 +1053,8 @@ def enable_compression( Parameters ---------- + type_embedding_net : TypeEmbedNet + The type embedding network table_data : dict The tabulated data from DPTabulate table_config : dict @@ -1060,6 +1064,13 @@ def enable_compression( upper : dict Upper bounds for compression """ + if self.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") + if self.filter_layers_strip is None: + raise RuntimeError( + "filter_layers_strip must exist for type embedding compression" + ) + # Compress the main geometric embedding network (self.filter_layers) net_key = "filter_net" self.compress_info[0] = torch.as_tensor( @@ -1078,6 +1089,22 @@ def enable_compression( device=env.DEVICE, dtype=self.prec ) + # Compress the type embedding network (self.filter_layers_strip) + with torch.no_grad(): + full_embd = type_embedding_net.get_full_embedding(env.DEVICE) + nt, t_dim = full_embd.shape + type_embedding_i = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim) + type_embedding_j = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim) + two_side_type_embedding = torch.cat( + [type_embedding_i, type_embedding_j], dim=-1 + ).reshape(-1, t_dim * 2) + embd_tensor = self.filter_layers_strip.networks[0]( + two_side_type_embedding + ).detach() + if hasattr(self, "type_embd_data"): + del self.type_embd_data + self.register_buffer("type_embd_data", embd_tensor) + self.compress = True def has_message_passing(self) -> bool: