Skip to content

Commit 93bb015

Browse files
committed
feat(pt): Enhance type embedding compression in se_t_tebd
- Introduced a new method for type embedding compression, allowing precomputation of strip-mode type embeddings for all type pairs. - Added runtime checks to ensure compatibility with the strip input mode and the existence of necessary filter layers. - Updated the forward method to utilize precomputed type embeddings when available, improving performance during inference. These changes optimize the handling of type embeddings, enhancing the efficiency of the descriptor model.
1 parent 4f72994 commit 93bb015

File tree

1 file changed

+40
-24
lines changed

1 file changed

+40
-24
lines changed

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def enable_compression(
599599
self.se_ttebd.enable_compression(
600600
self.table.data, self.table_config, self.lower, self.upper
601601
)
602+
self.se_ttebd.type_embedding_compression(self.type_embedding)
602603
self.compress = True
603604

604605

@@ -700,6 +701,7 @@ def __init__(
700701
self.compress_data = nn.ParameterList(
701702
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
702703
)
704+
self.type_embd_data: Optional[torch.Tensor] = None
703705

704706
def get_rcut(self) -> float:
705707
"""Returns the cut-off radius."""
@@ -838,6 +840,27 @@ def reinit_exclude(
838840
self.exclude_types = exclude_types
839841
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
840842

843+
def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
844+
"""Precompute strip-mode type embeddings for all type pairs."""
845+
if self.tebd_input_mode != "strip":
846+
raise RuntimeError("Type embedding compression only works in strip mode")
847+
if self.filter_layers_strip is None:
848+
raise RuntimeError(
849+
"filter_layers_strip must exist for type embedding compression"
850+
)
851+
852+
with torch.no_grad():
853+
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
854+
nt, t_dim = full_embd.shape
855+
type_embedding_i = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
856+
type_embedding_j = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
857+
two_side_type_embedding = torch.cat(
858+
[type_embedding_i, type_embedding_j], dim=-1
859+
).reshape(-1, t_dim * 2)
860+
self.type_embd_data = self.filter_layers_strip.networks[0](
861+
two_side_type_embedding
862+
).detach()
863+
841864
def forward(
842865
self,
843866
nlist: torch.Tensor,
@@ -986,31 +1009,24 @@ def forward(
9861009
nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1])
9871010
idx_i = nei_type_i * ntypes_with_padding
9881011
idx_j = nei_type_j
989-
# (nf x nl x nt_i x nt_j) x ng
990-
idx = (
991-
(idx_i + idx_j)
992-
.view(-1, 1)
993-
.expand(-1, ng)
994-
.type(torch.long)
995-
.to(torch.long)
996-
)
997-
# ntypes * (ntypes) * nt
998-
type_embedding_i = torch.tile(
999-
type_embedding.view(ntypes_with_padding, 1, nt),
1000-
[1, ntypes_with_padding, 1],
1001-
)
1002-
# (ntypes) * ntypes * nt
1003-
type_embedding_j = torch.tile(
1004-
type_embedding.view(1, ntypes_with_padding, nt),
1005-
[ntypes_with_padding, 1, 1],
1006-
)
1007-
# (ntypes * ntypes) * (nt+nt)
1008-
two_side_type_embedding = torch.cat(
1009-
[type_embedding_i, type_embedding_j], -1
1010-
).reshape(-1, nt * 2)
1011-
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
1012+
idx = (idx_i + idx_j).reshape(-1).to(torch.long)
1013+
if self.type_embd_data is not None:
1014+
tt_full = self.type_embd_data
1015+
else:
1016+
type_embedding_i = torch.tile(
1017+
type_embedding.view(ntypes_with_padding, 1, nt),
1018+
[1, ntypes_with_padding, 1],
1019+
)
1020+
type_embedding_j = torch.tile(
1021+
type_embedding.view(1, ntypes_with_padding, nt),
1022+
[ntypes_with_padding, 1, 1],
1023+
)
1024+
two_side_type_embedding = torch.cat(
1025+
[type_embedding_i, type_embedding_j], -1
1026+
).reshape(-1, nt * 2)
1027+
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
10121028
# (nfnl x nt_i x nt_j) x ng
1013-
gg_t = torch.gather(tt_full, dim=0, index=idx)
1029+
gg_t = tt_full[idx]
10141030
# (nfnl x nt_i x nt_j) x ng
10151031
gg_t = gg_t.reshape(nfnl, nnei, nnei, ng)
10161032
if self.smooth:

0 commit comments

Comments
 (0)