-
Notifications
You must be signed in to change notification settings - Fork 544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error importing onnx::Unsqueeze #180
Comments
I'm getting the same error when trying to import my ONNX model (generated by exporting a version of resnet50 from PyTorch). However, when using the |
@laggui Hi, I've encountered the same problem, but it seems a little bit more complicated. I coudn't convert the While parsing node number 175 [Gather -> "764"]:
ERROR: /home/xfb/Projects/ModelConvert/onnx-tensorrt/onnx2trt_utils.hpp:399 In function convert_axis:
[8] Assertion failed: axis >= 0 && axis < nbDims When I try to use RuntimeError: While parsing node number 175:
builtin_op_importers.cpp:1987 In function importUnsqueeze:
[8] Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape) I've also followed the [TensorRT] ERROR: Network must have at least one output and the engine didn't built. I wonder whether these errors are caused by unsupported operations like By the way, have you tried to load the TensorRT engine directly, I mean, since the '.trt' model is generated, why not just import the engine instead of '.onnx' model? I found some relative functions like 'trt_network_to_trt_engine()' in the legacy module of TensorRT, but they are deprecated, and I haven't figured out how to use them. Any advice on that? Thanks a lot~~ |
Yeah TensorRT doesn't support operations on the batch dim (axis 0), which is why you're seeing this error. A little bit annoying if you ask me since some operations converted from PyTorch end up manipulating the batch axis.
Do you know at which operation/layer in your model it's performing an
I also tried the OnnxParser from TensorRT 5.1.5 initially, and I the same error for the
(and actually, on that note, it seems to me like In your case, I believe with open(onnx_model_file, 'rb') as model:
print('Beginning ONNX file parsing')
if not parser.parse(model.read()):
err = parser.get_error(0)
print(f'[ERROR] {err}')
raise IOError('Failed to parse ONNX file')
print('Completed parsing of ONNX file') The important part you're probably missing is catching the parser errors with
I would still like to know why the |
Glad to hear that you've managed to get it work with trt, and thanks for your advice @laggui [ERROR] In node 175 (importUnsqueeze): UNSUPPORTED_NODE: Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape) I've realised that torchvision 0.4.0 has changed the |
Great, so at least the errors are the same.
In my case, the problem came from I only managed to convert my ONNX model to a trt engine by using the compiled
I don't think this is an opset version, and actually:
To be clear, the generated ONNX model is good. Just seems to have some incompatibility with the parser in trt. Edit: here's a sample # Read the serialized ICudaEngine
with open(trt_engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
# Deserialize ICudaEngine
engine = runtime.deserialize_cuda_engine(f.read())
# Now just as with the onnx2trt samples...
# Create an IExecutionContext (context for executing inference)
with engine.create_execution_context() as context:
# Allocate memory for inputs/outputs
inputs, outputs, bindings, stream = allocate_buffers(engine)
# Set host input to the image
inputs[0].host = image
# Inference
trt_outputs = infer(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
# Prediction
pred_id = np.argmax(trt_outputs[-1]) And the function definitions for def allocate_buffers(engine):
"""
Allocates all buffers required for the specified engine
"""
inputs = []
outputs = []
bindings = []
# Iterate over binding names in engine
for binding in engine:
# Get binding (tensor/buffer) size
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
# Get binding (tensor/buffer) data type (numpy-equivalent)
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate page-locked memory (i.e., pinned memory) buffers
host_mem = cuda.pagelocked_empty(size, dtype)
# Allocate linear piece of device memory
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings
bindings.append(int(device_mem))
# Append to inputs/ouputs list
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
# Create a stream (to eventually copy inputs/outputs and run inference)
stream = cuda.Stream()
return inputs, outputs, bindings, stream
def infer(context, bindings, inputs, outputs, stream, batch_size=1):
"""
Infer outputs on the IExecutionContext for the specified inputs
"""
# Transfer input data to the GPU
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference
context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return the host outputs
return [out.host for out in outputs] |
i meet the same problem. if you noramalize in the end of the net output,pls try this code: |
Hi, when I use your example script to inference a batch of images e.g. (5,112,112,3) and set inputs[0].host = images. But only the first sample in the output is correct, with the remaining being zeros. Could you please help me with the issue? |
same question with @ethanyhzhang |
@ethanyhzhang @thancaocuong sorry for the late response. The example script I provided was for a batch size of 1, so But it seems like @ethanyhzhang has managed to find a solution here: https://forums.developer.nvidia.com/t/batch-inference-wrong-in-python-api/112183/13 Sorry I don't have much time to look into this right now. |
sure,i think so |
Is anyone still having this issue with the latest TensorRT version (7.2)? |
I had the same issue :)) |
@phamdat09 can you provide the model you are having troubles with? |
Yes !! Here is my model https://drive.google.com/file/d/1UUkasZSCbPvyq7-qATxHdliugK5i6aQi/view?usp=sharing |
@phamdat09 I got error with BatchNorm1d layer too, did you find the solution? |
yes!! With the latest commit, they already fixed it !! Follow this issue #566 |
@phamdat09 thank you! |
Closing this issue as it's now fixed. |
out_num = output.argmax(2)[-1].item() Can you explain me the code ? |
I trying to export a pytorch model to tensorrt using onnx.
Conversion from pytorch to onnx format is successfully done. However, I'm getting the following error importing the onnx model using onnx-tensorrt:
Traceback (most recent call last):
File "export_onnx_model_to_trt.py", line 9, in
engine = backend.prepare(model)#, device='CUDA:1', max_batch_size=1)
File "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/onnx_tensorrt-0.1.0-py3.6-linux-x86_64.egg/onnx_tensorrt/backend.py", line 218, in prepare
return TensorRTBackendRep(model, device, **kwargs)
File "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/onnx_tensorrt-0.1.0-py3.6-linux-x86_64.egg/onnx_tensorrt/backend.py", line 94, in init
raise RuntimeError(msg)
RuntimeError: While parsing node number 30:
builtin_op_importers.cpp:1987 In function importUnsqueeze:
[8] Assertion failed: get_shape_size(layer->getOutput(0)->getDimensions()) == get_shape_size(old_shape)
It happens trying to import an onnx::Unsqueeze layer.
The layer in pytorch is:
Torch.onnx exports such layer as:
%60 : Tensor = onnx::Unsqueezeaxes=[2], scope: ViolenceModel/Sequential[classifier]/BatchNorm1d[0]
%61 : Tensor = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%60, %classifier.0.weight, %classifier.0.bias, %classifier.0.running_mean, %classifier.0.running_var), scope: ViolenceModel/Sequential[classifier]/BatchNorm1d[0]
%62 : Float(1, 1000) = onnx::Squeezeaxes=[2], scope: ViolenceModel/Sequential[classifier]/BatchNorm1d[0]
and onnx.load generates the error.
I'm using Pytorch 1.1.0, onnx 1.5.0 and TensorRT 5.1
any help???
The text was updated successfully, but these errors were encountered: