Skip to content

Commit d70a397

Browse files
committed
refactor(pt): Streamline type embedding compression logic in dpa1 and se_atten
- Simplified the type embedding compression process by consolidating methods and removing unnecessary conditions. - Enhanced clarity in the handling of one-side and two-side type embeddings, ensuring consistent functionality across both modes. - Updated comments for better understanding of the compression logic and its implications on performance. These changes contribute to cleaner code and improved maintainability of the descriptor model.
1 parent f5f8185 commit d70a397

File tree

2 files changed

+41
-47
lines changed

2 files changed

+41
-47
lines changed

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,10 +646,8 @@ def enable_compression(
646646
self.table.data, self.table_config, self.lower, self.upper
647647
)
648648

649-
# Enable type embedding compression only for two-side mode
650-
# TODO: why not enable for one-side mode? (do not consider this for now)
651-
if not self.se_atten.type_one_side:
652-
self.se_atten.enable_type_embedding_compression(self.type_embedding)
649+
# Enable type embedding compression
650+
self.se_atten.type_embedding_compression(self.type_embedding)
653651

654652
self.compress = True
655653

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def __init__(
275275
self.filter_layers_strip = filter_layers_strip
276276
self.stats = None
277277

278-
# add for compression
278+
# For geometric compression
279279
self.compress = False
280280
self.is_sorted = False
281281
self.compress_info = nn.ParameterList(
@@ -284,9 +284,8 @@ def __init__(
284284
self.compress_data = nn.ParameterList(
285285
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
286286
)
287-
# For type embedding compression (strip mode, two-side only)
288-
self.compress_type_embd = False
289-
self.two_side_embd_data = None
287+
# For type embedding compression
288+
self.type_embd_data = None
290289

291290
def get_rcut(self) -> float:
292291
"""Returns the cut-off radius."""
@@ -453,49 +452,44 @@ def enable_compression(
453452
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
454453
self.compress = True
455454

456-
def enable_type_embedding_compression(
457-
self, type_embedding_net: TypeEmbedNet
458-
) -> None:
459-
"""Enable type embedding compression for strip mode (two-side only).
460-
461-
This method precomputes the type embedding network outputs for all possible
462-
type pairs, following the same approach as TF backend's compression:
463-
464-
TF approach:
465-
1. get_two_side_type_embedding(): creates (ntypes+1)^2 type pair combinations
466-
2. make_data(): applies embedding network to get precomputed outputs
467-
3. In forward: lookup precomputed values instead of real-time computation
455+
def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
456+
"""Enable type embedding compression for strip mode.
468457
469-
PyTorch implementation:
470-
- Precomputes all (ntypes+1)^2 type pair embedding network outputs
471-
- Stores in buffer for proper serialization and device management
472-
- Uses lookup during inference to avoid redundant computations
458+
Precomputes embedding network outputs for all type combinations:
459+
- One-side: (ntypes+1) combinations (neighbor types only)
460+
- Two-side: (ntypes+1)² combinations (neighbor x center type pairs)
473461
474462
Parameters
475463
----------
476464
type_embedding_net : TypeEmbedNet
477465
The type embedding network that provides get_full_embedding() method
478466
"""
479467
with torch.no_grad():
480-
# Get full type embedding: (ntypes+1) x t_dim
468+
# Get full type embedding: (ntypes+1) x tebd_dim
481469
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
482470
nt, t_dim = full_embd.shape
483471

484-
# Create all type pair combinations [neighbor, center]
485-
# for a fixed row i, all columns j have different neighbor types
486-
embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
487-
# for a fixed row i, all columns j share the same center type i
488-
embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
489-
two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape(
490-
-1, t_dim * 2
491-
)
492-
493-
# Apply strip embedding network and store
494-
# index logic: index = center_type * nt + neighbor_type
495-
self.two_side_embd_data = self.filter_layers_strip.networks[0](
496-
two_side_embd
497-
).detach()
498-
self.compress_type_embd = True
472+
if self.type_one_side:
473+
# One-side: only neighbor types, much simpler!
474+
# Precompute for all (ntypes+1) neighbor types
475+
self.type_embd_data = self.filter_layers_strip.networks[0](
476+
full_embd
477+
).detach()
478+
else:
479+
# Two-side: all (ntypes+1)² type pair combinations
480+
# Create [neighbor, center] combinations
481+
# for a fixed row i, all columns j have different neighbor types
482+
embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
483+
# for a fixed row i, all columns j share the same center type i
484+
embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
485+
two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape(
486+
-1, t_dim * 2
487+
)
488+
# Precompute for all type pairs
489+
# Index formula: idx = center_type * nt + neighbor_type
490+
self.type_embd_data = self.filter_layers_strip.networks[0](
491+
two_side_embd
492+
).detach()
499493

500494
def forward(
501495
self,
@@ -622,22 +616,24 @@ def forward(
622616
nlist_index = nlist.reshape(nb, nloc * nnei)
623617
# nf x (nl x nnei)
624618
nei_type = torch.gather(extended_atype, dim=1, index=nlist_index)
625-
# (nf x nl x nnei) x ng
626-
nei_type_index = nei_type.view(-1, 1).expand(-1, ng).type(torch.long)
627619
if self.type_one_side:
628-
tt_full = self.filter_layers_strip.networks[0](type_embedding)
629-
# (nf x nl x nnei) x ng
630-
gg_t = torch.gather(tt_full, dim=0, index=nei_type_index)
620+
if self.type_embd_data is not None:
621+
tt_full = self.type_embd_data
622+
else:
623+
# (ntypes+1, tebd_dim) -> (ntypes+1, ng)
624+
tt_full = self.filter_layers_strip.networks[0](type_embedding)
625+
# (nf*nl*nnei,) -> (nf*nl*nnei, ng)
626+
gg_t = tt_full[nei_type.view(-1).type(torch.long)]
631627
else:
632628
idx_i = torch.tile(
633629
atype.reshape(-1, 1) * ntypes_with_padding, [1, nnei]
634630
).view(-1)
635631
idx_j = nei_type.view(-1)
636632
# (nf x nl x nnei)
637633
idx = (idx_i + idx_j).to(torch.long)
638-
if self.compress_type_embd and self.two_side_embd_data is not None:
634+
if self.type_embd_data is not None:
639635
# (ntypes^2, ng)
640-
tt_full = self.two_side_embd_data
636+
tt_full = self.type_embd_data
641637
else:
642638
# (ntypes) * ntypes * nt
643639
type_embedding_nei = torch.tile(

0 commit comments

Comments
 (0)