2222
2323def export (
2424 gm : torch .fx .GraphModule ,
25- cross_compile_flag : Optional [bool ] = False ,
25+ cross_compile_module : Optional [bool ] = False ,
2626) -> ExportedProgram :
2727 """Export the result of TensorRT compilation into the desired output format.
2828
2929 Arguments:
3030 gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3131 inputs (torch.Tensor): Torch input tensors
32- cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
32+ cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
3333 """
34- patched_module = transform (gm , cross_compile_flag )
34+ patched_module = transform (gm , cross_compile_module )
3535 exp_program = create_trt_exp_program (patched_module )
3636 return exp_program
3737
3838
3939def transform (
4040 gm : torch .fx .GraphModule ,
41- cross_compile_flag : Optional [bool ] = False ,
41+ cross_compile_module : Optional [bool ] = False ,
4242) -> torch .fx .GraphModule :
4343 """
4444 Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
@@ -48,7 +48,7 @@ def transform(
4848 Arguments:
4949 gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
5050 inputs (torch.Tensor): Torch input tensors
51- cross_compile_flag (bool): Flag to indicated whether it is cross_compilation enabled or not
51+ cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not
5252
5353 Returns an inlined torch.fx.GraphModule
5454 """
@@ -57,7 +57,7 @@ def transform(
5757 gm = copy .deepcopy (gm )
5858
5959 # Inline TensorRT submodules
60- inline_trt_modules (gm , cross_compile_flag )
60+ inline_trt_modules (gm , cross_compile_module )
6161
6262 # Inline pytorch submodules
6363 inline_torch_modules (gm )
@@ -356,7 +356,7 @@ def create_trt_exp_program(
356356
357357
358358def inline_trt_modules (
359- gm : torch .fx .GraphModule , cross_compile_flag : Optional [bool ] = False
359+ gm : torch .fx .GraphModule , cross_compile_module : Optional [bool ] = False
360360) -> torch .fx .GraphModule :
361361 """
362362 Replace TRT submodules with trt engine nodes.
@@ -380,7 +380,16 @@ def inline_trt_modules(
380380 num_outputs = len (trt_module_node .meta ["val" ])
381381 # Insert a call_function node to perform inference on TRT engine
382382 with gm .graph .inserting_before (trt_module_node ):
383- if not cross_compile_flag :
383+ if cross_compile_module :
384+ engine_info = trt_module ._pack_engine_info ()
385+ engine_bytes = engine_info [ENGINE_IDX ]
386+ engine_info [ENGINE_IDX ] = base64 .b64encode (engine_bytes ).decode ("utf-8" )
387+ # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
388+ trt_node = gm .graph .call_function (
389+ torch .ops .tensorrt .no_op_placeholder_for_execute_engine .default ,
390+ (trt_module_node .args , * engine_info ),
391+ )
392+ else :
384393 # for the normal workflow: use the execute_engine node
385394 engine_name = f"{ name } _engine"
386395 setattr (gm , engine_name , trt_module .engine )
@@ -396,16 +405,6 @@ def inline_trt_modules(
396405 engine_node .meta ["val" ] = CustomObjArgument (
397406 name = engine_node .name , class_fqn = ""
398407 )
399- else :
400- # for the cross compile for windows workflow: use the no_op_placeholder node
401- engine_info = trt_module ._pack_engine_info ()
402- engine_bytes = engine_info [ENGINE_IDX ]
403- engine_info [ENGINE_IDX ] = base64 .b64encode (engine_bytes ).decode ("utf-8" )
404- # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
405- trt_node = gm .graph .call_function (
406- torch .ops .tensorrt .no_op_placeholder_for_execute_engine .default ,
407- (trt_module_node .args , * engine_info ),
408- )
409408 # set trt_node.meta with trt_module_node.meta
410409 assert num_outputs > 0
411410 trt_node .meta ["val" ] = trt_module_node .meta ["val" ]
@@ -464,16 +463,10 @@ def replace_execute_engine_no_op_node(
464463 name = engine_node .name , class_fqn = ""
465464 )
466465
467- if len (no_op_placeholder_node .meta ["val" ]) == 1 :
468- with gm .graph .inserting_after (trt_node ):
469- getitem_output = gm .graph .call_function (operator .getitem , (trt_node , 0 ))
470- getitem_output .meta ["val" ] = trt_node .meta ["val" ]
471- no_op_placeholder_node .replace_all_uses_with (getitem_output )
472- else :
473- no_op_placeholder_node .replace_all_uses_with (trt_node )
474- getitem_nodes = trt_node .users
475- for idx , getitem_node in enumerate (getitem_nodes ):
476- getitem_node .meta ["val" ] = trt_node .meta ["val" ][idx ]
466+ no_op_placeholder_node .replace_all_uses_with (trt_node )
467+ getitem_nodes = trt_node .users
468+ for idx , getitem_node in enumerate (getitem_nodes ):
469+ getitem_node .meta ["val" ] = trt_node .meta ["val" ][idx ]
477470
478471 gm .graph .erase_node (no_op_placeholder_node )
479472
0 commit comments