Skip to content

Commit 5163e74

Browse files
authored
feat(pt): Implement type embedding compression for se_e3_tebd (#5059)
The unit test for this descriptor will be committed through another pr. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Breaking Changes** * Model compression configuration updated to include new required parameters * **Improvements** * Precomputed type-embedding data improves model compression efficiency and inference performance <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 86cb76f commit 5163e74

File tree

1 file changed

+55
-28
lines changed

1 file changed

+55
-28
lines changed

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ def enable_compression(
553553
assert not self.se_ttebd.resnet_dt, (
554554
"Model compression error: descriptor resnet_dt must be false!"
555555
)
556+
if self.tebd_input_mode != "strip":
557+
raise RuntimeError("Cannot compress model when tebd_input_mode != 'strip'")
556558
for tt in self.se_ttebd.exclude_types:
557559
if (tt[0] not in range(self.se_ttebd.ntypes)) or (
558560
tt[1] not in range(self.se_ttebd.ntypes)
@@ -573,9 +575,6 @@ def enable_compression(
573575
"Empty embedding-nets are not supported in model compression!"
574576
)
575577

576-
if self.tebd_input_mode != "strip":
577-
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")
578-
579578
data = self.serialize()
580579
self.table = DPTabulate(
581580
self,
@@ -597,7 +596,11 @@ def enable_compression(
597596
)
598597

599598
self.se_ttebd.enable_compression(
600-
self.table.data, self.table_config, self.lower, self.upper
599+
self.type_embedding,
600+
self.table.data,
601+
self.table_config,
602+
self.lower,
603+
self.upper,
601604
)
602605
self.compress = True
603606

@@ -694,12 +697,17 @@ def __init__(
694697
self.stats = None
695698
# compression related variables
696699
self.compress = False
700+
# For geometric compression
697701
self.compress_info = nn.ParameterList(
698702
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
699703
)
700704
self.compress_data = nn.ParameterList(
701705
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
702706
)
707+
# For type embedding compression
708+
self.register_buffer(
709+
"type_embd_data", torch.zeros(0, dtype=self.prec, device=env.DEVICE)
710+
)
703711

704712
def get_rcut(self) -> float:
705713
"""Returns the cut-off radius."""
@@ -986,31 +994,24 @@ def forward(
986994
nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1])
987995
idx_i = nei_type_i * ntypes_with_padding
988996
idx_j = nei_type_j
989-
# (nf x nl x nt_i x nt_j) x ng
990-
idx = (
991-
(idx_i + idx_j)
992-
.view(-1, 1)
993-
.expand(-1, ng)
994-
.type(torch.long)
995-
.to(torch.long)
996-
)
997-
# ntypes * (ntypes) * nt
998-
type_embedding_i = torch.tile(
999-
type_embedding.view(ntypes_with_padding, 1, nt),
1000-
[1, ntypes_with_padding, 1],
1001-
)
1002-
# (ntypes) * ntypes * nt
1003-
type_embedding_j = torch.tile(
1004-
type_embedding.view(1, ntypes_with_padding, nt),
1005-
[ntypes_with_padding, 1, 1],
1006-
)
1007-
# (ntypes * ntypes) * (nt+nt)
1008-
two_side_type_embedding = torch.cat(
1009-
[type_embedding_i, type_embedding_j], -1
1010-
).reshape(-1, nt * 2)
1011-
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
997+
idx = (idx_i + idx_j).reshape(-1).to(torch.long)
998+
if self.compress:
999+
tt_full = self.type_embd_data
1000+
else:
1001+
type_embedding_i = torch.tile(
1002+
type_embedding.view(ntypes_with_padding, 1, nt),
1003+
[1, ntypes_with_padding, 1],
1004+
)
1005+
type_embedding_j = torch.tile(
1006+
type_embedding.view(1, ntypes_with_padding, nt),
1007+
[ntypes_with_padding, 1, 1],
1008+
)
1009+
two_side_type_embedding = torch.cat(
1010+
[type_embedding_i, type_embedding_j], -1
1011+
).reshape(-1, nt * 2)
1012+
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
10121013
# (nfnl x nt_i x nt_j) x ng
1013-
gg_t = torch.gather(tt_full, dim=0, index=idx)
1014+
gg_t = tt_full[idx]
10141015
# (nfnl x nt_i x nt_j) x ng
10151016
gg_t = gg_t.reshape(nfnl, nnei, nnei, ng)
10161017
if self.smooth:
@@ -1042,6 +1043,7 @@ def forward(
10421043

10431044
def enable_compression(
10441045
self,
1046+
type_embedding_net: TypeEmbedNet,
10451047
table_data: dict,
10461048
table_config: dict,
10471049
lower: dict,
@@ -1051,6 +1053,8 @@ def enable_compression(
10511053
10521054
Parameters
10531055
----------
1056+
type_embedding_net : TypeEmbedNet
1057+
The type embedding network
10541058
table_data : dict
10551059
The tabulated data from DPTabulate
10561060
table_config : dict
@@ -1060,6 +1064,13 @@ def enable_compression(
10601064
upper : dict
10611065
Upper bounds for compression
10621066
"""
1067+
if self.tebd_input_mode != "strip":
1068+
raise RuntimeError("Type embedding compression only works in strip mode")
1069+
if self.filter_layers_strip is None:
1070+
raise RuntimeError(
1071+
"filter_layers_strip must exist for type embedding compression"
1072+
)
1073+
10631074
# Compress the main geometric embedding network (self.filter_layers)
10641075
net_key = "filter_net"
10651076
self.compress_info[0] = torch.as_tensor(
@@ -1078,6 +1089,22 @@ def enable_compression(
10781089
device=env.DEVICE, dtype=self.prec
10791090
)
10801091

1092+
# Compress the type embedding network (self.filter_layers_strip)
1093+
with torch.no_grad():
1094+
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
1095+
nt, t_dim = full_embd.shape
1096+
type_embedding_i = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
1097+
type_embedding_j = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
1098+
two_side_type_embedding = torch.cat(
1099+
[type_embedding_i, type_embedding_j], dim=-1
1100+
).reshape(-1, t_dim * 2)
1101+
embd_tensor = self.filter_layers_strip.networks[0](
1102+
two_side_type_embedding
1103+
).detach()
1104+
if hasattr(self, "type_embd_data"):
1105+
del self.type_embd_data
1106+
self.register_buffer("type_embd_data", embd_tensor)
1107+
10811108
self.compress = True
10821109

10831110
def has_message_passing(self) -> bool:

0 commit comments

Comments
 (0)