Skip to content

Commit 44353d3

Browse files
authored
Merge branch 'main' into exynos-quantize-support
2 parents eaed9db + 29b4db8 commit 44353d3

File tree

11 files changed

+177
-93
lines changed

11 files changed

+177
-93
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,8 @@
553553
kernels:
554554
- arg_meta: null
555555
kernel_name: impl::HiFi::quantized_w8a32_linear_out
556+
557+
- func: cadence::quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)
558+
kernels:
559+
- arg_meta: null
560+
kernel_name: impl::HiFi::quantized_w8a32_conv_out

backends/cadence/aot/ops_registrations.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,12 @@
571571
"quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
572572
)
573573

574+
lib.define(
575+
"quantized_w8a32_conv(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor"
576+
)
577+
lib.define(
578+
"quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
579+
)
574580

575581
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
576582
aten_lib = Library("aten", "FRAGMENT")
@@ -2589,3 +2595,32 @@ def quantized_w8a32_linear_meta(
25892595
assert src_shape[-1] == weight_shape[-1]
25902596
src_shape[-1] = weight_shape[0]
25912597
return src.new_empty(src_shape, dtype=src.dtype)
2598+
2599+
2600+
@register_fake("cadence::quantized_w8a32_conv")
2601+
def quantized_w8a32_conv_meta(
2602+
src: torch.Tensor,
2603+
weight: torch.Tensor,
2604+
w_scale: float,
2605+
bias: torch.Tensor,
2606+
b_scale: float,
2607+
) -> torch.Tensor:
2608+
# src comes in shape [batch, in_channel, in_length]
2609+
# weight comes in shape [out_ch, in_ch, kernel_dim]
2610+
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
2611+
assert len(src.shape) == 3
2612+
2613+
kernel_size, out_channels, in_channels = weight.shape
2614+
assert in_channels == src.shape[-1]
2615+
2616+
# Compute the output tensor size
2617+
output_size = get_conv1d_output_size(
2618+
src.permute(0, 2, 1).shape,
2619+
out_channels,
2620+
stride=1,
2621+
padding=0,
2622+
dilation=1,
2623+
kernel_size=kernel_size,
2624+
channel_last=False,
2625+
)
2626+
return src.new_empty(output_size, dtype=src.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LayerNormPattern,
2525
LinearPattern,
2626
MatmulPattern,
27+
MixedW8A32ConvPattern,
2728
MixedW8A32LinearPattern,
2829
ReluPattern0,
2930
ReluPattern1,
@@ -478,6 +479,52 @@ def get_args_and_kwargs_softmax(
478479
out_zero_point_tensor,
479480
)
480481
kwargs = {}
482+
483+
return args, kwargs
484+
485+
486+
def get_args_and_kwargs_mixed_w8a32_conv(
487+
graph_module: GraphModule,
488+
other_inputs: List[fx.Node],
489+
weights_inputs: List[fx.Node],
490+
dequants_weights: List[fx.Node],
491+
bias_inputs: List[fx.Node],
492+
dequants_biases: List[fx.Node],
493+
op_node: fx.Node,
494+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
495+
# Stride, padding, dilation, groups not supported yet
496+
if len(op_node.args) > 3:
497+
assert op_node.args[3] == [1] # Stride
498+
if len(op_node.args) > 4:
499+
assert op_node.args[4] == [0] # Padding
500+
if len(op_node.args) > 5:
501+
assert op_node.args[5] == [1] # Dilation
502+
if len(op_node.args) > 6:
503+
assert op_node.args[6] == 1 # Groups
504+
505+
assert len(dequants_weights) == 1
506+
assert len(dequants_biases) == 1
507+
W_scale_ = dequants_weights[0].args[1]
508+
B_scale_ = dequants_biases[0].args[1]
509+
510+
transposed_inputs = graph_module.graph.call_function(
511+
torch.ops.aten.permute.default,
512+
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
513+
)
514+
transposed_weights = graph_module.graph.call_function(
515+
torch.ops.aten.permute.default,
516+
(weights_inputs[0], [2, 0, 1]), # NCL -> NLC
517+
)
518+
519+
args = (
520+
transposed_inputs,
521+
transposed_weights,
522+
W_scale_,
523+
bias_inputs[0],
524+
B_scale_,
525+
)
526+
kwargs = {}
527+
481528
return args, kwargs
482529

483530

@@ -650,6 +697,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
650697
bias_inputs,
651698
dequants_biases,
652699
)
700+
elif isinstance(pattern, MixedW8A32ConvPattern):
701+
args, kwargs = get_args_and_kwargs_mixed_w8a32_conv(
702+
graph_module,
703+
other_inputs,
704+
weights_inputs,
705+
dequants_weights,
706+
bias_inputs,
707+
dequants_biases,
708+
op_node,
709+
)
653710

654711
fused = graph_module.graph.call_function(
655712
pattern.replacement_op(),

backends/cadence/aot/quantizer/patterns.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,65 @@ def get_anchors(
599599

600600
def replacement_op(self) -> OpOverload:
601601
return torch.ops.cadence.quantized_w8a32_linear.default
602+
603+
604+
class MixedW8A32ConvPattern(QuantizationPattern):
605+
def partition_types(self) -> List[OpOverload]:
606+
return [torch.ops.aten.conv1d.default]
607+
608+
def get_anchors(
609+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
610+
) -> Tuple[PartitionAnchors, fx.Node]:
611+
# pyre-ignore[29]
612+
conv_layer = fused_partition[0].nodes[-1]
613+
614+
# Bail if the arguments have different shapes than expected
615+
# Stride, padding, dilation and groups are not supported
616+
if len(conv_layer.args) != 3 or len(conv_layer.kwargs) > 0:
617+
return (
618+
PartitionAnchors(
619+
empty=True,
620+
),
621+
conv_layer,
622+
)
623+
624+
cnn_weights = conv_layer.args[1]
625+
if hasattr(cnn_weights.meta, "tensor_meta"):
626+
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
627+
# Bail if the channels are not multiple of 4 (SIMD)
628+
if cnn_weights_shape[0] % 4 != 0:
629+
return (
630+
PartitionAnchors(
631+
empty=True,
632+
),
633+
conv_layer,
634+
)
635+
if cnn_weights_shape[1] % 4 != 0:
636+
return (
637+
PartitionAnchors(
638+
empty=True,
639+
),
640+
conv_layer,
641+
)
642+
# Bail if the kernel size is not 3
643+
if cnn_weights_shape[2] != 3:
644+
return (
645+
PartitionAnchors(
646+
empty=True,
647+
),
648+
conv_layer,
649+
)
650+
651+
return (
652+
PartitionAnchors(
653+
inputs=[],
654+
weights=[(conv_layer, 1)],
655+
biases=[(conv_layer, 2)],
656+
output=[],
657+
others=[(conv_layer, 0)],
658+
),
659+
conv_layer,
660+
)
661+
662+
def replacement_op(self) -> OpOverload:
663+
return torch.ops.cadence.quantized_w8a32_conv.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LayerNormPattern,
2525
LinearPattern,
2626
MatmulPattern,
27+
MixedW8A32ConvPattern,
2728
MixedW8A32LinearPattern,
2829
QuantizationPattern,
2930
ReluPattern0,
@@ -321,6 +322,9 @@ def __init__(self) -> None:
321322
quantizers.append(
322323
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)
323324
)
325+
quantizers.append(
326+
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
327+
)
324328
super().__init__(quantizers)
325329

326330

docs/source/backends-qualcomm.md

Lines changed: 4 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,9 @@ This example is verified with SM8550 and SM8450.
7474
- A compiler to compile AOT parts, e.g., the GCC compiler comes with Ubuntu LTS.
7575
- [Android NDK](https://developer.android.com/ndk). This example is verified with NDK 26c.
7676
- [Qualcomm AI Engine Direct SDK](https://developer.qualcomm.com/software/qualcomm-ai-engine-direct-sdk)
77-
- Click the "Get Software" button to download a version of QNN SDK.
78-
- However, at the moment of updating this tutorial, the above website doesn't provide QNN SDK newer than 2.22.6.
79-
- The below is public links to download various QNN versions. Hope they can be publicly discoverable soon.
80-
- [QNN 2.37.0](https://softwarecenter.qualcomm.com/api/download/software/sdks/Qualcomm_AI_Runtime_Community/All/2.37.0.250724/v2.37.0.250724.zip)
77+
- Click the "Get Software" button to download the latest version of the QNN SDK.
78+
- Although newer versions are available, we have verified and recommend using QNN 2.37.0 for stability.
79+
- You can download it directly from the following link: [QNN 2.37.0](https://softwarecenter.qualcomm.com/api/download/software/sdks/Qualcomm_AI_Runtime_Community/All/2.37.0.250724/v2.37.0.250724.zip)
8180

8281
The directory with installed Qualcomm AI Engine Direct SDK looks like:
8382
```
@@ -136,86 +135,6 @@ cd $EXECUTORCH_ROOT
136135
./backends/qualcomm/scripts/build.sh --release
137136
```
138137

139-
### AOT (Ahead-of-time) components:
140-
141-
Python APIs on x64 are required to compile models to Qualcomm AI Engine Direct binary.
142-
143-
```bash
144-
cd $EXECUTORCH_ROOT
145-
mkdir build-x86
146-
cd build-x86
147-
# Note that the below command might change.
148-
# Please refer to the above build.sh for latest workable commands.
149-
cmake .. \
150-
-DCMAKE_INSTALL_PREFIX=$PWD \
151-
-DEXECUTORCH_BUILD_QNN=ON \
152-
-DQNN_SDK_ROOT=${QNN_SDK_ROOT} \
153-
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
154-
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
155-
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
156-
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
157-
-DPYTHON_EXECUTABLE=python3
158-
159-
# nproc is used to detect the number of available CPU.
160-
# If it is not applicable, please feel free to use the number you want.
161-
cmake --build $PWD --target "PyQnnManagerAdaptor" "PyQnnWrapperAdaptor" -j$(nproc)
162-
163-
# install Python APIs to correct import path
164-
# The filename might vary depending on your Python and host version.
165-
cp -f backends/qualcomm/PyQnnManagerAdaptor.cpython-310-x86_64-linux-gnu.so $EXECUTORCH_ROOT/backends/qualcomm/python
166-
cp -f backends/qualcomm/PyQnnWrapperAdaptor.cpython-310-x86_64-linux-gnu.so $EXECUTORCH_ROOT/backends/qualcomm/python
167-
168-
# Workaround for .fbs files in exir/_serialize
169-
cp $EXECUTORCH_ROOT/schema/program.fbs $EXECUTORCH_ROOT/exir/_serialize/program.fbs
170-
cp $EXECUTORCH_ROOT/schema/scalar_type.fbs $EXECUTORCH_ROOT/exir/_serialize/scalar_type.fbs
171-
```
172-
173-
### Runtime:
174-
175-
An example `qnn_executor_runner` executable would be used to run the compiled `pte` model.
176-
177-
Commands to build `qnn_executor_runner` for Android:
178-
179-
```bash
180-
cd $EXECUTORCH_ROOT
181-
mkdir build-android
182-
cd build-android
183-
# build executorch & qnn_executorch_backend
184-
cmake .. \
185-
-DCMAKE_INSTALL_PREFIX=$PWD \
186-
-DEXECUTORCH_BUILD_QNN=ON \
187-
-DQNN_SDK_ROOT=$QNN_SDK_ROOT \
188-
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
189-
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
190-
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
191-
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
192-
-DPYTHON_EXECUTABLE=python3 \
193-
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
194-
-DANDROID_ABI='arm64-v8a' \
195-
-DANDROID_PLATFORM=android-30
196-
197-
# nproc is used to detect the number of available CPU.
198-
# If it is not applicable, please feel free to use the number you want.
199-
cmake --build $PWD --target install -j$(nproc)
200-
201-
cmake ../examples/qualcomm \
202-
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
203-
-DANDROID_ABI='arm64-v8a' \
204-
-DANDROID_PLATFORM=android-30 \
205-
-DCMAKE_PREFIX_PATH="$PWD/lib/cmake/ExecuTorch;$PWD/third-party/gflags;" \
206-
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
207-
-DPYTHON_EXECUTABLE=python3 \
208-
-Bexamples/qualcomm
209-
210-
cmake --build examples/qualcomm -j$(nproc)
211-
212-
# qnn_executor_runner can be found under examples/qualcomm
213-
# The full path is $EXECUTORCH_ROOT/build-android/examples/qualcomm/executor_runner/qnn_executor_runner
214-
ls examples/qualcomm
215-
```
216-
217-
**Note:** If you want to build for release, add `-DCMAKE_BUILD_TYPE=Release` to the `cmake` command options.
218-
219138

220139
## Deploying and running on device
221140

@@ -365,7 +284,7 @@ The model, inputs, and output location are passed to `qnn_executorch_runner` by
365284

366285
## Supported model list
367286

368-
Please refer to `$EXECUTORCH_ROOT/examples/qualcomm/scripts/` and `EXECUTORCH_ROOT/examples/qualcomm/oss_scripts/` to the list of supported models.
287+
Please refer to `$EXECUTORCH_ROOT/examples/qualcomm/scripts/` and `$EXECUTORCH_ROOT/examples/qualcomm/oss_scripts/` to the list of supported models.
369288

370289
## How to Support a Custom Model in HTP Backend
371290

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,12 @@ class StaticAttentionIOManager {
586586
* of the prompt and method's input length. Returns the position in the output
587587
* that corresponds to the end of the prompt during the last inference.
588588
*/
589-
template <typename TokenT>
589+
template <typename TokenT, typename LogitT>
590590
size_t prefill(
591591
executorch::runtime::Span<TokenT> tokens,
592592
executorch::runtime::Span<TokenT> input_buffer,
593593
executorch::runtime::Method& method,
594-
std::function<void(executorch::runtime::Span<const float>)>
594+
std::function<void(executorch::runtime::Span<const LogitT>)>
595595
logits_callback = nullptr) {
596596
ET_LOG(Info, "Prefilling at position %zu", input_pos_);
597597
size_t input_len = input_buffer.size();
@@ -619,7 +619,7 @@ class StaticAttentionIOManager {
619619
batch_len);
620620
if (logits_callback) {
621621
auto logits_tensor = method.get_output(0).toTensor();
622-
auto* logits = logits_tensor.const_data_ptr<float>();
622+
auto* logits = logits_tensor.const_data_ptr<LogitT>();
623623
logits_callback(executorch::runtime::Span(
624624
logits,
625625
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));

examples/models/llama/static_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def __init__(
259259
}
260260

261261
rope = Rope(config)
262-
freqs = rope.get_freqs(None, config.max_seq_len)
262+
freqs = rope.get_freqs(None, config.max_context_len)
263263
self.freqs_cos = freqs[0].to(dtype)
264264
self.freqs_sin = freqs[1].to(dtype)
265265

examples/qualcomm/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,13 @@ This section outlines the essential APIs and utilities provided to streamline th
111111
Creates a clean directory for storing model outputs or intermediate results. If the directory already exists, it will be deleted and recreated to ensure a consistent environment for each run.
112112

113113
## Additional Dependency
114+
This example requires the following Python packages:
115+
- pandas and scikit-learn: used in the mobilebert multi-class text classification example.
116+
- graphviz (optional): used for visualizing QNN graphs during debugging.
114117

115-
The mobilebert multi-class text classification example requires `pandas` and `sklearn`.
116118
Please install them by something like
117-
118119
```bash
119-
pip install scikit-learn pandas
120+
pip install scikit-learn pandas graphviz
120121
```
121122

122123
## Limitation

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ We offer the following modes to execute the model:
3838
### Step 1: Setup
3939
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
4040
2. Follow the [tutorial](https://pytorch.org/executorch/main/backends-qualcomm) to build Qualcomm AI Engine Direct Backend.
41+
3. Please install the llm eval dependency via [examples/models/llama/install_requirements.sh](https://github.com/pytorch/executorch/blob/main/examples/models/llama/install_requirements.sh)
4142

4243
### Step 2: Prepare Model
4344

0 commit comments

Comments
 (0)