Skip to content

Commit 4399aa8

Browse files
Merge branch 'main' into node_visitors_dtype_validation
2 parents 243bb6c + 9952aef commit 4399aa8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2237
-464
lines changed

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
strategy:
7272
fail-fast: false
7373
matrix:
74-
model: [linear, add, add_mul, resnet18, conv1d]
74+
model: [linear, add, add_mul, resnet18, conv1d, sdpa]
7575
with:
7676
timeout: 90
7777
runner: linux.g5.4xlarge.nvidia.gpu

backends/aoti/aoti_partitioner.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5252
partition_tags: Dict[str, DelegationSpec] = {}
5353
tag = "tag0"
5454

55+
# Tag torch.cond and other control flow operations
56+
def is_control_flow(node: torch.fx.Node) -> bool:
57+
return node.op == "call_function" and node.target in [
58+
torch.ops.higher_order.cond,
59+
torch.ops.higher_order.map_impl,
60+
torch.ops.higher_order.while_loop,
61+
]
62+
5563
for node in exported_program.graph.nodes:
56-
if node.op != "call_function":
57-
continue
58-
node.meta["delegation_tag"] = tag
64+
if node.op == "call_function":
65+
node.meta["delegation_tag"] = tag
66+
# Tag get_attr nodes that are used by control flow operations
67+
elif node.op == "get_attr":
68+
# Check if any user is a control flow operation
69+
for user in node.users:
70+
if is_control_flow(user):
71+
node.meta["delegation_tag"] = tag
72+
break
5973

6074
partition_tags[tag] = self.delegation_spec
6175

backends/arm/operator_support/clone_dim_order_support.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for dim-order clone in TOSA.
6+
7+
This module registers a support check for ``dim_order_ops._clone_dim_order``
8+
ensuring input/output dtypes match and the value types are FakeTensors.
9+
10+
"""
511

612
import logging
713

@@ -19,6 +25,8 @@
1925

2026
@register_tosa_support_check
2127
class CloneSupported(SupportedTOSAOperatorCheck):
28+
"""Provide TOSA support check for ``_clone_dim_order``."""
29+
2230
targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default]
2331

2432
tosa_specs = [
@@ -29,6 +37,12 @@ class CloneSupported(SupportedTOSAOperatorCheck):
2937
def is_node_tosa_supported(
3038
self, node: fx.Node, tosa_spec: TosaSpecification
3139
) -> bool:
40+
"""Return True if the node is supported by TOSA.
41+
42+
Verify the operator target, the number and types of inputs/outputs, and
43+
check that input and output dtypes match.
44+
45+
"""
3246
if node.target not in self.targets:
3347
self.reporter.report_reject(node, f"Target {node.target} is not supported.")
3448
return False

backends/arm/tosa/dialect/lib.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515

1616

1717
def register_tosa_dialect_op(op_schema, func) -> Callable:
18+
"""Register a TOSA dialect operator with the backend op library.
19+
20+
Args:
21+
op_schema (str): Operator schema without namespace or overload name.
22+
func (Callable): Fake implementation used for registration.
23+
24+
Returns:
25+
Callable: Backend dialect operator handle exposed via ``exir_ops`` and
26+
marked ``not_callable`` for runtime use.
27+
28+
"""
1829
if tosa_lib.ns not in _BACKEND_OP_LIB:
1930
_BACKEND_OP_LIB.append(tosa_lib.ns)
2031

@@ -43,6 +54,7 @@ def register_tosa_dialect_op(op_schema, func) -> Callable:
4354
# the op doesn't need to be callable. This can be changed in the future if needed to support
4455
# execution of TOSA ops directly.
4556
def not_callable():
57+
"""Raise when the dialect op handle is invoked at runtime."""
4658
raise RuntimeError("TOSA dialect op is not callable")
4759

4860
op.__equvalent_callable__ = not_callable
@@ -51,11 +63,22 @@ def not_callable():
5163

5264

5365
class TosaValueError(ValueError):
66+
"""Error type that annotates failures with the originating TOSA op."""
67+
5468
def __init__(self, message="A TOSA value error occurred", *args, op=None):
69+
"""Initialise the error with optional operator metadata.
70+
71+
Args:
72+
message (str): Human-readable error message.
73+
*args: Additional arguments forwarded to ``ValueError``.
74+
op: Optional operator identifier included in the string output.
75+
76+
"""
5577
super().__init__(message, *args)
5678
self.op = op
5779

5880
def __str__(self):
81+
"""Return the base message, appending the operator when provided."""
5982
base_message = super().__str__()
6083
if self.op is not None:
6184
return f"{base_message} (TOSA op: {self.op})"

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def convert_pt2(
144144
# It is however useful for unit tests to separate the converted model from the
145145
# fused model, to be able to get reference numerics.
146146
# If this does not apply, please use quantize_pt2 instead.
147-
def fuse_pt2(
147+
def apply_pre_edge_transform_passes(
148148
converted_program: ExportedProgram,
149149
quantizer: CadenceQuantizer,
150150
) -> ExportedProgram:
@@ -229,7 +229,7 @@ def quantize_pt2(
229229

230230
# Apply quant fusion to the exported program
231231
program = torch.export.export(converted_gm, inputs, strict=True)
232-
fused_program = fuse_pt2(program, quantizer)
232+
fused_program = apply_pre_edge_transform_passes(program, quantizer)
233233

234234
if dump_graphs:
235235
logging.info("Graph after quantization and fusion:")

backends/cadence/aot/export_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
from executorch.backends.cadence.aot.compiler import (
2020
_lower_ep_to_cadence_gen_etrecord,
21+
apply_pre_edge_transform_passes,
2122
convert_pt2,
22-
fuse_pt2,
2323
prepare_pt2,
2424
)
2525

@@ -66,7 +66,7 @@ def export_model(
6666
ep = torch.export.export(converted_model, example_inputs, strict=True)
6767

6868
# Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2)
69-
ep = fuse_pt2(ep, quantizer)
69+
ep = apply_pre_edge_transform_passes(ep, quantizer)
7070

7171
# Get edge program after Cadence specific passes
7272
exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord(

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ set(_cortex_m_kernels__srcs
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
6060
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp
6162
)
6263

6364
# Generate C++ bindings to register kernels into Executorch
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
#include <array>
11+
#include <limits>
12+
#include <vector>
13+
14+
// Include CMSIS-NN headers with C linkage
15+
extern "C" {
16+
#include "arm_nnfunctions.h"
17+
}
18+
19+
namespace cortex_m {
20+
namespace native {
21+
22+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
23+
24+
namespace {
25+
26+
constexpr size_t kMaxSupportedDims = 4;
27+
28+
} // namespace
29+
30+
Tensor& transpose_out(
31+
KernelRuntimeContext& context,
32+
const Tensor& input,
33+
const IntArrayRef perm,
34+
Tensor& out) {
35+
if (input.scalar_type() != ScalarType::Char ||
36+
out.scalar_type() != ScalarType::Char) {
37+
ET_LOG(
38+
Error,
39+
"transpose_out: only int8 tensors are supported (input=%d, out=%d)",
40+
static_cast<int>(input.scalar_type()),
41+
static_cast<int>(out.scalar_type()));
42+
context.fail(Error::InvalidArgument);
43+
return out;
44+
}
45+
46+
const size_t rank = input.dim();
47+
if (rank == 0 || rank > kMaxSupportedDims) {
48+
ET_LOG(
49+
Error,
50+
"transpose_out: expected tensor rank in [1, %zu], got %zu",
51+
kMaxSupportedDims,
52+
rank);
53+
context.fail(Error::InvalidArgument);
54+
return out;
55+
}
56+
57+
if (perm.size() != static_cast<int64_t>(rank)) {
58+
ET_LOG(
59+
Error,
60+
"transpose_out: permutation length %zd does not match tensor rank %zu",
61+
perm.size(),
62+
rank);
63+
context.fail(Error::InvalidArgument);
64+
return out;
65+
}
66+
67+
std::array<int32_t, kMaxSupportedDims> input_dims_arr{1, 1, 1, 1};
68+
std::array<int32_t, kMaxSupportedDims> output_dims_arr{1, 1, 1, 1};
69+
for (size_t i = 0; i < rank; ++i) {
70+
const auto in_size = input.size(i);
71+
const auto out_size = out.size(i);
72+
if (in_size > std::numeric_limits<int32_t>::max() ||
73+
out_size > std::numeric_limits<int32_t>::max()) {
74+
ET_LOG(
75+
Error,
76+
"transpose_out: dimension size exceeds int32_t range (input=%lld, output=%lld)",
77+
static_cast<long long>(in_size),
78+
static_cast<long long>(out_size));
79+
context.fail(Error::InvalidArgument);
80+
return out;
81+
}
82+
input_dims_arr[i] = static_cast<int32_t>(in_size);
83+
output_dims_arr[i] = static_cast<int32_t>(out_size);
84+
}
85+
86+
cmsis_nn_dims input_dims = {
87+
input_dims_arr[0],
88+
input_dims_arr[1],
89+
input_dims_arr[2],
90+
input_dims_arr[3]};
91+
cmsis_nn_dims output_dims = {
92+
output_dims_arr[0],
93+
output_dims_arr[1],
94+
output_dims_arr[2],
95+
output_dims_arr[3]};
96+
97+
std::array<uint32_t, kMaxSupportedDims> perm_buffer{0, 1, 2, 3};
98+
for (size_t i = 0; i < rank; ++i) {
99+
perm_buffer[i] = static_cast<uint32_t>(perm[i]);
100+
}
101+
102+
const cmsis_nn_transpose_params transpose_params{
103+
static_cast<int32_t>(rank), perm_buffer.data()};
104+
105+
const int8_t* input_data = input.const_data_ptr<int8_t>();
106+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
107+
108+
const arm_cmsis_nn_status status = arm_transpose_s8(
109+
input_data, output_data, &input_dims, &output_dims, &transpose_params);
110+
111+
if (status != ARM_CMSIS_NN_SUCCESS) {
112+
ET_LOG(
113+
Error,
114+
"transpose_out: arm_transpose_s8 failed with status [%d]",
115+
static_cast<int>(status));
116+
context.fail(Error::Internal);
117+
return out;
118+
}
119+
120+
return out;
121+
}
122+
123+
} // namespace native
124+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,21 @@ def quantized_linear_impl(
349349
output += output_offset
350350
output = torch.clamp(output, activation_min, activation_max).to(torch.int8)
351351
return output
352+
353+
354+
# ===================================================================
355+
# TRANSPOSE OPERATION DEFINITION
356+
# ===================================================================
357+
lib.define("transpose(Tensor input, int[] perm) -> Tensor")
358+
lib.define("transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)")
359+
360+
361+
@register_fake("cortex_m::transpose")
362+
def transpose_meta(input: torch.Tensor, perm) -> torch.Tensor:
363+
output_shape = [input.shape[idx] for idx in perm]
364+
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
365+
366+
367+
@impl(lib, "transpose", "CompositeExplicitAutograd")
368+
def transpose_impl(input: torch.Tensor, perm) -> torch.Tensor:
369+
return input.permute(tuple(perm)).contiguous()

backends/cortex_m/ops/operators.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@
3434
kernels:
3535
- arg_meta: null
3636
kernel_name: cortex_m::quantized_linear_out
37+
38+
- func: cortex_m::transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)
39+
variants: function
40+
kernels:
41+
- arg_meta: null
42+
kernel_name: cortex_m::transpose_out

0 commit comments

Comments
 (0)