Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch->Onnx->Tensorrt 8.5.2.2 conversion #200

Closed
aaronrmm opened this issue Feb 1, 2024 · 5 comments
Closed

Pytorch->Onnx->Tensorrt 8.5.2.2 conversion #200

aaronrmm opened this issue Feb 1, 2024 · 5 comments
Assignees

Comments

@aaronrmm
Copy link

aaronrmm commented Feb 1, 2024

The issue

I am trying to get the model to run with tensorrt==8.5.2.2 because that's the highest tensorrt version I can install on Nvidia Jetson Orin.
Conversion to onnx works fine. Conversion to tensorrt results in

[02/01/2024-20:28:37] [E] [TRT] ModelImporter.cpp:731: ERROR: builtin_op_importers.cpp:5427 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?" 

To Reproduce
cuda==11.8
tensorrt==8.5.2.2
torch==2.1.0a0+41361538.nv23.6
onnx==1.15.0

The conversion code:

class Model(nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        rtdetr.model.training = False
        self.model = rtdetr.model.eval()

    def forward(self, images, orig_target_sizes):
        outputs = self.model(images, orig_target_sizes)
        return outputs

model = Model().eval()

# as per https://github.com/lyuwenyu/RT-DETR/blob/3330eca679a7d7cce16bbb10509099174a2f40bf/rtdetr_pytorch/tools/export_onnx.py#L48C21-L48C21
dynamic_axes = {
    "images": {
        0: "N",
    },
    "orig_target_sizes": {0: "N"},
}

        image_tensor = torch.rand(4, 3, 640, 640).cuda()
        orig_target_sizes = torch.tensor([[640, 640]], dtype=torch.int32).cuda()

        torch.onnx.export(
            model,
            (image_tensor, orig_target_sizes),
            f"temp/rtdetr.onnx",
            input_names=["images", "orig_target_sizes"],
            output_names=["labels", "boxes", "scores", "features"],
            dynamic_axes=dynamic_axes,
            opset_version=17,
            verbose=False,
        )
        command = (
            f"/usr/src/tensorrt/bin/trtexec "
            f"--onnx={conversion_config.temp_dir/'rtdetr.onnx'}  "
            f"--workspace=16096 "
            f"--minShapes=images:1x3x640x640,orig_target_sizes:1x2 "
            f"--optShapes=images:4x3x640x640,orig_target_sizes:4x2 "
            f"--maxShapes=images:4x3x640x640,orig_target_sizes:4x2 "
            f"--saveEngine=output/model.trt"
        )
        result = subprocess.run(
            command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )

Full error
/usr/src/tensorrt/bin/trtexec --onnx=/temp/onnx/rtdetr.onnx --workspace=16096 --minShapes=images:1x3x640x640,orig_target_sizes:1x2 --optShapes=images:4x3x640x640,orig_target_sizes:4x2 --maxShapes=images:4x3x640x640,orig_target_sizes:4x2 --saveEngine=/tests/output/model.trt--avgRuns=10--fp16
[02/01/2024-20:28:33] [I] === Model Options ===
[02/01/2024-20:28:33] [I] Format: ONNX
[02/01/2024-20:28:33] [I] Model: /temp/onnx/rtdetr.onnx
[02/01/2024-20:28:33] [I] Output:
[02/01/2024-20:28:33] [I] === Build Options ===
[02/01/2024-20:28:33] [I] Max batch: explicit batch
[02/01/2024-20:28:33] [I] Memory Pools: workspace: 16096 MiB, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
[02/01/2024-20:28:33] [I] minTiming: 1
[02/01/2024-20:28:33] [I] avgTiming: 8
[02/01/2024-20:28:33] [I] Precision: FP32
[02/01/2024-20:28:33] [I] LayerPrecisions:
[02/01/2024-20:28:33] [I] Calibration:
[02/01/2024-20:28:33] [I] Refit: Disabled
[02/01/2024-20:28:33] [I] Sparsity: Disabled
[02/01/2024-20:28:33] [I] Safe mode: Disabled
[02/01/2024-20:28:33] [I] DirectIO mode: Disabled
[02/01/2024-20:28:33] [I] Restricted mode: Disabled
[02/01/2024-20:28:33] [I] Build only: Disabled
[02/01/2024-20:28:33] [I] Save engine: /tests/output/model.trt--avgRuns=10--fp16
[02/01/2024-20:28:33] [I] Load engine:
[02/01/2024-20:28:33] [I] Profiling verbosity: 0
[02/01/2024-20:28:33] [I] Tactic sources: Using default tactic sources
[02/01/2024-20:28:33] [I] timingCacheMode: local
[02/01/2024-20:28:33] [I] timingCacheFile:
[02/01/2024-20:28:33] [I] Heuristic: Disabled
[02/01/2024-20:28:33] [I] Preview Features: Use default preview flags.
[02/01/2024-20:28:33] [I] Input(s)s format: fp32:CHW
[02/01/2024-20:28:33] [I] Output(s)s format: fp32:CHW
[02/01/2024-20:28:33] [I] Input build shape: images=1x3x640x640+4x3x640x640+4x3x640x640
[02/01/2024-20:28:33] [I] Input build shape: orig_target_sizes=1x2+4x2+4x2
[02/01/2024-20:28:33] [I] Input calibration shapes: model
[02/01/2024-20:28:33] [I] === System Options ===
[02/01/2024-20:28:33] [I] Device: 0
[02/01/2024-20:28:33] [I] DLACore:
[02/01/2024-20:28:33] [I] Plugins:
[02/01/2024-20:28:33] [I] === Inference Options ===
[02/01/2024-20:28:33] [I] Batch: Explicit
[02/01/2024-20:28:33] [I] Input inference shape: orig_target_sizes=4x2
[02/01/2024-20:28:33] [I] Input inference shape: images=4x3x640x640
[02/01/2024-20:28:33] [I] Iterations: 10
[02/01/2024-20:28:33] [I] Duration: 3s (+ 200ms warm up)
[02/01/2024-20:28:33] [I] Sleep time: 0ms
[02/01/2024-20:28:33] [I] Idle time: 0ms
[02/01/2024-20:28:33] [I] Streams: 1
[02/01/2024-20:28:33] [I] ExposeDMA: Disabled
[02/01/2024-20:28:33] [I] Data transfers: Enabled
[02/01/2024-20:28:33] [I] Spin-wait: Disabled
[02/01/2024-20:28:33] [I] Multithreading: Disabled
[02/01/2024-20:28:33] [I] CUDA Graph: Disabled
[02/01/2024-20:28:33] [I] Separate profiling: Disabled
[02/01/2024-20:28:33] [I] Time Deserialize: Disabled
[02/01/2024-20:28:33] [I] Time Refit: Disabled
[02/01/2024-20:28:33] [I] NVTX verbosity: 0
[02/01/2024-20:28:33] [I] Persistent Cache Ratio: 0
[02/01/2024-20:28:33] [I] Inputs:
[02/01/2024-20:28:33] [I] === Reporting Options ===
[02/01/2024-20:28:33] [I] Verbose: Disabled
[02/01/2024-20:28:33] [I] Averages: 10 inferences
[02/01/2024-20:28:33] [I] Percentiles: 90,95,99
[02/01/2024-20:28:33] [I] Dump refittable layers:Disabled
[02/01/2024-20:28:33] [I] Dump output: Disabled
[02/01/2024-20:28:33] [I] Profile: Disabled
[02/01/2024-20:28:33] [I] Export timing to JSON file:
[02/01/2024-20:28:33] [I] Export output to JSON file:
[02/01/2024-20:28:33] [I] Export profile to JSON file:
[02/01/2024-20:28:33] [I]
[02/01/2024-20:28:33] [I] === Device Information ===
[02/01/2024-20:28:33] [I] Selected Device: Orin
[02/01/2024-20:28:33] [I] Compute Capability: 8.7
[02/01/2024-20:28:33] [I] SMs: 8
[02/01/2024-20:28:33] [I] Compute Clock Rate: 1.3 GHz
[02/01/2024-20:28:33] [I] Device Global Memory: 62800 MiB
[02/01/2024-20:28:33] [I] Shared Memory per SM: 164 KiB
[02/01/2024-20:28:33] [I] Memory Bus Width: 256 bits (ECC disabled)
[02/01/2024-20:28:33] [I] Memory Clock Rate: 0.612 GHz
[02/01/2024-20:28:33] [I]
[02/01/2024-20:28:33] [I] TensorRT version: 8.5.2
[02/01/2024-20:28:34] [I] [TRT] [MemUsageChange] Init CUDA: CPU +220, GPU +0, now: CPU 249, GPU 11291 (MiB)
[02/01/2024-20:28:37] [I] [TRT] [MemUsageChange] Init builder kernel library: CPU +302, GPU +431, now: CPU 574, GPU 11739 (MiB)
[02/01/2024-20:28:37] [I] Start parsing network model
[02/01/2024-20:28:37] [I] [TRT] ----------------------------------------------------------------
[02/01/2024-20:28:37] [I] [TRT] Input filename: /temp/onnx/rtdetr.onnx
[02/01/2024-20:28:37] [I] [TRT] ONNX IR version: 0.0.8
[02/01/2024-20:28:37] [I] [TRT] Opset version: 17
[02/01/2024-20:28:37] [I] [TRT] Producer name: pytorch
[02/01/2024-20:28:37] [I] [TRT] Producer version: 2.1.0
[02/01/2024-20:28:37] [I] [TRT] Domain:
[02/01/2024-20:28:37] [I] [TRT] Model version: 0
[02/01/2024-20:28:37] [I] [TRT] Doc string:
[02/01/2024-20:28:37] [I] [TRT] ----------------------------------------------------------------
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_1: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_2: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_2: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_1: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_2: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_1: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_2: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_1: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_2: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_1: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_2: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_1: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_2: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] /model/model/encoder/encoder.0/layers.0/self_attn/MatMul_1: broadcasting input1 to make tensors conform, dims(input0)=[400,-1,256][NONE] dims(input1)=[1,256,256][NONE].
[02/01/2024-20:28:37] [I] [TRT] No importer registered for op: LayerNormalization. Attempting to import as plugin.
[02/01/2024-20:28:37] [I] [TRT] Searching for plugin: LayerNormalization, plugin_version: 1, plugin_namespace:
[02/01/2024-20:28:37] [I] Finish parsing network model
&&&& FAILED TensorRT.trtexec [TensorRT v8502] # /usr/src/tensorrt/bin/trtexec --onnx=/temp/onnx/rtdetr.onnx --workspace=16096 --minShapes=images:1x3x640x640,orig_target_sizes:1x2 --optShapes=images:4x3x640x640,orig_target_sizes:4x2 --maxShapes=images:4x3x640x640,orig_target_sizes:4x2 --saveEngine=/tests/output/model.trt--avgRuns=10--fp16
Error:/libs/conversion/rtdetr_pytorch_tensorrt.py - [02/01/2024-20:28:33] [W] --workspace flag has been deprecated by --memPoolSize flag.
[02/01/2024-20:28:37] [W] [TRT] onnx2trt_utils.cpp:375: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[02/01/2024-20:28:37] [E] [TRT] ModelImporter.cpp:726: While parsing node number 171 [LayerNormalization -> "/model/model/encoder/encoder.0/layers.0/norm1/LayerNormalization_output_0"]:
[02/01/2024-20:28:37] [E] [TRT] ModelImporter.cpp:727: --- Begin node ---
[02/01/2024-20:28:37] [E] [TRT] ModelImporter.cpp:728: input: "/model/model/encoder/encoder.0/layers.0/Add_1_output_0"
input: "model.model.encoder.encoder.0.layers.0.norm1.weight"
input: "model.model.encoder.encoder.0.layers.0.norm1.bias"
output: "/model/model/encoder/encoder.0/layers.0/norm1/LayerNormalization_output_0"
name: "/model/model/encoder/encoder.0/layers.0/norm1/LayerNormalization"
op_type: "LayerNormalization"
attribute {
name: "axis"
i: -1
type: INT
}
attribute {
name: "epsilon"
f: 1e-05
type: FLOAT
}
[02/01/2024-20:28:37] [E] [TRT] ModelImporter.cpp:729: --- End node ---
[02/01/2024-20:28:37] [E] [TRT] ModelImporter.cpp:731: ERROR: builtin_op_importers.cpp:5427 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[02/01/2024-20:28:37] [E] Failed to parse onnx file
[02/01/2024-20:28:37] [E] Parsing model failed
[02/01/2024-20:28:37] [E] Failed to create engine from model or file.
[02/01/2024-20:28:37] [E] Engine set up failed

@lyuwenyu
Copy link
Owner

lyuwenyu commented Feb 2, 2024

dynamic_axes=dynamic_axes,
opset_version=17,

I have never encountered Plugin not found problem for LayerNormalization before.

I think you can try static input firstly, and try to set opset_version=16

@aaronrmm
Copy link
Author

Thanks so much! That did it: I set the opset to 16 and left out all the additional arguments, and that worked.

My working configuration for anyone else trying this:

FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3 # base docker image

torch==2.1.0a0+41361538.nv23.6
torchvision==0.15.2
onnx==1.14.0
tensorrt==8.5.2.2
numpy==1.23.1 # had to downgrade numpy for this version of tensorrt

        command = (
            f"/usr/src/tensorrt/bin/trtexec "
            f"--onnx='rtdetr.onnx' "
            f"--saveEngine='rtdetr.trt' "
            f"--best "
        )

@IamShubhamGupto
Copy link

@aaronrmm hello! I am facing similar issues, how did you generate the onnx file?

@aaronrmm
Copy link
Author

@IamShubhamGupto
Try this and let me know if you have any issues. I simplified my code for loading the pytorch model (first two lines) in order to paste it here, but haven't tested.

# modified from https://github.com/lyuwenyu/RT-DETR/blob/3330eca679a7d7cce16bbb10509099174a2f40bf/rtdetr_pytorch/tools/export_onnx.py#L48C21-L48C21

rtdetr_model = torch.load(model_weights_path, map_location=torch.device(self.device))

class Model(nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        rtdetr_model.training = False
        self.model = rtdetr_model.eval()

    def forward(self, images, orig_target_sizes):
        outputs = self.model(images, orig_target_sizes)
        return outputs

model = Model().eval()

image_tensor = torch.rand(1, 3, 640, 640).cuda()
orig_target_sizes = torch.tensor([[640, 640]], dtype=torch.int32).cuda()

# modified to use static axes and opset 16
torch.onnx.export(
    model,
    (image_tensor, orig_target_sizes),
    "rtdetr.onnx",
    input_names=["images", "orig_target_sizes"],
    output_names=["labels", "boxes", "scores"],
    opset_version=16,
    verbose=False,
)

@Yuxinyi-Qiyu
Copy link

Thanks so much! That did it: I set the opset to 16 and left out all the additional arguments, and that worked.

My working configuration for anyone else trying this:

FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3 # base docker image

torch==2.1.0a0+41361538.nv23.6 torchvision==0.15.2 onnx==1.14.0 tensorrt==8.5.2.2 numpy==1.23.1 # had to downgrade numpy for this version of tensorrt

        command = (
            f"/usr/src/tensorrt/bin/trtexec "
            f"--onnx='rtdetr.onnx' "
            f"--saveEngine='rtdetr.trt' "
            f"--best "
        )

hey, can you show the full code? i meet the same problem: In node 356 (importFallbackPluginImporter): UNSUPPORTED_NODE: Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?" , i'd like to try to run your code to see whether it works☺️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants