Skip to content

[BUG] Serving tensorrt model with CUDA graph results in weird unconsistent outputs. #8550

@WingEdge777

Description

@WingEdge777

Description
When serving a TensorRT engine with CUDA graph optimization enabled, we encountered a weird phenomenon.

We send requests sequentially, following the AAAAABBBBBAAAABBBB pattern. In every A(B)‘s requesting period, the first few A(B) requests probably return the last period round B(A)'s results, which is absurdly wrong for the current input.

Image

However, if we remove the CUDA graph optimization config. All the results become consistent and correct with respect to A(B).

Image

Thus, we highly suspect that Triton or tensorrt backend or TRT itself is reusing the wrong/dirty/uninitialized inputs/buffers, etc.

CUDA graph is an essential optimization for many user cases. This issue prevents us from upgrading CUDA/TRT/tritonserver etc.

Triton Information
I am using NVIDIA NGC TritonServer 25.10, which uses TritonServer 2.62.0 and TensorRT 10.13 according to the release notes. I'm running the server on L20/A10 GPUs along with Nvidia driver 535.161.08, of course, using the NVIDIA compat lib.

To Reproduce
The reproduction instruction is simple. Export a ResNet 50 engine, and serve it with CUDA graph optimization. Then send requests to the server.

Refer to the code script as follows:

1、export onnx:

import torch
import torchvision.models as models


model = models.resnet50(weights=None)

model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

output_onnx_file = "resnet50.onnx"

torch.onnx.export(
    model,
    dummy_input,
    output_onnx_file,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    dynamo=False,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

2、build TensorRT engine:

polygraphy convert ./resnet50.onnx --convert-to trt --output ./resnet50.plan \
    --fp16 \
    --trt-min-shapes input:[1,3,224,224]  \
    --trt-opt-shapes input:[4,3,224,224]   \
    --trt-max-shapes input:[8,3,224,224]

3、setup model_zoo and model config config.pbtxt:

name: "resnet_50"
backend: "tensorrt"
max_batch_size: 8
model_warmup: {
    name: "sample"
    batch_size: 1
    inputs: {
        key: "input"
        value: {
            data_type: TYPE_FP32,
            dims: [3, 224, 224],
	        zero_data: true
        }
    }
}

optimization{
   graph: {
       level : 1
   },
   eager_batching : 1,
   cuda: {
       graphs: 1,
       graph_spec: [
            { batch_size: 1 },
            { batch_size: 2 },
            { batch_size: 3 },
            { batch_size: 4 },
            { batch_size: 5 },
            { batch_size: 6 },
            { batch_size: 7 },
            { batch_size: 8 }
        ]
       busy_wait_events:1,
       output_copy_stream: 1
   }
}


dynamic_batching {
  preferred_batch_size: [4,8]
  max_queue_delay_microseconds: 2000
}
instance_group [ { count: 2 kind: KIND_GPU gpus:[0]}]

4、start server :

tritonserver --strict-model-config=0 --metrics-port=8102 --http-port=8100 --grpc-port=8101 --model-repository=./model_zoo --log-verbose=0 --backend-config=python,shm-default-byte-size=335544320

5、client sending requests:

import argparse
import numpy as np
import sys

import tritonclient.http as httpclient

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-v',
                        '--verbose',
                        action="store_true",
                        required=False,
                        default=False,
                        help='Enable verbose output')
    parser.add_argument('-u',
                        '--url',
                        type=str,
                        required=False,
                        default='localhost:8100',
                        help='Inference server URL. Default is localhost:8100.')

    FLAGS = parser.parse_args()
    request_count = 50
    try:
        triton_client = httpclient.InferenceServerClient(
            url=FLAGS.url, verbose=FLAGS.verbose, concurrency=request_count)
    except Exception as e:
        print("channel creation failed: " + str(e))
        sys.exit()

    ################################################### img check
    model_name = "resnet_50"

    output_name = ["output"]
    np.random.seed(1024)
    input_data = np.random.randn(2, 3, 224, 224).astype(np.float32)
    # input0_data = load_image("bag.jpeg")
    print(input_data.shape)

    for i in range(5):
        batch_id = i % 2
        # Infer
        inputs = []
        outputs = []
        
        inputs.append(httpclient.InferInput('input', input_data[batch_id:batch_id+1].shape, "FP32"))


        # Create the data for the two input tensors. Initialize the first
        # to unique integers and the second to all ones.

        # Initialize the data
        inputs[0].set_data_from_numpy(input_data[batch_id:batch_id+1])

        for name in output_name:
            outputs.append(httpclient.InferRequestedOutput(name))
        import time

        headers = {}
        st = time.time()
        cnt = 7
        results = []
        print(f"round {i}")
        for j in range(cnt):
            async_request = triton_client.async_infer(model_name=model_name,
                                        inputs=inputs,
                                        outputs=outputs, headers=headers)
            result = async_request.get_result()
            for name in output_name:
                out = result.as_numpy(name)
                print(f"input case : {batch_id}", ", output: ", name, out.shape, out[0][0])
                break

Expected behavior
The server should output consistent and correct results.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions