Skip to content

Commit bb60a20

Browse files
committed
Merge branch 'main' into jz/add-phi4
2 parents ef717db + 7aa6494 commit bb60a20

File tree

25 files changed

+495
-87
lines changed

25 files changed

+495
-87
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
# Copyright (c) Qualcomm Innovation Center, Inc.
3+
# All rights reserved
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
10+
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
11+
12+
export EXECUTORCH_ROOT="$(dirname "${BASH_SOURCE[0]}")/../.."
13+
14+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
15+
PYTHON_EXECUTABLE=python3
16+
fi
17+
18+
which "${PYTHON_EXECUTABLE}"
19+
20+
pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
21+
22+
# Download stories llama110m artifacts
23+
download_stories_model_artifacts
24+
25+
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w
26+
27+
popd

.ci/scripts/test_model.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ test_model_with_qnn() {
172172
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/
173173
export PYTHONPATH=$EXECUTORCH_ROOT/..
174174

175+
EXTRA_FLAGS=""
175176
if [[ "${MODEL_NAME}" == "dl3" ]]; then
176177
EXPORT_SCRIPT=deeplab_v3
177178
elif [[ "${MODEL_NAME}" == "mv3" ]]; then
@@ -184,6 +185,12 @@ test_model_with_qnn() {
184185
EXPORT_SCRIPT=inception_v3
185186
elif [[ "${MODEL_NAME}" == "vit" ]]; then
186187
EXPORT_SCRIPT=torchvision_vit
188+
elif [[ "${MODEL_NAME}" == "mb" ]]; then
189+
EXPORT_SCRIPT=mobilebert_fine_tune
190+
EXTRA_FLAGS="--num_epochs 1"
191+
pip install scikit-learn
192+
elif [[ "${MODEL_NAME}" == "w2l" ]]; then
193+
EXPORT_SCRIPT=wav2letter
187194
elif [[ "${MODEL_NAME}" == "edsr" ]]; then
188195
EXPORT_SCRIPT=edsr
189196
# Additional deps for edsr
@@ -197,7 +204,7 @@ test_model_with_qnn() {
197204
# TODO(guangyang): Make QNN chipset matches the target device
198205
QNN_CHIPSET=SM8450
199206

200-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only
207+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only $EXTRA_FLAGS
201208
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
202209
}
203210

.github/workflows/trunk.yml

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,28 @@ jobs:
229229
# see if we can import the module successfully
230230
${CONDA_RUN} python -c "from executorch.extension.pybindings import portable_lib; print('success!')"
231231
232+
test-static-llama-ane:
233+
name: test-static-llama-ane
234+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
235+
with:
236+
runner: macos-m1-stable
237+
python-version: '3.11'
238+
submodules: 'true'
239+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
240+
script: |
241+
set -eux
242+
bash .ci/scripts/setup-conda.sh
243+
eval "$(conda shell.bash hook)"
244+
245+
# Install requirements
246+
sh install_requirements.sh
247+
sh backends/apple/coreml/scripts/install_requirements.sh
248+
python install_executorch.py --pybind coreml
249+
sh examples/models/llama/install_requirements.sh
250+
251+
# Test ANE llama
252+
sh .ci/scripts/test_ane_static_llama.sh
253+
232254
test-llama-runner-macos:
233255
name: test-llama-runner-mac
234256
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
@@ -311,7 +333,7 @@ jobs:
311333
strategy:
312334
matrix:
313335
dtype: [fp32]
314-
model: [dl3, mv3, mv2, ic4, ic3, vit]
336+
model: [dl3, mv3, mv2, ic4, ic3, vit, mb, w2l]
315337
fail-fast: false
316338
with:
317339
runner: linux.2xlarge

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ python_library(
99
"//executorch/backends/transforms:replace_scalar_with_tensor",
1010
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1111
"//executorch/exir:lib",
12+
"//executorch/backends/transforms:utils",
1213
],
1314
)

backends/arm/_passes/arm_pass_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.

backends/arm/_passes/fuse_batchnorm2d_pass.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.transforms.utils import (
10+
create_constant_placeholder,
11+
delete_constant_placeholder,
12+
)
913
from executorch.exir import ExportedProgram
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115
from executorch.exir.pass_base import ExportPass, PassResult
1216
from torch._export.utils import get_buffer, get_param
17+
from torch.export.graph_signature import InputKind
1318
from torch.fx import Node
1419
from torch.nn.utils.fusion import fuse_conv_bn_weights
1520

@@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
2328
self.exported_program = exported_program
2429
super().__init__()
2530

26-
def is_fuseable_conv_bn(self, node: Node):
31+
def is_fuseable_conv_bn(self, node: Node) -> bool:
2732
"""Returns True if node is a batchnorm that can be fused into
2833
a parent convolution."""
2934
if node.op != "call_function":
@@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
4449
# Since we change the output of the conv, fuse only if it has single user.
4550
if len(conv.users) > 1:
4651
return False
47-
# For similar reasons, only fuse if conv parameters have single user.
48-
if len(conv.all_input_nodes[1].users) > 1:
49-
return False
50-
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
51-
return False
5252
return True
5353

54+
def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str:
55+
if conv_bias_node:
56+
return conv_bias_node.name + "_fused_bn"
57+
elif "weight" in conv_weight_node.name:
58+
return conv_weight_node.name.replace("weight", "bias") + "_fused_bn"
59+
else:
60+
return conv_weight_node.name + "_bias_fused_bn"
61+
5462
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
5563
modified = False
64+
constant_placeholders_to_delete = set()
5665
for node in graph_module.graph.nodes:
5766
if not self.is_fuseable_conv_bn(node):
5867
continue
@@ -64,68 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
6473
)
6574

6675
# Get weight, bias, mean, var and epsilon from the batchnorm
67-
bn = node
68-
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
69-
bn_weight = get_param_or_none(bn_weight_node)
70-
bn_bias = get_param_or_none(bn_bias_node)
71-
72-
running_mean = get_buffer(self.exported_program, bn_mean_node)
73-
running_var = get_buffer(self.exported_program, bn_var_node)
74-
if running_mean is None or running_var is None:
76+
bn_node = node
77+
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = (
78+
bn_node.args[0:5]
79+
)
80+
bn_weight_tensor = get_param_or_none(bn_weight_node)
81+
bn_bias_tensor = get_param_or_none(bn_bias_node)
82+
bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node)
83+
bn_var_tensor = get_buffer(self.exported_program, bn_var_node)
84+
if bn_mean_tensor is None or bn_var_tensor is None:
7585
raise ValueError(
7686
"Parameters running_mean and running_var of batchnorm can't be None."
7787
)
78-
epsilon = bn.args[-1]
88+
epsilon = bn_node.args[-1]
7989

8090
# Get weight and bias from conv
8191
conv_weight_node, conv_bias_node = conv.args[1:3]
82-
conv_weight = get_param(self.exported_program, conv_weight_node)
83-
conv_bias = get_param_or_none(conv_bias_node)
84-
if conv_weight is None:
92+
conv_weight_tensor = get_param(self.exported_program, conv_weight_node)
93+
conv_bias_tensor = get_param_or_none(conv_bias_node)
94+
if conv_weight_tensor is None:
8595
raise ValueError("Parameter weight of convolution can't be None.")
8696

8797
# Compute conv parameters folded with batchnorm
8898
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
89-
conv_weight,
90-
conv_bias,
91-
running_mean,
92-
running_var,
99+
conv_weight_tensor,
100+
conv_bias_tensor,
101+
bn_mean_tensor,
102+
bn_var_tensor,
93103
epsilon,
94-
bn_weight,
95-
bn_bias,
104+
bn_weight_tensor,
105+
bn_bias_tensor,
96106
)
97107

98-
# Set the conv parameters to fused value
99-
def try_set_param(
100-
param_node: Node | None, param_value: torch.nn.Parameter
101-
) -> bool:
102-
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
103-
if param_node is not None:
104-
param_name = (
105-
self.exported_program.graph_signature.inputs_to_parameters[
106-
param_node.name
107-
]
108+
# Create fused weights and bias to conv and replace conv args
109+
with graph_module.graph.inserting_before(conv_weight_node):
110+
fused_conv_weight_node = create_constant_placeholder(
111+
exp_program=self.exported_program,
112+
graph=graph_module.graph,
113+
kind=InputKind.PARAMETER,
114+
name=conv_weight_node.name + "_fused_bn",
115+
data=fused_conv_weight,
116+
)
117+
118+
if fused_conv_bias is not None:
119+
fused_conv_bias_node = create_constant_placeholder(
120+
exp_program=self.exported_program,
121+
graph=graph_module.graph,
122+
kind=InputKind.PARAMETER,
123+
name=self.get_bias_name(conv_weight_node, conv_bias_node),
124+
data=fused_conv_bias,
108125
)
109-
self.exported_program.state_dict[param_name] = param_value
110-
return True
111-
return False
126+
else:
127+
fused_conv_bias_node = None
128+
129+
conv.args = (
130+
conv.args[0],
131+
fused_conv_weight_node,
132+
fused_conv_bias_node,
133+
*conv.args[3:],
134+
)
112135

113-
try_set_param(conv_weight_node, fused_conv_weight)
114-
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
115-
bn_bias_node, fused_conv_bias
116-
):
117-
# pyre-ignore[60]
118-
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
119-
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
120-
conv.args = conv_args
121-
122-
# Erasing nodes is handled by dead-code elimination.
123-
for user in bn.users:
136+
# Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
137+
for user in bn_node.users:
124138
user.replace_all_uses_with(conv)
139+
140+
constant_placeholders_to_delete.update(
141+
[
142+
bn_weight_node,
143+
bn_bias_node,
144+
bn_mean_node,
145+
bn_var_node,
146+
conv_weight_node,
147+
conv_bias_node,
148+
]
149+
)
125150
modified = True
126151

127152
if modified:
128153
graph_module.graph.eliminate_dead_code()
154+
for constant_placeholder in constant_placeholders_to_delete:
155+
if (constant_placeholder is not None) and (
156+
len(constant_placeholder.users) == 0
157+
):
158+
delete_constant_placeholder(
159+
self.exported_program, constant_placeholder
160+
)
161+
129162
graph_module.recompile()
130163
graph_module = super().call(graph_module).graph_module
164+
131165
return PassResult(graph_module=graph_module, modified=modified)

backends/arm/test/models/test_w2l_arm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def test_w2l_u55_BI(self):
131131

132132
@pytest.mark.slow
133133
@pytest.mark.corstone_fvp
134-
@unittest.skip("Blocked by MLBEDSW-10420")
135134
@conftest.expectedFailureOnFVP # TODO: MLBEDSW-10093
136135
def test_w2l_u85_BI(self):
137136
tester = self._test_w2l_ethos_BI_pipeline(

backends/arm/test/passes/test_fuse_batchnorm_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ def forward(self, x):
8585
return x
8686

8787

88-
class MergeNoBN(torch.nn.Module):
88+
class MergeMultipleUsersBN(torch.nn.Module):
8989
ops_before_pass = {
9090
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
9191
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
9292
}
9393
ops_after_pass = {
94-
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
94+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1,
9595
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
9696
}
9797

@@ -122,7 +122,7 @@ def forward(self, x):
122122
z = self.conv2d2(x)
123123
a = self.batch_norm2d(
124124
y
125-
) # Can't be fused since paramters of conv2d2 have multiple users.
125+
) # Can be fused despite paramters of conv2d2 having multiple users.
126126

127127
return z, a
128128

@@ -131,7 +131,7 @@ def forward(self, x):
131131
"merge_one_of_two_bn_affine": MergeOneOfTwoBN(True),
132132
"merge_one_of_two_bn": MergeOneOfTwoBN(False),
133133
"merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True),
134-
"merge_no_bn_affine": MergeNoBN(True),
134+
"merge_multiple_users_bn_affine": MergeMultipleUsersBN(True),
135135
}
136136

137137

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
from executorch.examples.models.mobilenet_v3 import MV3Model
7474
from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel
7575

76-
# from executorch.examples.models.wav2letter import Wav2LetterModel
76+
from executorch.examples.models.wav2letter import Wav2LetterModel
7777
from executorch.exir import to_edge
7878
from executorch.exir.backend.backend_api import disable_validation
7979
from executorch.exir.passes import PassManager
@@ -907,8 +907,7 @@ def test_qnn_backend_example_models(self):
907907
# Fail during lowering Reopen once resolved
908908
# MobileBertModelExample(),
909909
# TorchVisionViTModel(),
910-
# Encountered undefined symbol in mainline. Reopen once resolved.
911-
# Wav2LetterModel(),
910+
Wav2LetterModel(),
912911
]
913912
expected_partitions = [
914913
1,
@@ -917,8 +916,8 @@ def test_qnn_backend_example_models(self):
917916
1,
918917
1,
919918
1,
920-
1,
921-
1,
919+
# 1,
920+
# 1,
922921
1,
923922
]
924923
# TODO: Due to trigger maximum recursion depth exceeded, need to check it.
@@ -1962,12 +1961,11 @@ def test_qnn_backend_example_models(self):
19621961
QCOM_ANNOTATION: (),
19631962
QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
19641963
},
1965-
# Encountered undefined symbol in mainline. Reopen once resolved.
1966-
# {
1967-
# QCOM_MODULE: Wav2LetterModel(),
1968-
# QCOM_ANNOTATION: (),
1969-
# QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1970-
# },
1964+
{
1965+
QCOM_MODULE: Wav2LetterModel(),
1966+
QCOM_ANNOTATION: (),
1967+
QCOM_QUANT_DTYPE: QuantDtype.use_8a8w,
1968+
},
19711969
]
19721970
expected_partitions = [
19731971
1,
@@ -1979,7 +1977,7 @@ def test_qnn_backend_example_models(self):
19791977
# For MobileBertModelExample
19801978
# 1,
19811979
1,
1982-
# 1, For Wav2LetterModel
1980+
1,
19831981
]
19841982
# TODO: Due to trigger maximum recursion depth exceeded, need to check it.
19851983
disable_validation()

0 commit comments

Comments
 (0)