@@ -311,7 +311,27 @@ def _switch_transpose_and_node(self, node, trans, update_shape=True):
311311 self ._g .set_shape (node .output [0 ], new_shape )
312312 self ._g .set_shape (trans .output [0 ], shape )
313313 return True
314-
314+ # this is for the case where node has multiple outputs. e.g. split node.
315+ def _switch_transpose_and_node_with_multiple_outputs (self , node , trans , update_shape = True ):
316+ input_index = self ._get_input_index_for_trans (node , trans )
317+ for idx ,_output in enumerate (node .output ):
318+ shape = self ._g .get_shape (_output )
319+ nxt_nodes = self ._g .find_output_consumers (_output )
320+ if idx == 0 :
321+ transpose = trans
322+ self ._g .replace_input (node , node .input [input_index ], transpose .input [0 ], input_index )
323+ self ._g .replace_input (trans , trans .input [0 ], _output , 0 )
324+ else :
325+ transpose = self ._g .make_node ("Transpose" , [_output ], attr = {"perm" : trans .get_attr_value ("perm" )})
326+ for nxt_node in nxt_nodes :
327+ self ._g .replace_input (nxt_node , _output , transpose .output [0 ])
328+
329+ if update_shape and shape :
330+ perm_inv = invert_perm (transpose .get_attr_value ("perm" ))
331+ new_shape = [shape [i ] for i in perm_inv ]
332+ self ._g .set_shape (_output , new_shape )
333+ self ._g .set_shape (transpose .output [0 ], shape )
334+ return True
315335 # if return value is True, then it means Transpose is handled as designed
316336 # otherwise, it means that we skip handling since it is not in our support set
317337 def _handle_nhwc_tranpose (self , trans ):
@@ -694,6 +714,21 @@ def _split_handler(self, trans, node):
694714 new_axes_const = self ._g .make_const (utils .make_name (node .inputs [1 ].name ), new_axes_np )
695715 self ._g .replace_inputs (node , [node .input [0 ], new_axes_const .output [0 ]])
696716 return True
717+ # handling having branches
718+ if len (node .output ) > 1 :
719+ trans_rank = get_transpose_rank (trans )
720+ axes = node .get_attr_value ("axis" , 0 )
721+ perm = trans .get_attr ("perm" ).ints
722+ axes = [axes + trans_rank if axes < 0 else axes ]
723+ if split :
724+ new_axes_np = np .array (split , dtype = np .int64 )
725+ new_axes_const = self ._g .make_const (utils .make_name (node .inputs [1 ].name ), new_axes_np )
726+ # [Transpose -> Split -> next_nodes] -> [Split -> Transpose -> next_nodes]
727+ if not self ._switch_transpose_and_node_with_multiple_outputs (node , trans , 1 ):
728+ return False
729+ new_axes = [perm [a ] for a in axes ]
730+ node .set_attr ("axes" , new_axes )
731+ return True
697732 return False
698733
699734 def _unsqueeze_handler (self , trans , node ):
0 commit comments