1212from hls4ml .model .layers import ApplyAlpha , Constant , Conv , MatMul , Merge
1313from hls4ml .model .optimizer import OptimizerPass
1414
15+ # These attributes should not be copied. (Should add the output name to this)
16+ _attrs_not_to_copy = ['trace' , 'precision' , 'scale' , 'bias' , 'scale_data' , 'bias_data' ]
17+
1518
1619class ScaleDownMatMul (OptimizerPass ):
1720 '''Shift an ApplyAlpha below a MatMul'''
@@ -62,7 +65,7 @@ def transform(self, model, node):
6265
6366 output = node .get_output_variable ()
6467 # to remove warning, since these get set again
65- new_attrs = {k : v for k , v in apply_alpha .attributes .items () if k not in ( 'trace' , 'precision' ) }
68+ new_attrs = {k : v for k , v in apply_alpha .attributes .items () if k not in _attrs_not_to_copy + apply_alpha . outputs }
6669
6770 can_propagate = False
6871 if not bias .shape and bias == 0 :
@@ -258,7 +261,7 @@ def transform(self, model, node):
258261 return False
259262
260263 # to remove warning, since these get set again
261- new_attrs = {k : v for k , v in in0 .attributes .items () if k not in ( 'trace' , 'precision' ) }
264+ new_attrs = {k : v for k , v in in0 .attributes .items () if k not in _attrs_not_to_copy + in0 . outputs }
262265 new_name = in0 .name
263266 model .remove_node (in0 )
264267
@@ -305,7 +308,7 @@ def transform(self, model, node):
305308 return False
306309
307310 # to remove warning, since these get set again
308- new_attrs = {k : v for k , v in in0 .attributes .items () if k not in ( 'trace' , 'precision' ) }
311+ new_attrs = {k : v for k , v in in0 .attributes .items () if k not in _attrs_not_to_copy + in0 . outputs }
309312 new_name = in1 .name
310313 model .remove_node (in1 )
311314
@@ -329,7 +332,7 @@ def transform(self, model, node):
329332 return False
330333
331334 # to remove warning, since these get set again
332- new_attrs = {k : v for k , v in in2 .attributes .items () if k not in ( 'trace' , 'precision' ) }
335+ new_attrs = {k : v for k , v in in2 .attributes .items () if k not in _attrs_not_to_copy + in2 . outputs }
333336 new_name = in2 .name
334337 model .remove_node (in2 )
335338
@@ -391,7 +394,7 @@ def transform(self, model, node):
391394 return False
392395
393396 # to remove warning, since these get set again
394- new_attrs = {k : v for k , v in in0 .attributes .items () if k not in ( 'trace' , 'precision' ) }
397+ new_attrs = {k : v for k , v in in0 .attributes .items () if k not in _attrs_not_to_copy + in0 . outputs }
395398 new_name = in1 .name
396399 model .remove_node (in0 )
397400 model .remove_node (in1 )
@@ -415,7 +418,7 @@ def transform(self, model, node):
415418 return False
416419
417420 # to remove warning, since these get set again
418- new_attrs = {k : v for k , v in in0 .attributes .items () if k not in ( 'trace' , 'precision' ) }
421+ new_attrs = {k : v for k , v in in0 .attributes .items () if k not in _attrs_not_to_copy + in0 . outputs }
419422 new_name = in0 .name
420423 model .remove_node (in0 )
421424 model .remove_node (in2 )
@@ -442,7 +445,7 @@ def transform(self, model, node):
442445 return False
443446
444447 # to remove warning, since these get set again
445- new_attrs = {k : v for k , v in in1 .attributes .items () if k not in ( 'trace' , 'precision' ) }
448+ new_attrs = {k : v for k , v in in1 .attributes .items () if k not in _attrs_not_to_copy + in1 . outputs }
446449 new_name = in1 .name
447450 model .remove_node (in1 )
448451 model .remove_node (in2 )
@@ -478,7 +481,7 @@ def transform(self, model, node):
478481 return False
479482
480483 # to remove warning, since these get set again
481- new_attrs = {k : v for k , v in in0 .attributes .items () if k not in ( 'trace' , 'precision' ) }
484+ new_attrs = {k : v for k , v in in0 .attributes .items () if k not in _attrs_not_to_copy + in0 . outputs }
482485 new_name = in0 .name
483486 model .remove_node (in0 )
484487 model .remove_node (in1 )
0 commit comments