@@ -599,6 +599,7 @@ def enable_compression(
599599 self .se_ttebd .enable_compression (
600600 self .table .data , self .table_config , self .lower , self .upper
601601 )
602+ self .se_ttebd .type_embedding_compression (self .type_embedding )
602603 self .compress = True
603604
604605
@@ -700,6 +701,7 @@ def __init__(
700701 self .compress_data = nn .ParameterList (
701702 [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = env .DEVICE ))]
702703 )
704+ self .type_embd_data : Optional [torch .Tensor ] = None
703705
704706 def get_rcut (self ) -> float :
705707 """Returns the cut-off radius."""
@@ -838,6 +840,27 @@ def reinit_exclude(
838840 self .exclude_types = exclude_types
839841 self .emask = PairExcludeMask (self .ntypes , exclude_types = exclude_types )
840842
843+ def type_embedding_compression (self , type_embedding_net : TypeEmbedNet ) -> None :
844+ """Precompute strip-mode type embeddings for all type pairs."""
845+ if self .tebd_input_mode != "strip" :
846+ raise RuntimeError ("Type embedding compression only works in strip mode" )
847+ if self .filter_layers_strip is None :
848+ raise RuntimeError (
849+ "filter_layers_strip must exist for type embedding compression"
850+ )
851+
852+ with torch .no_grad ():
853+ full_embd = type_embedding_net .get_full_embedding (env .DEVICE )
854+ nt , t_dim = full_embd .shape
855+ type_embedding_i = full_embd .view (nt , 1 , t_dim ).expand (nt , nt , t_dim )
856+ type_embedding_j = full_embd .view (1 , nt , t_dim ).expand (nt , nt , t_dim )
857+ two_side_type_embedding = torch .cat (
858+ [type_embedding_i , type_embedding_j ], dim = - 1
859+ ).reshape (- 1 , t_dim * 2 )
860+ self .type_embd_data = self .filter_layers_strip .networks [0 ](
861+ two_side_type_embedding
862+ ).detach ()
863+
841864 def forward (
842865 self ,
843866 nlist : torch .Tensor ,
@@ -986,31 +1009,24 @@ def forward(
9861009 nei_type_j = nei_type .unsqueeze (1 ).expand ([- 1 , nnei , - 1 ])
9871010 idx_i = nei_type_i * ntypes_with_padding
9881011 idx_j = nei_type_j
989- # (nf x nl x nt_i x nt_j) x ng
990- idx = (
991- (idx_i + idx_j )
992- .view (- 1 , 1 )
993- .expand (- 1 , ng )
994- .type (torch .long )
995- .to (torch .long )
996- )
997- # ntypes * (ntypes) * nt
998- type_embedding_i = torch .tile (
999- type_embedding .view (ntypes_with_padding , 1 , nt ),
1000- [1 , ntypes_with_padding , 1 ],
1001- )
1002- # (ntypes) * ntypes * nt
1003- type_embedding_j = torch .tile (
1004- type_embedding .view (1 , ntypes_with_padding , nt ),
1005- [ntypes_with_padding , 1 , 1 ],
1006- )
1007- # (ntypes * ntypes) * (nt+nt)
1008- two_side_type_embedding = torch .cat (
1009- [type_embedding_i , type_embedding_j ], - 1
1010- ).reshape (- 1 , nt * 2 )
1011- tt_full = self .filter_layers_strip .networks [0 ](two_side_type_embedding )
1012+ idx = (idx_i + idx_j ).reshape (- 1 ).to (torch .long )
1013+ if self .type_embd_data is not None :
1014+ tt_full = self .type_embd_data
1015+ else :
1016+ type_embedding_i = torch .tile (
1017+ type_embedding .view (ntypes_with_padding , 1 , nt ),
1018+ [1 , ntypes_with_padding , 1 ],
1019+ )
1020+ type_embedding_j = torch .tile (
1021+ type_embedding .view (1 , ntypes_with_padding , nt ),
1022+ [ntypes_with_padding , 1 , 1 ],
1023+ )
1024+ two_side_type_embedding = torch .cat (
1025+ [type_embedding_i , type_embedding_j ], - 1
1026+ ).reshape (- 1 , nt * 2 )
1027+ tt_full = self .filter_layers_strip .networks [0 ](two_side_type_embedding )
10121028 # (nfnl x nt_i x nt_j) x ng
1013- gg_t = torch . gather ( tt_full , dim = 0 , index = idx )
1029+ gg_t = tt_full [ idx ]
10141030 # (nfnl x nt_i x nt_j) x ng
10151031 gg_t = gg_t .reshape (nfnl , nnei , nnei , ng )
10161032 if self .smooth :
0 commit comments