Skip to content

Commit cf98216

Browse files
authored
Merge branch 'fastmachinelearning:main' into oneapi_backend/experiment
2 parents 0b8ef13 + 18ccc61 commit cf98216

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

hls4ml/model/graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ def remove_node(self, node, rewire=True):
565565
if outputs[0] == nxt_inp:
566566
next_node.inputs[i] = inputs[0]
567567

568+
if node.outputs[0] in self.outputs:
569+
prev_node = node.get_input_node(node.inputs[0])
570+
self.outputs[self.outputs.index(node.outputs[0])] = prev_node.outputs[0]
568571
del self.output_vars[node.outputs[0]]
569572
del self.graph[node.name]
570573

hls4ml/model/optimizer/passes/move_scales.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from hls4ml.model.layers import ApplyAlpha, Constant, Conv, MatMul, Merge
1313
from hls4ml.model.optimizer import OptimizerPass
1414

15+
# These attributes should not be copied. (Should add the output name to this)
16+
_attrs_not_to_copy = ['trace', 'precision', 'scale', 'bias', 'scale_data', 'bias_data']
17+
1518

1619
class ScaleDownMatMul(OptimizerPass):
1720
'''Shift an ApplyAlpha below a MatMul'''
@@ -62,7 +65,7 @@ def transform(self, model, node):
6265

6366
output = node.get_output_variable()
6467
# to remove warning, since these get set again
65-
new_attrs = {k: v for k, v in apply_alpha.attributes.items() if k not in ('trace', 'precision')}
68+
new_attrs = {k: v for k, v in apply_alpha.attributes.items() if k not in _attrs_not_to_copy + apply_alpha.outputs}
6669

6770
can_propagate = False
6871
if not bias.shape and bias == 0:
@@ -258,7 +261,7 @@ def transform(self, model, node):
258261
return False
259262

260263
# to remove warning, since these get set again
261-
new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')}
264+
new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs}
262265
new_name = in0.name
263266
model.remove_node(in0)
264267

@@ -305,7 +308,7 @@ def transform(self, model, node):
305308
return False
306309

307310
# to remove warning, since these get set again
308-
new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')}
311+
new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs}
309312
new_name = in1.name
310313
model.remove_node(in1)
311314

@@ -329,7 +332,7 @@ def transform(self, model, node):
329332
return False
330333

331334
# to remove warning, since these get set again
332-
new_attrs = {k: v for k, v in in2.attributes.items() if k not in ('trace', 'precision')}
335+
new_attrs = {k: v for k, v in in2.attributes.items() if k not in _attrs_not_to_copy + in2.outputs}
333336
new_name = in2.name
334337
model.remove_node(in2)
335338

@@ -391,7 +394,7 @@ def transform(self, model, node):
391394
return False
392395

393396
# to remove warning, since these get set again
394-
new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')}
397+
new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs}
395398
new_name = in1.name
396399
model.remove_node(in0)
397400
model.remove_node(in1)
@@ -415,7 +418,7 @@ def transform(self, model, node):
415418
return False
416419

417420
# to remove warning, since these get set again
418-
new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')}
421+
new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs}
419422
new_name = in0.name
420423
model.remove_node(in0)
421424
model.remove_node(in2)
@@ -442,7 +445,7 @@ def transform(self, model, node):
442445
return False
443446

444447
# to remove warning, since these get set again
445-
new_attrs = {k: v for k, v in in1.attributes.items() if k not in ('trace', 'precision')}
448+
new_attrs = {k: v for k, v in in1.attributes.items() if k not in _attrs_not_to_copy + in1.outputs}
446449
new_name = in1.name
447450
model.remove_node(in1)
448451
model.remove_node(in2)
@@ -478,7 +481,7 @@ def transform(self, model, node):
478481
return False
479482

480483
# to remove warning, since these get set again
481-
new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')}
484+
new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs}
482485
new_name = in0.name
483486
model.remove_node(in0)
484487
model.remove_node(in1)

0 commit comments

Comments
 (0)