@@ -553,6 +553,8 @@ def enable_compression(
553553 assert not self .se_ttebd .resnet_dt , (
554554 "Model compression error: descriptor resnet_dt must be false!"
555555 )
556+ if self .tebd_input_mode != "strip" :
557+ raise RuntimeError ("Cannot compress model when tebd_input_mode != 'strip'" )
556558 for tt in self .se_ttebd .exclude_types :
557559 if (tt [0 ] not in range (self .se_ttebd .ntypes )) or (
558560 tt [1 ] not in range (self .se_ttebd .ntypes )
@@ -573,9 +575,6 @@ def enable_compression(
573575 "Empty embedding-nets are not supported in model compression!"
574576 )
575577
576- if self .tebd_input_mode != "strip" :
577- raise RuntimeError ("Cannot compress model when tebd_input_mode == 'concat'" )
578-
579578 data = self .serialize ()
580579 self .table = DPTabulate (
581580 self ,
@@ -597,7 +596,11 @@ def enable_compression(
597596 )
598597
599598 self .se_ttebd .enable_compression (
600- self .table .data , self .table_config , self .lower , self .upper
599+ self .type_embedding ,
600+ self .table .data ,
601+ self .table_config ,
602+ self .lower ,
603+ self .upper ,
601604 )
602605 self .compress = True
603606
@@ -694,12 +697,17 @@ def __init__(
694697 self .stats = None
695698 # compression related variables
696699 self .compress = False
700+ # For geometric compression
697701 self .compress_info = nn .ParameterList (
698702 [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = "cpu" ))]
699703 )
700704 self .compress_data = nn .ParameterList (
701705 [nn .Parameter (torch .zeros (0 , dtype = self .prec , device = env .DEVICE ))]
702706 )
707+ # For type embedding compression
708+ self .register_buffer (
709+ "type_embd_data" , torch .zeros (0 , dtype = self .prec , device = env .DEVICE )
710+ )
703711
704712 def get_rcut (self ) -> float :
705713 """Returns the cut-off radius."""
@@ -986,31 +994,24 @@ def forward(
986994 nei_type_j = nei_type .unsqueeze (1 ).expand ([- 1 , nnei , - 1 ])
987995 idx_i = nei_type_i * ntypes_with_padding
988996 idx_j = nei_type_j
989- # (nf x nl x nt_i x nt_j) x ng
990- idx = (
991- (idx_i + idx_j )
992- .view (- 1 , 1 )
993- .expand (- 1 , ng )
994- .type (torch .long )
995- .to (torch .long )
996- )
997- # ntypes * (ntypes) * nt
998- type_embedding_i = torch .tile (
999- type_embedding .view (ntypes_with_padding , 1 , nt ),
1000- [1 , ntypes_with_padding , 1 ],
1001- )
1002- # (ntypes) * ntypes * nt
1003- type_embedding_j = torch .tile (
1004- type_embedding .view (1 , ntypes_with_padding , nt ),
1005- [ntypes_with_padding , 1 , 1 ],
1006- )
1007- # (ntypes * ntypes) * (nt+nt)
1008- two_side_type_embedding = torch .cat (
1009- [type_embedding_i , type_embedding_j ], - 1
1010- ).reshape (- 1 , nt * 2 )
1011- tt_full = self .filter_layers_strip .networks [0 ](two_side_type_embedding )
997+ idx = (idx_i + idx_j ).reshape (- 1 ).to (torch .long )
998+ if self .compress :
999+ tt_full = self .type_embd_data
1000+ else :
1001+ type_embedding_i = torch .tile (
1002+ type_embedding .view (ntypes_with_padding , 1 , nt ),
1003+ [1 , ntypes_with_padding , 1 ],
1004+ )
1005+ type_embedding_j = torch .tile (
1006+ type_embedding .view (1 , ntypes_with_padding , nt ),
1007+ [ntypes_with_padding , 1 , 1 ],
1008+ )
1009+ two_side_type_embedding = torch .cat (
1010+ [type_embedding_i , type_embedding_j ], - 1
1011+ ).reshape (- 1 , nt * 2 )
1012+ tt_full = self .filter_layers_strip .networks [0 ](two_side_type_embedding )
10121013 # (nfnl x nt_i x nt_j) x ng
1013- gg_t = torch . gather ( tt_full , dim = 0 , index = idx )
1014+ gg_t = tt_full [ idx ]
10141015 # (nfnl x nt_i x nt_j) x ng
10151016 gg_t = gg_t .reshape (nfnl , nnei , nnei , ng )
10161017 if self .smooth :
@@ -1042,6 +1043,7 @@ def forward(
10421043
10431044 def enable_compression (
10441045 self ,
1046+ type_embedding_net : TypeEmbedNet ,
10451047 table_data : dict ,
10461048 table_config : dict ,
10471049 lower : dict ,
@@ -1051,6 +1053,8 @@ def enable_compression(
10511053
10521054 Parameters
10531055 ----------
1056+ type_embedding_net : TypeEmbedNet
1057+ The type embedding network
10541058 table_data : dict
10551059 The tabulated data from DPTabulate
10561060 table_config : dict
@@ -1060,6 +1064,13 @@ def enable_compression(
10601064 upper : dict
10611065 Upper bounds for compression
10621066 """
1067+ if self .tebd_input_mode != "strip" :
1068+ raise RuntimeError ("Type embedding compression only works in strip mode" )
1069+ if self .filter_layers_strip is None :
1070+ raise RuntimeError (
1071+ "filter_layers_strip must exist for type embedding compression"
1072+ )
1073+
10631074 # Compress the main geometric embedding network (self.filter_layers)
10641075 net_key = "filter_net"
10651076 self .compress_info [0 ] = torch .as_tensor (
@@ -1078,6 +1089,22 @@ def enable_compression(
10781089 device = env .DEVICE , dtype = self .prec
10791090 )
10801091
1092+ # Compress the type embedding network (self.filter_layers_strip)
1093+ with torch .no_grad ():
1094+ full_embd = type_embedding_net .get_full_embedding (env .DEVICE )
1095+ nt , t_dim = full_embd .shape
1096+ type_embedding_i = full_embd .view (nt , 1 , t_dim ).expand (nt , nt , t_dim )
1097+ type_embedding_j = full_embd .view (1 , nt , t_dim ).expand (nt , nt , t_dim )
1098+ two_side_type_embedding = torch .cat (
1099+ [type_embedding_i , type_embedding_j ], dim = - 1
1100+ ).reshape (- 1 , t_dim * 2 )
1101+ embd_tensor = self .filter_layers_strip .networks [0 ](
1102+ two_side_type_embedding
1103+ ).detach ()
1104+ if hasattr (self , "type_embd_data" ):
1105+ del self .type_embd_data
1106+ self .register_buffer ("type_embd_data" , embd_tensor )
1107+
10811108 self .compress = True
10821109
10831110 def has_message_passing (self ) -> bool :
0 commit comments