@@ -275,7 +275,7 @@ def __init__(
275275 self .filter_layers_strip = filter_layers_strip
276276 self .stats = None
277277
278- # add for compression
278+ # For geometric compression
279279 self .compress = False
280280 self .is_sorted = False
281281 self .compress_info = nn .ParameterList (
@@ -284,9 +284,8 @@ def __init__(
284284 self .compress_data = nn .ParameterList (
285285 [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = env .DEVICE ))]
286286 )
287- # For type embedding compression (strip mode, two-side only)
288- self .compress_type_embd = False
289- self .two_side_embd_data = None
287+ # For type embedding compression
288+ self .type_embd_data = None
290289
291290 def get_rcut (self ) -> float :
292291 """Returns the cut-off radius."""
@@ -453,49 +452,44 @@ def enable_compression(
453452 self .compress_data [0 ] = table_data [net ].to (device = env .DEVICE , dtype = self .prec )
454453 self .compress = True
455454
456- def enable_type_embedding_compression (
457- self , type_embedding_net : TypeEmbedNet
458- ) -> None :
459- """Enable type embedding compression for strip mode (two-side only).
460-
461- This method precomputes the type embedding network outputs for all possible
462- type pairs, following the same approach as TF backend's compression:
463-
464- TF approach:
465- 1. get_two_side_type_embedding(): creates (ntypes+1)^2 type pair combinations
466- 2. make_data(): applies embedding network to get precomputed outputs
467- 3. In forward: lookup precomputed values instead of real-time computation
455+ def type_embedding_compression (self , type_embedding_net : TypeEmbedNet ) -> None :
456+ """Enable type embedding compression for strip mode.
468457
469- PyTorch implementation:
470- - Precomputes all (ntypes+1)^2 type pair embedding network outputs
471- - Stores in buffer for proper serialization and device management
472- - Uses lookup during inference to avoid redundant computations
458+ Precomputes embedding network outputs for all type combinations:
459+ - One-side: (ntypes+1) combinations (neighbor types only)
460+ - Two-side: (ntypes+1)² combinations (neighbor x center type pairs)
473461
474462 Parameters
475463 ----------
476464 type_embedding_net : TypeEmbedNet
477465 The type embedding network that provides get_full_embedding() method
478466 """
479467 with torch .no_grad ():
480- # Get full type embedding: (ntypes+1) x t_dim
468+ # Get full type embedding: (ntypes+1) x tebd_dim
481469 full_embd = type_embedding_net .get_full_embedding (env .DEVICE )
482470 nt , t_dim = full_embd .shape
483471
484- # Create all type pair combinations [neighbor, center]
485- # for a fixed row i, all columns j have different neighbor types
486- embd_nei = full_embd .view (1 , nt , t_dim ).expand (nt , nt , t_dim )
487- # for a fixed row i, all columns j share the same center type i
488- embd_center = full_embd .view (nt , 1 , t_dim ).expand (nt , nt , t_dim )
489- two_side_embd = torch .cat ([embd_nei , embd_center ], dim = - 1 ).reshape (
490- - 1 , t_dim * 2
491- )
492-
493- # Apply strip embedding network and store
494- # index logic: index = center_type * nt + neighbor_type
495- self .two_side_embd_data = self .filter_layers_strip .networks [0 ](
496- two_side_embd
497- ).detach ()
498- self .compress_type_embd = True
472+ if self .type_one_side :
473+ # One-side: only neighbor types, much simpler!
474+ # Precompute for all (ntypes+1) neighbor types
475+ self .type_embd_data = self .filter_layers_strip .networks [0 ](
476+ full_embd
477+ ).detach ()
478+ else :
479+ # Two-side: all (ntypes+1)² type pair combinations
480+ # Create [neighbor, center] combinations
481+ # for a fixed row i, all columns j have different neighbor types
482+ embd_nei = full_embd .view (1 , nt , t_dim ).expand (nt , nt , t_dim )
483+ # for a fixed row i, all columns j share the same center type i
484+ embd_center = full_embd .view (nt , 1 , t_dim ).expand (nt , nt , t_dim )
485+ two_side_embd = torch .cat ([embd_nei , embd_center ], dim = - 1 ).reshape (
486+ - 1 , t_dim * 2
487+ )
488+ # Precompute for all type pairs
489+ # Index formula: idx = center_type * nt + neighbor_type
490+ self .type_embd_data = self .filter_layers_strip .networks [0 ](
491+ two_side_embd
492+ ).detach ()
499493
500494 def forward (
501495 self ,
@@ -622,22 +616,24 @@ def forward(
622616 nlist_index = nlist .reshape (nb , nloc * nnei )
623617 # nf x (nl x nnei)
624618 nei_type = torch .gather (extended_atype , dim = 1 , index = nlist_index )
625- # (nf x nl x nnei) x ng
626- nei_type_index = nei_type .view (- 1 , 1 ).expand (- 1 , ng ).type (torch .long )
627619 if self .type_one_side :
628- tt_full = self .filter_layers_strip .networks [0 ](type_embedding )
629- # (nf x nl x nnei) x ng
630- gg_t = torch .gather (tt_full , dim = 0 , index = nei_type_index )
620+ if self .type_embd_data is not None :
621+ tt_full = self .type_embd_data
622+ else :
623+ # (ntypes+1, tebd_dim) -> (ntypes+1, ng)
624+ tt_full = self .filter_layers_strip .networks [0 ](type_embedding )
625+ # (nf*nl*nnei,) -> (nf*nl*nnei, ng)
626+ gg_t = tt_full [nei_type .view (- 1 ).type (torch .long )]
631627 else :
632628 idx_i = torch .tile (
633629 atype .reshape (- 1 , 1 ) * ntypes_with_padding , [1 , nnei ]
634630 ).view (- 1 )
635631 idx_j = nei_type .view (- 1 )
636632 # (nf x nl x nnei)
637633 idx = (idx_i + idx_j ).to (torch .long )
638- if self .compress_type_embd and self . two_side_embd_data is not None :
634+ if self .type_embd_data is not None :
639635 # (ntypes^2, ng)
640- tt_full = self .two_side_embd_data
636+ tt_full = self .type_embd_data
641637 else :
642638 # (ntypes) * ntypes * nt
643639 type_embedding_nei = torch .tile (
0 commit comments