Skip to content

Commit f5f8185

Browse files
committed
refactor(pt): Simplify index calculation and enhance type embedding handling in se_atten
- Streamlined the index calculation for neighbor types by removing unnecessary reshaping and expansion. - Improved clarity in the handling of type embeddings, ensuring that the logic for two-side embeddings remains intact while enhancing readability. - Maintained functionality for both compressed and uncompressed type embeddings, optimizing the inference process. These changes contribute to cleaner code and maintain the performance benefits introduced in previous commits.
1 parent 2c829b2 commit f5f8185

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -633,36 +633,31 @@ def forward(
633633
atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei]
634634
).view(-1)
635635
idx_j = nei_type.view(-1)
636-
# (nf x nl x nnei) x ng
637-
idx = (
638-
(idx_i + idx_j)
639-
.view(-1, 1)
640-
.expand(-1, ng)
641-
.type(torch.long)
642-
.to(torch.long)
643-
)
644-
# (ntypes) * ntypes * nt
645-
type_embedding_nei = torch.tile(
646-
type_embedding.view(1, ntypes_with_padding, nt),
647-
[ntypes_with_padding, 1, 1],
648-
)
649-
# ntypes * (ntypes) * nt
650-
type_embedding_center = torch.tile(
651-
type_embedding.view(ntypes_with_padding, 1, nt),
652-
[1, ntypes_with_padding, 1],
653-
)
654-
# (ntypes * ntypes) * (nt+nt)
655-
two_side_type_embedding = torch.cat(
656-
[type_embedding_nei, type_embedding_center], -1
657-
).reshape(-1, nt * 2)
636+
# (nf x nl x nnei)
637+
idx = (idx_i + idx_j).to(torch.long)
658638
if self.compress_type_embd and self.two_side_embd_data is not None:
639+
# (ntypes^2, ng)
659640
tt_full = self.two_side_embd_data
660641
else:
642+
# (ntypes) * ntypes * nt
643+
type_embedding_nei = torch.tile(
644+
type_embedding.view(1, ntypes_with_padding, nt),
645+
[ntypes_with_padding, 1, 1],
646+
)
647+
# ntypes * (ntypes) * nt
648+
type_embedding_center = torch.tile(
649+
type_embedding.view(ntypes_with_padding, 1, nt),
650+
[1, ntypes_with_padding, 1],
651+
)
652+
# (ntypes * ntypes) * (nt+nt)
653+
two_side_type_embedding = torch.cat(
654+
[type_embedding_nei, type_embedding_center], -1
655+
).reshape(-1, nt * 2)
661656
tt_full = self.filter_layers_strip.networks[0](
662657
two_side_type_embedding
663658
)
664659
# (nf x nl x nnei) x ng
665-
gg_t = torch.gather(tt_full, dim=0, index=idx)
660+
gg_t = tt_full[idx]
666661
# (nf x nl) x nnei x ng
667662
gg_t = gg_t.reshape(nfnl, nnei, ng)
668663
if self.smooth:

0 commit comments

Comments
 (0)