Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error importing onnx::Unsqueeze #180

Closed
jgarciacastano opened this issue May 23, 2019 · 19 comments
Closed

Error importing onnx::Unsqueeze #180

jgarciacastano opened this issue May 23, 2019 · 19 comments
Labels
repro requested Request more information about reproduction of issue triaged Issue has been triaged by maintainers

Comments

@jgarciacastano
Copy link

I trying to export a pytorch model to tensorrt using onnx.
Conversion from pytorch to onnx format is successfully done. However, I'm getting the following error importing the onnx model using onnx-tensorrt:

Traceback (most recent call last):
File "export_onnx_model_to_trt.py", line 9, in
engine = backend.prepare(model)#, device='CUDA:1', max_batch_size=1)
File "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/onnx_tensorrt-0.1.0-py3.6-linux-x86_64.egg/onnx_tensorrt/backend.py", line 218, in prepare
return TensorRTBackendRep(model, device, **kwargs)
File "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/onnx_tensorrt-0.1.0-py3.6-linux-x86_64.egg/onnx_tensorrt/backend.py", line 94, in init
raise RuntimeError(msg)
RuntimeError: While parsing node number 30:
builtin_op_importers.cpp:1987 In function importUnsqueeze:
[8] Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)

It happens trying to import an onnx::Unsqueeze layer.

The layer in pytorch is:

nn.BatchNorm1d(1000)

Torch.onnx exports such layer as:

%60 : Tensor = onnx::Unsqueezeaxes=[2], scope: ViolenceModel/Sequential[classifier]/BatchNorm1d[0]
%61 : Tensor = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%60, %classifier.0.weight, %classifier.0.bias, %classifier.0.running_mean, %classifier.0.running_var), scope: ViolenceModel/Sequential[classifier]/BatchNorm1d[0]
%62 : Float(1, 1000) = onnx::Squeezeaxes=[2], scope: ViolenceModel/Sequential[classifier]/BatchNorm1d[0]

and onnx.load generates the error.

I'm using Pytorch 1.1.0, onnx 1.5.0 and TensorRT 5.1

any help???

@laggui
Copy link

laggui commented Aug 20, 2019

I'm getting the same error when trying to import my ONNX model (generated by exporting a version of resnet50 from PyTorch).

However, when using the onnx2trt executable, I am able to convert it to a TensorRT engine. How come?

@X-funbean
Copy link

@laggui Hi, I've encountered the same problem, but it seems a little bit more complicated.

I coudn't convert the RenNet50_nFC.onnx model generated by PyTorch to .trt until I modified the resnet.py in torchvision, where I changed x = x.reshape(x.size(0), -1) to x = x.reshape(1, -1), otherwise an error occurs:

While parsing node number 175 [Gather -> "764"]:
ERROR: /home/xfb/Projects/ModelConvert/onnx-tensorrt/onnx2trt_utils.hpp:399 In function convert_axis:
[8] Assertion failed: axis >= 0 && axis < nbDims

When I try to use onnx_tensorrt.backend, I meet the same problem as @jgarciacastano did

RuntimeError: While parsing node number 175:
builtin_op_importers.cpp:1987 In function importUnsqueeze:
[8] Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)

I've also followed the TensorRT Samples Support Guide, and tried to import .onnx model with OnnxParser. However, I met another error:

[TensorRT] ERROR: Network must have at least one output

and the engine didn't built.

I wonder whether these errors are caused by unsupported operations like unsqueeze . If so, how could I solve it?

By the way, have you tried to load the TensorRT engine directly, I mean, since the '.trt' model is generated, why not just import the engine instead of '.onnx' model? I found some relative functions like 'trt_network_to_trt_engine()' in the legacy module of TensorRT, but they are deprecated, and I haven't figured out how to use them. Any advice on that? Thanks a lot~~

@laggui
Copy link

laggui commented Aug 21, 2019

@laggui Hi, I've encountered the same problem, but it seems a little bit more complicated.

I coudn't convert the RenNet50_nFC.onnx model generated by PyTorch to .trt until I modified the resnet.py in torchvision, where I changed x = x.reshape(x.size(0), -1) to x = x.reshape(1, -1), otherwise an error occurs:

While parsing node number 175 [Gather -> "764"]:
ERROR: /home/xfb/Projects/ModelConvert/onnx-tensorrt/onnx2trt_utils.hpp:399 In function convert_axis:
[8] Assertion failed: axis >= 0 && axis < nbDims

Yeah TensorRT doesn't support operations on the batch dim (axis 0), which is why you're seeing this error. A little bit annoying if you ask me since some operations converted from PyTorch end up manipulating the batch axis.

When I try to use onnx_tensorrt.backend, I meet the same problem as @jgarciacastano did

RuntimeError: While parsing node number 175:
builtin_op_importers.cpp:1987 In function importUnsqueeze:
[8] Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)

Do you know at which operation/layer in your model it's performing an unsqueeze?

I've also followed the TensorRT Samples Support Guide, and tried to import .onnx model with OnnxParser. However, I met another error:

[TensorRT] ERROR: Network must have at least one output

and the engine didn't built.

I also tried the OnnxParser from TensorRT 5.1.5 initially, and I the same error for the unsqueeze operation (which makes sense since internally this repo detects your TensorRT version and uses the trt.OnnxParser as well for versions >= 5).

[TensorRT] VERBOSE: 508:Flatten -> (4096)
[TensorRT] VERBOSE: /home/erisuser/p4sw/sw/gpgpu/MachineLearning/DIT/release/5.1/parsers/onnxOpenSource/builtin_op_importers.cpp:1981: Unsqueezing from (4096) to (4096, 35)
[ERROR] In node 176 (importUnsqueeze): UNSUPPORTED_NODE: Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)

(and actually, on that note, it seems to me like Unsqueezing from (4096) to (4096, 35) is wrong. The unsqueeze should produce an output of (4096, 1).)

In your case, I believe [TensorRT] ERROR: Network must have at least one output is simply because the parser failed to parse the ONNX model at the same node but tried to complete, and couldn't find the output. If you reproduced the samples in trt and got this error, perhaps try this when parsing your onnx model:

with open(onnx_model_file, 'rb') as model:
    print('Beginning ONNX file parsing')
        if not parser.parse(model.read()):
            err = parser.get_error(0)
            print(f'[ERROR] {err}')
            raise IOError('Failed to parse ONNX file')
print('Completed parsing of ONNX file')

The important part you're probably missing is catching the parser errors with parser.get_error, and printing the error should get you the same error you're getting with onnx_tensorrt.backend.

I wonder whether these errors are caused by unsupported operations like unsqueeze . If so, how could I solve it?

By the way, have you tried to load the TensorRT engine directly, I mean, since the '.trt' model is generated, why not just import the engine instead of '.onnx' model? I found some relative functions like 'trt_network_to_trt_engine()' in the legacy module of TensorRT, but they are deprecated, and I haven't figured out how to use them. Any advice on that? Thanks a lot~~

Yeah that was the next step, I'm going to try to load the converted TensorRT engine today. (edit: inference works with trt).

I would still like to know why the OnnxParser failed to convert my ONNX model for the unsqueeze operation, and not the executable built from the C++ files. Was there a recent fix, I couldn't tell by quickly looking at the latest commits?

@X-funbean
Copy link

X-funbean commented Aug 22, 2019

Glad to hear that you've managed to get it work with trt, and thanks for your advice @laggui
I tried your code to parse the onnx model, and the parser failed to parse it indeed:

[ERROR] In node 175 (importUnsqueeze): UNSUPPORTED_NODE: Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)

I've realised that torchvision 0.4.0 has changed the unsqueeze operation in resnet.py (to be exact, in the forward function of class ResNet) to flatten, and I successfully export and load a simple ResNet50 network. In my case, I believe the problem lies with BatchNorm1d either.
I found an issue under PyTorch : pytorch/pytorch#14946, which says that PyTorch >= 1.0 exports onnx model with opset 9, but tensorrt >= 5.1 support opset version 7. I tried both versions, but got the same error (Failed to parse ONNX file). So maybe opset version is a problem, but doesn't suit my case.
To me, it's a little bit weird that the trt model works, while the onnx model that generates it fail. Still, I'm really interested in how you managed to inference with .trt model. I haven't found examples about that. Could you give me some code or function hints? Really appreciate it.

@laggui
Copy link

laggui commented Aug 22, 2019

Glad to hear that you've managed to get it work with trt, and thanks for your advice @laggui
I tried your code to parse the onnx model, and the parser failed to parse it indeed:

[ERROR] In node 175 (importUnsqueeze): UNSUPPORTED_NODE: Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)

Great, so at least the errors are the same.

I've realised that torchvision 0.4.0 has changed the unsqueeze operation in resnet.py (to be exact, in the forward function of class ResNet) to flatten, and I successfully export and load a simple ResNet50 network. In my case, I believe the problem lies with BatchNorm1d either.

In my case, the problem came from BatchNorm1D as well (and actually, it is still a problem when using the trt.OnnxParser or onnx_tensorrt.backend, which are pretty much the same, programmatically). The problem arises here specifically, with the unsqueeze operation that's added.

I only managed to convert my ONNX model to a trt engine by using the compiled onnx2trt tool from this repo, so I still wanted to know why the results differ (perhaps there have been some changes recently to the C++ code which fixed this issue, so the underlying onnx parser still doesn't have this fix?

I found an issue under PyTorch : pytorch/pytorch#14946, which says that PyTorch >= 1.0 exports onnx model with opset 9, but tensorrt >= 5.1 support opset version 7. I tried both versions, but got the same error (Failed to parse ONNX file). So maybe opset version is a problem, but doesn't suit my case.

I don't think this is an opset version, and actually: The ONNX Parser shipped with TensorRT 5.1.x supports ONNX IR (Intermediate Representation) version 0.0.3, opset version 9. (as stated here)

To me, it's a little bit weird that the trt model works, while the onnx model that generates it fail. Still, I'm really interested in how you managed to inference with .trt model. I haven't found examples about that. Could you give me some code or function hints? Really appreciate it.

To be clear, the generated ONNX model is good. Just seems to have some incompatibility with the parser in trt. I'll post a code sample for inference tomorrow :)

Edit: here's a sample

# Read the serialized ICudaEngine
with open(trt_engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
    # Deserialize ICudaEngine
    engine = runtime.deserialize_cuda_engine(f.read())
# Now just as with the onnx2trt samples...
# Create an IExecutionContext (context for executing inference)
with engine.create_execution_context() as context:
    # Allocate memory for inputs/outputs
    inputs, outputs, bindings, stream = allocate_buffers(engine)
    # Set host input to the image
    inputs[0].host = image
    # Inference
    trt_outputs = infer(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
    # Prediction
    pred_id = np.argmax(trt_outputs[-1])

And the function definitions for allocate_buffers and infer:

def allocate_buffers(engine):
    """
    Allocates all buffers required for the specified engine
    """
    inputs = []
    outputs = []
    bindings = []
    # Iterate over binding names in engine
    for binding in engine:
        # Get binding (tensor/buffer) size
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        # Get binding (tensor/buffer) data type (numpy-equivalent)
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate page-locked memory (i.e., pinned memory) buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        # Allocate linear piece of device memory
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings
        bindings.append(int(device_mem))
        # Append to inputs/ouputs list
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
    # Create a stream (to eventually copy inputs/outputs and run inference)
    stream = cuda.Stream()
    return inputs, outputs, bindings, stream

def infer(context, bindings, inputs, outputs, stream, batch_size=1):
    """
    Infer outputs on the IExecutionContext for the specified inputs
    """
    # Transfer input data to the GPU
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return the host outputs
    return [out.host for out in outputs]

@DJMeng
Copy link

DJMeng commented Nov 13, 2019

i meet the same problem. if you noramalize in the end of the net output,pls try this code:
x = x / x.pow(2).sum(dim=1,keepdim=True).sqrt()
it works for me.

@ethanyhzhang
Copy link

Glad to hear that you've managed to get it work with trt, and thanks for your advice @laggui
I tried your code to parse the onnx model, and the parser failed to parse it indeed:

[ERROR] In node 175 (importUnsqueeze): UNSUPPORTED_NODE: Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)

Great, so at least the errors are the same.

I've realised that torchvision 0.4.0 has changed the unsqueeze operation in resnet.py (to be exact, in the forward function of class ResNet) to flatten, and I successfully export and load a simple ResNet50 network. In my case, I believe the problem lies with BatchNorm1d either.

In my case, the problem came from BatchNorm1D as well (and actually, it is still a problem when using the trt.OnnxParser or onnx_tensorrt.backend, which are pretty much the same, programmatically). The problem arises here specifically, with the unsqueeze operation that's added.

I only managed to convert my ONNX model to a trt engine by using the compiled onnx2trt tool from this repo, so I still wanted to know why the results differ (perhaps there have been some changes recently to the C++ code which fixed this issue, so the underlying onnx parser still doesn't have this fix?

I found an issue under PyTorch : pytorch/pytorch#14946, which says that PyTorch >= 1.0 exports onnx model with opset 9, but tensorrt >= 5.1 support opset version 7. I tried both versions, but got the same error (Failed to parse ONNX file). So maybe opset version is a problem, but doesn't suit my case.

I don't think this is an opset version, and actually: The ONNX Parser shipped with TensorRT 5.1.x supports ONNX IR (Intermediate Representation) version 0.0.3, opset version 9. (as stated here)

To me, it's a little bit weird that the trt model works, while the onnx model that generates it fail. Still, I'm really interested in how you managed to inference with .trt model. I haven't found examples about that. Could you give me some code or function hints? Really appreciate it.

To be clear, the generated ONNX model is good. Just seems to have some incompatibility with the parser in trt. I'll post a code sample for inference tomorrow :)

Edit: here's a sample

# Read the serialized ICudaEngine
with open(trt_engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
    # Deserialize ICudaEngine
    engine = runtime.deserialize_cuda_engine(f.read())
# Now just as with the onnx2trt samples...
# Create an IExecutionContext (context for executing inference)
with engine.create_execution_context() as context:
    # Allocate memory for inputs/outputs
    inputs, outputs, bindings, stream = allocate_buffers(engine)
    # Set host input to the image
    inputs[0].host = image
    # Inference
    trt_outputs = infer(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
    # Prediction
    pred_id = np.argmax(trt_outputs[-1])

And the function definitions for allocate_buffers and infer:

def allocate_buffers(engine):
    """
    Allocates all buffers required for the specified engine
    """
    inputs = []
    outputs = []
    bindings = []
    # Iterate over binding names in engine
    for binding in engine:
        # Get binding (tensor/buffer) size
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        # Get binding (tensor/buffer) data type (numpy-equivalent)
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate page-locked memory (i.e., pinned memory) buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        # Allocate linear piece of device memory
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings
        bindings.append(int(device_mem))
        # Append to inputs/ouputs list
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
    # Create a stream (to eventually copy inputs/outputs and run inference)
    stream = cuda.Stream()
    return inputs, outputs, bindings, stream

def infer(context, bindings, inputs, outputs, stream, batch_size=1):
    """
    Infer outputs on the IExecutionContext for the specified inputs
    """
    # Transfer input data to the GPU
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return the host outputs
    return [out.host for out in outputs]

Hi, when I use your example script to inference a batch of images e.g. (5,112,112,3) and set inputs[0].host = images. But only the first sample in the output is correct, with the remaining being zeros. Could you please help me with the issue?

@thancaocuong
Copy link

same question with @ethanyhzhang

@laggui
Copy link

laggui commented Jun 25, 2020

@ethanyhzhang @thancaocuong sorry for the late response.

The example script I provided was for a batch size of 1, so inputs[0].host = image is correct. I have not tried batch inference with a batch size larger than 1 as it was not required for my application at the time.

But it seems like @ethanyhzhang has managed to find a solution here: https://forums.developer.nvidia.com/t/batch-inference-wrong-in-python-api/112183/13

Sorry I don't have much time to look into this right now.

@daixiangzi
Copy link

i meet the same problem. if you noramalize in the end of the net output,pls try this code:
x = x / x.pow(2).sum(dim=1,keepdim=True).sqrt()
it works for me.

sure,i think so

@kevinch-nv
Copy link
Collaborator

Is anyone still having this issue with the latest TensorRT version (7.2)?

@kevinch-nv kevinch-nv added repro requested Request more information about reproduction of issue triaged Issue has been triaged by maintainers labels Oct 25, 2020
@phamdat09
Copy link

Is anyone still having this issue with the latest TensorRT version (7.2)?

I had the same issue :))

@kevinch-nv
Copy link
Collaborator

@phamdat09 can you provide the model you are having troubles with?

@phamdat09
Copy link

@phamdat09 can you provide the model you are having troubles with?

Yes !! Here is my model https://drive.google.com/file/d/1UUkasZSCbPvyq7-qATxHdliugK5i6aQi/view?usp=sharing
I use the pytorch -> onnx with opset=12. Then, I figure out that the problems come from the torch.BatchNorm1d

@hongsamvo
Copy link

@phamdat09 I got error with BatchNorm1d layer too, did you find the solution?

@phamdat09
Copy link

phamdat09 commented Nov 24, 2020

@phamdat09 I got error with BatchNorm1d layer too, did you find the solution?

yes!! With the latest commit, they already fixed it !!

Follow this issue #566

@hongsamvo
Copy link

@phamdat09 thank you!

@kevinch-nv
Copy link
Collaborator

Closing this issue as it's now fixed.

@Drimz11
Copy link

Drimz11 commented Feb 8, 2021

out_num = output.argmax(2)[-1].item()

Can you explain me the code ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
repro requested Request more information about reproduction of issue triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests