Skip to content

Commit 2c829b2

Browse files
committed
feat(pt): Implement type embedding compression for two-side mode for se_atten
- Added functionality to enable type embedding compression in the class, specifically for two-side mode. - Introduced a new method to precompute type embedding network outputs for all type pair combinations, optimizing inference by using precomputed values. - Updated the method to utilize precomputed embeddings when compression is enabled, improving performance during inference. This enhancement allows for more efficient handling of type embeddings in the descriptor model.
1 parent 018cccf commit 2c829b2

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,12 @@ def enable_compression(
645645
self.se_atten.enable_compression(
646646
self.table.data, self.table_config, self.lower, self.upper
647647
)
648+
649+
# Enable type embedding compression only for two-side mode
650+
# TODO: why not enable for one-side mode? (do not consider this for now)
651+
if not self.se_atten.type_one_side:
652+
self.se_atten.enable_type_embedding_compression(self.type_embedding)
653+
648654
self.compress = True
649655

650656
def forward(

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
MLPLayer,
2828
NetworkCollection,
2929
)
30+
from deepmd.pt.model.network.network import (
31+
TypeEmbedNet,
32+
)
3033
from deepmd.pt.utils import (
3134
env,
3235
)
@@ -281,6 +284,9 @@ def __init__(
281284
self.compress_data = nn.ParameterList(
282285
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
283286
)
287+
# For type embedding compression (strip mode, two-side only)
288+
self.compress_type_embd = False
289+
self.two_side_embd_data = None
284290

285291
def get_rcut(self) -> float:
286292
"""Returns the cut-off radius."""
@@ -447,6 +453,50 @@ def enable_compression(
447453
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
448454
self.compress = True
449455

456+
def enable_type_embedding_compression(
457+
self, type_embedding_net: TypeEmbedNet
458+
) -> None:
459+
"""Enable type embedding compression for strip mode (two-side only).
460+
461+
This method precomputes the type embedding network outputs for all possible
462+
type pairs, following the same approach as TF backend's compression:
463+
464+
TF approach:
465+
1. get_two_side_type_embedding(): creates (ntypes+1)^2 type pair combinations
466+
2. make_data(): applies embedding network to get precomputed outputs
467+
3. In forward: lookup precomputed values instead of real-time computation
468+
469+
PyTorch implementation:
470+
- Precomputes all (ntypes+1)^2 type pair embedding network outputs
471+
- Stores in buffer for proper serialization and device management
472+
- Uses lookup during inference to avoid redundant computations
473+
474+
Parameters
475+
----------
476+
type_embedding_net : TypeEmbedNet
477+
The type embedding network that provides get_full_embedding() method
478+
"""
479+
with torch.no_grad():
480+
# Get full type embedding: (ntypes+1) x t_dim
481+
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
482+
nt, t_dim = full_embd.shape
483+
484+
# Create all type pair combinations [neighbor, center]
485+
# for a fixed row i, all columns j have different neighbor types
486+
embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
487+
# for a fixed row i, all columns j share the same center type i
488+
embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
489+
two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape(
490+
-1, t_dim * 2
491+
)
492+
493+
# Apply strip embedding network and store
494+
# index logic: index = center_type * nt + neighbor_type
495+
self.two_side_embd_data = self.filter_layers_strip.networks[0](
496+
two_side_embd
497+
).detach()
498+
self.compress_type_embd = True
499+
450500
def forward(
451501
self,
452502
nlist: torch.Tensor,
@@ -605,7 +655,12 @@ def forward(
605655
two_side_type_embedding = torch.cat(
606656
[type_embedding_nei, type_embedding_center], -1
607657
).reshape(-1, nt * 2)
608-
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
658+
if self.compress_type_embd and self.two_side_embd_data is not None:
659+
tt_full = self.two_side_embd_data
660+
else:
661+
tt_full = self.filter_layers_strip.networks[0](
662+
two_side_type_embedding
663+
)
609664
# (nf x nl x nnei) x ng
610665
gg_t = torch.gather(tt_full, dim=0, index=idx)
611666
# (nf x nl) x nnei x ng

0 commit comments

Comments
 (0)