@@ -246,7 +246,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
246246 exceptions = []
247247 if initialized_tables is None :
248248 initialized_tables = {}
249-
249+
250250 ops = list (g .get_nodes ())
251251 for node in ops :
252252 logger .debug ("Process node: %s\n %s" , node .name , node .summary )
@@ -263,7 +263,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
263263 logger .error ("Tensorflow op [%s: %s] is not supported" , node .name , op )
264264 continue
265265 mapped_op [op ] += 1
266-
266+
267267 func , kwargs = map_info
268268 if kwargs :
269269 # if there is a tf_op/onnx_op key we'll map the old type to a new type
@@ -273,6 +273,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
273273 kwargs ["tfl_op" if is_tflite else "tf_op" ] = op
274274 node .type = converted_op
275275 body_graphs = node .get_body_graphs ()
276+
276277 if body_graphs :
277278 for attr , b_g in body_graphs .items ():
278279 logger .debug ("start handling subgraph of %s's attribute %s" , node .name , attr )
@@ -287,7 +288,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
287288 b_g .topological_sort (b_g .get_nodes ())
288289 exceptions .extend (body_exceptions )
289290 logger .debug ("finish handling subgraph of %s's attribute %s" , node .name , attr )
290-
291+
291292 try :
292293 func (g , node , ** kwargs , initialized_tables = initialized_tables , dequantize = dequantize )
293294 if not is_tflite :
@@ -302,7 +303,6 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
302303 logger .error ("Failed to convert node %r (fct=%r)\n %r" ,
303304 node .name , func , summary , exc_info = 1 )
304305 exceptions .append (ex )
305-
306306 return mapped_op , unmapped_op , exceptions
307307
308308
@@ -332,26 +332,96 @@ def transpose_inputs(ctx, inputs_as_nchw):
332332def transpose_outputs (ctx , outputs_as_nchw ):
333333 """Insert a transpose from NHWC to NCHW on model output on users request."""
334334 ops = []
335+
336+ # First pass: Find and handle edge cases in original nodes
337+ edge_case_handled = set ()
338+
335339 for node in ctx .get_nodes ():
336340 for output_name in node .output :
337- if output_name in outputs_as_nchw :
341+ # Check if this output is used to create a model output
342+ consumers = ctx .find_output_consumers (output_name )
343+
344+ # Look for edge case: output consumed by both model output node and other nodes
345+ model_output_consumers = []
346+ other_consumers = []
347+
348+ for consumer in consumers :
349+ if consumer .output and any (out in outputs_as_nchw for out in consumer .output ):
350+ model_output_consumers .append (consumer )
351+ else :
352+ other_consumers .append (consumer )
353+
354+ # Edge case: original node output goes to both model output and other layers
355+ if model_output_consumers and other_consumers :
356+ # Get shape for validation
357+ shape = ctx .get_shape (output_name )
358+ if len (shape ) != len (constants .NHWC_TO_NCHW ):
359+ continue
360+
361+ # Handle edge case: Use insert_node_on_output for proper structure
362+ # Step 1: Create Identity node and insert it on the original output
363+ identity_name = utils .make_name (node .name + "_identity" )
364+ identity = ctx .make_node ("Identity" , [output_name ],
365+ outputs = [identity_name + ":0" ], name = identity_name )
366+
367+ # Copy shape information
368+ ctx .copy_shape (output_name , identity .output [0 ])
369+ ctx .set_shape (identity .output [0 ], shape )
370+
371+ # Insert the identity on the original output - this will redirect ALL consumers
372+ ctx .insert_node_on_output (identity , output_name )
373+
374+ # Step 2: Create Transpose node and connect it to Identity
375+ transpose_name = utils .make_name (identity .name + "_transpose" )
376+ transpose = ctx .make_node ("Transpose" , [identity .output [0 ]],
377+ outputs = [transpose_name + ":0" ], name = transpose_name )
378+ transpose .set_attr ("perm" , constants .NHWC_TO_NCHW )
379+ ctx .copy_shape (identity .output [0 ], transpose .output [0 ])
380+ ctx .set_shape (transpose .output [0 ], np .array (shape )[constants .NHWC_TO_NCHW ])
381+
382+ # Step 3: Manually redirect ONLY the model output consumers to use transpose
383+ for consumer in model_output_consumers :
384+ ctx .replace_all_inputs (identity .output [0 ], transpose .output [0 ], ops = [consumer ])
385+
386+ # Mark this output as handled
387+ edge_case_handled .add (output_name )
388+
389+ ops .append (node )
390+ ops .append (identity )
391+ ops .append (transpose )
392+ break # Only handle one edge case per node
393+
394+ # If no edge case was handled for this node, add it normally
395+ if not any (out in edge_case_handled for out in node .output ):
396+ ops .append (node )
397+
398+ # Second pass: Handle normal cases (nodes that directly output to model outputs)
399+ final_ops = []
400+ for node in ops :
401+ handled = False
402+ for output_name in node .output :
403+ if output_name in outputs_as_nchw and output_name not in edge_case_handled :
404+ # Get shape for validation
338405 shape = ctx .get_shape (output_name )
339406 if len (shape ) != len (constants .NHWC_TO_NCHW ):
340407 logger .warning ("transpose_output for %s: shape must be rank 4, ignored" % output_name )
341- ops .append (node )
342408 continue
409+
343410 # insert transpose
344411 op_name = utils .make_name (node .name )
345- transpose = ctx .insert_new_node_on_output ("Transpose" , node . input [ 0 ] , name = op_name )
412+ transpose = ctx .insert_new_node_on_output ("Transpose" , output_name , name = op_name )
346413 transpose .set_attr ("perm" , constants .NHWC_TO_NCHW )
347- ctx .copy_shape (node .output [0 ], transpose .output [0 ])
348- ctx .set_shape (transpose .output [0 ], np .array (shape )[constants .NHWC_TO_NCHW ])
414+ ctx .copy_shape (output_name , transpose .output [0 ])
349415 ctx .set_shape (output_name , np .array (shape )[constants .NHWC_TO_NCHW ])
350- ops .append (transpose )
351- ops .append (node )
352- continue
353- ops .append (node )
354- ctx .reset_nodes (ops )
416+ final_ops .append (transpose )
417+ final_ops .append (node )
418+ handled = True
419+ break
420+
421+ if not handled :
422+ final_ops .append (node )
423+
424+ ctx .reset_nodes (final_ops )
355425
356426def topological_sort (g , continue_on_error ):
357427 ops = g .get_nodes ()
@@ -522,7 +592,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
522592 initialized_tables , is_tflite = False , dequantize = False ):
523593
524594 op_cnt , attr_cnt = g .dump_node_statistics (include_attrs = True , include_subgraphs = False )
525-
595+
526596 if is_tflite :
527597 tfl_rewriters = []
528598 if dequantize :
@@ -531,13 +601,16 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
531601 tfl_rewriters .append (rewrite_tfl_select_zero )
532602 tfl_rewriters .append (rewrite_tfl_rfft )
533603 run_rewriters (g , tfl_rewriters , continue_on_error )
604+
534605 tfl_ops_mapping = handler .tfl_op .create_tfl_to_tf_mapping ()
535606 _ , _ , exceptions = tensorflow_onnx_mapping (g , tfl_ops_mapping , is_tflite = True , dequantize = False )
607+
536608 if exceptions and not continue_on_error :
537609 raise exceptions [0 ]
538610
539611 # create ops mapping for the desired opsets
540612 ops_mapping = handler .tf_op .create_mapping (g .opset , g .extra_opset )
613+
541614
542615 # apply custom ops on top of the assembled opset. We can either complement the opset
543616 # or override existing ops with a custom op.
0 commit comments