@@ -88,6 +88,7 @@ def debugprint(
8888 | FunctionGraph
8989 | Sequence [Variable | Apply | Function | FunctionGraph ],
9090 depth : int = - 1 ,
91+ inner_depth : int = - 1 ,
9192 print_type : bool = False ,
9293 print_shape : bool = False ,
9394 file : Literal ["str" ] | TextIO | None = None ,
@@ -299,7 +300,7 @@ def debugprint(
299300 isinstance (var .owner .op , HasInnerGraph )
300301 or hasattr (var .owner .op , "scalar_op" )
301302 and isinstance (var .owner .op .scalar_op , HasInnerGraph )
302- ) and var not in inner_graph_vars :
303+ ) and not inner_depth and var not in inner_graph_vars :
303304 inner_graph_vars .append (var )
304305 if print_op_info :
305306 op_information .update (op_debug_information (var .owner .op , var .owner ))
@@ -325,7 +326,7 @@ def debugprint(
325326 print_view_map = print_view_map ,
326327 )
327328
328- if len (inner_graph_vars ) > 0 and print_inner_graphs :
329+ if len (inner_graph_vars ) > 0 and inner_depth :
329330 print ("" , file = _file )
330331 prefix = ""
331332 new_prefix = prefix + " ← "
@@ -377,7 +378,7 @@ def debugprint(
377378 _debugprint (
378379 ig_var ,
379380 prefix = prefix ,
380- depth = depth ,
381+ depth = inner_depth ,
381382 done = done ,
382383 print_type = print_type ,
383384 print_shape = print_shape ,
@@ -400,7 +401,7 @@ def debugprint(
400401 _debugprint (
401402 inp ,
402403 prefix = " → " ,
403- depth = depth ,
404+ depth = inner_depth ,
404405 done = done ,
405406 print_type = print_type ,
406407 print_shape = print_shape ,
@@ -435,7 +436,7 @@ def debugprint(
435436 _debugprint (
436437 out ,
437438 prefix = new_prefix ,
438- depth = depth ,
439+ depth = inner_depth ,
439440 done = done ,
440441 print_type = print_type ,
441442 print_shape = print_shape ,
0 commit comments