Skip to content

Convnext model fails to lower on QNN #14049

@GregoryComer

Description

@GregoryComer

🐛 Describe the bug

The convnext_small model from torchvision does not successfully lower on the QNN backend. It fails with an error - "number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 4". The error is from in torch, but the model is able to lower and run without delegation, so it appears that there may be some problem with a transformation made during QNN lowering. This same error message shows up in a few different models, so it would be nice to resolve it.

Output excerpt:

[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
[QNN Partitioner Op Support]: aten.linear.default | True
[QNN Partitioner Op Support]: aten.view_copy.default | True
[QNN Partitioner Op Support]: aten.permute_copy.default | True
[QNN Partitioner Op Support]: aten.native_layer_norm.default | True
[QNN Partitioner Op Support]: aten.permute_copy.default | True
[QNN Partitioner Op Support]: aten.adaptive_avg_pool2d.default | True
[QNN Partitioner Op Support]: aten.add.Tensor | True
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755] failed while attempting to run meta for aten.permute.default
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755] Traceback (most recent call last):
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755]   File "/home/gregory/miniconda3/envs/executorch/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2751, in _dispatch_impl
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755]     r = func(*args, **kwargs)
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755]         ^^^^^^^^^^^^^^^^^^^^^
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755]   File "/home/gregory/miniconda3/envs/executorch/lib/python3.12/site-packages/torch/_ops.py", line 840, in __call__
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755]     return self._op(*args, **kwargs)
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755]            ^^^^^^^^^^^^^^^^^^^^^^^^^
E0907 21:42:05.974000 30617 site-packages/torch/_subclasses/fake_tensor.py:2755] RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 4

This can be reproduced with the following test case command or standalone script.

python -m executorch.backends.test.suite.runner models --flow qnn --filter "test_convnext_small_qnn_float32$"

Standalone repro:

from typing import Tuple

import executorch
import torch
import torchvision

from executorch.backends.qualcomm.utils.utils import (
    generate_qnn_executorch_compiler_spec,
    generate_htp_compiler_spec,
    QcomChipset,
    to_edge_transform_and_lower_to_qnn,
)

inputs = (torch.randn(1, 3, 224, 224),)
model = torchvision.models.convnext_small().eval()

ep = torch.export.export(model, inputs)

backend_options = generate_htp_compiler_spec(
    use_fp16=True,
)

compile_spec = generate_qnn_executorch_compiler_spec(
    soc_model=QcomChipset.SM8650,
    backend_options=backend_options,
)

model = to_edge_transform_and_lower_to_qnn(
    model,
    inputs,
    compile_spec
).to_executorch()

Note that running the backend test case requires executorch's python bindings to be built with the QNN backend. An example build command is below, Note that it will still need the library paths to be set up properly as described in the ET QNN docs.

CMAKE_ARGS="-DEXECUTORCH_BUILD_QNN=ON -DQNN_SDK_ROOT=$QNN_SDK_ROOT" ./install_executorch.sh --editable

Versions

Commit fbda3a9, x86-64 simulator, WSL

cc @cccclai @winskuo-quic @shewu-quic @cbilgin

Metadata

Metadata

Assignees

No one assigned

    Labels

    backend testerThis bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/partner: qualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions