@@ -393,121 +393,12 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
393393
394394
395395def weights_torch (model , fmt = 'longform' , plot = 'boxplot' ):
396+ from hls4ml .utils .profiling_utils import WeightsTorch
397+
396398 wt = WeightsTorch (model , fmt , plot )
397399 return wt .get_weights ()
398400
399401
400- def _torch_batchnorm (layer ):
401- weights = list (layer .parameters ())
402- epsilon = layer .eps
403-
404- gamma = weights [0 ]
405- beta = weights [1 ]
406- if layer .track_running_stats :
407- mean = layer .running_mean
408- var = layer .running_var
409- else :
410- mean = torch .tensor (np .ones (20 ))
411- var = torch .tensor (np .zeros (20 ))
412-
413- scale = gamma / np .sqrt (var + epsilon )
414- bias = beta - gamma * mean / np .sqrt (var + epsilon )
415-
416- return [scale , bias ], ['s' , 'b' ]
417-
418-
419- def _torch_layer (layer ):
420- return list (layer .parameters ()), ['w' , 'b' ]
421-
422-
423- def _torch_rnn (layer ):
424- return list (layer .parameters ()), ['w_ih_l0' , 'w_hh_l0' , 'b_ih_l0' , 'b_hh_l0' ]
425-
426-
427- torch_process_layer_map = defaultdict (
428- lambda : _torch_layer ,
429- {
430- 'BatchNorm1d' : _torch_batchnorm ,
431- 'BatchNorm2d' : _torch_batchnorm ,
432- 'RNN' : _torch_rnn ,
433- 'LSTM' : _torch_rnn ,
434- 'GRU' : _torch_rnn ,
435- },
436- )
437-
438-
439- class WeightsTorch :
440- def __init__ (self , model : torch .nn .Module , fmt : str = 'longform' , plot : str = 'boxplot' ) -> None :
441- self .model = model
442- self .fmt = fmt
443- self .plot = plot
444- self .registered_layers = list ()
445- self ._find_layers (self .model , self .model .__class__ .__name__ )
446-
447- def _find_layers (self , model , module_name ):
448- for name , module in model .named_children ():
449- if isinstance (module , (torch .nn .Sequential , torch .nn .ModuleList )):
450- self ._find_layers (module , module_name + "." + name )
451- elif isinstance (module , (torch .nn .Module )) and self ._is_parameterized (module ):
452- if len (list (module .named_children ())) != 0 :
453- # custom nn.Module, continue search
454- self ._find_layers (module , module_name + "." + name )
455- else :
456- self ._register_layer (module_name + "." + name )
457-
458- def _is_registered (self , name : str ) -> bool :
459- return name in self .registered_layers
460-
461- def _register_layer (self , name : str ) -> None :
462- if self ._is_registered (name ) is False :
463- self .registered_layers .append (name )
464-
465- def _is_parameterized (self , module : torch .nn .Module ) -> bool :
466- return any (p .requires_grad for p in module .parameters ())
467-
468- def _get_weights (self ) -> pandas .DataFrame | list [dict ]:
469- if self .fmt == 'longform' :
470- data = {'x' : [], 'layer' : [], 'weight' : []}
471- elif self .fmt == 'summary' :
472- data = []
473- for layer_name in self .registered_layers :
474- layer = self ._get_layer (layer_name , self .model )
475- name = layer .__class__ .__name__
476- weights , suffix = torch_process_layer_map [layer .__class__ .__name__ ](layer )
477- for i , w in enumerate (weights ):
478- label = f'{ name } /{ suffix [i ]} '
479- w = weights [i ].detach ().numpy ()
480- w = w .flatten ()
481- w = abs (w [w != 0 ])
482- n = len (w )
483- if n == 0 :
484- print (f'Weights for { name } are only zeros, ignoring.' )
485- break
486- if self .fmt == 'longform' :
487- data ['x' ].extend (w .tolist ())
488- data ['layer' ].extend ([name ] * n )
489- data ['weight' ].extend ([label ] * n )
490- elif self .fmt == 'summary' :
491- data .append (array_to_summary (w , fmt = self .plot ))
492- data [- 1 ]['layer' ] = name
493- data [- 1 ]['weight' ] = label
494-
495- if self .fmt == 'longform' :
496- data = pandas .DataFrame (data )
497- return data
498-
499- def get_weights (self ) -> pandas .DataFrame | list [dict ]:
500- return self ._get_weights ()
501-
502- def get_layers (self ) -> list [str ]:
503- return self .registered_layers
504-
505- def _get_layer (self , layer_name : str , module : torch .nn .Module ) -> torch .nn .Module :
506- for name in layer_name .split ('.' )[1 :]:
507- module = getattr (module , name )
508- return module
509-
510-
511402def activations_torch (model , X , fmt = 'longform' , plot = 'boxplot' ):
512403 X = torch .Tensor (X )
513404 if fmt == 'longform' :
0 commit comments