Skip to content

Commit 7f776fa

Browse files
committed
fix(pt): Enforce conditions for type embedding compression in se_atten
- Added runtime checks to ensure type embedding compression only operates in "strip" mode. - Introduced validation for the initialization of `filter_layers_strip` to prevent runtime errors. - Updated comments to clarify the expected dimensions of type embeddings, enhancing code readability. These changes improve error handling and maintain the integrity of the type embedding compression logic.
1 parent d70a397 commit 7f776fa

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,13 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
464464
type_embedding_net : TypeEmbedNet
465465
The type embedding network that provides get_full_embedding() method
466466
"""
467+
if self.tebd_input_mode != "strip":
468+
raise RuntimeError("Type embedding compression only works in strip mode")
469+
if self.filter_layers_strip is None:
470+
raise RuntimeError(
471+
"filter_layers_strip must be initialized for type embedding compression"
472+
)
473+
467474
with torch.no_grad():
468475
# Get full type embedding: (ntypes+1) x tebd_dim
469476
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
@@ -632,20 +639,20 @@ def forward(
632639
# (nf x nl x nnei)
633640
idx = (idx_i + idx_j).to(torch.long)
634641
if self.type_embd_data is not None:
635-
# (ntypes^2, ng)
642+
# ((ntypes+1)^2, ng)
636643
tt_full = self.type_embd_data
637644
else:
638-
# (ntypes) * ntypes * nt
645+
# ((ntypes+1)^2) * (ntypes+1)^2 * nt
639646
type_embedding_nei = torch.tile(
640647
type_embedding.view(1, ntypes_with_padding, nt),
641648
[ntypes_with_padding, 1, 1],
642649
)
643-
# ntypes * (ntypes) * nt
650+
# (ntypes+1)^2 * ((ntypes+1)^2) * nt
644651
type_embedding_center = torch.tile(
645652
type_embedding.view(ntypes_with_padding, 1, nt),
646653
[1, ntypes_with_padding, 1],
647654
)
648-
# (ntypes * ntypes) * (nt+nt)
655+
# ((ntypes+1)^2 * (ntypes+1)^2) * (nt+nt)
649656
two_side_type_embedding = torch.cat(
650657
[type_embedding_nei, type_embedding_center], -1
651658
).reshape(-1, nt * 2)

0 commit comments

Comments
 (0)