From 25e841e74f62fba9b3cd55aebfe1ebd14d093630 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 16 Mar 2026 20:47:09 -0700 Subject: [PATCH 01/16] Update the NCHWc transformer to handle more patterns --- .../core/optimizer/nchwc_transformer.cc | 97 ++++++++++++++++++- .../test/optimizer/nchwc_optimizer_test.cc | 57 +++++++++-- 2 files changed, 142 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index e8c3bf24a612f..2c41c8180a1c3 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -122,6 +122,7 @@ class NchwcTransformerImpl { void TransformConv(Node& node); void TransformPool(Node& node); void TransformBinary(Node& node, bool add_node); + void TransformMul(Node& node); void TransformConcat(Node& node); void TransformActivation(Node& node); void TransformBatchNormalization(Node& node); @@ -734,6 +735,91 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { } } +void NchwcTransformerImpl::TransformMul(Node& node) { + auto& input_defs = node.MutableInputDefs(); + auto& output_defs = node.MutableOutputDefs(); + + if (input_defs.size() != 2 || output_defs.size() != 1) { + return; + } + + auto* nchwc_input_0 = LookupNchwcArgument(input_defs[0]); + auto* nchwc_input_1 = LookupNchwcArgument(input_defs[1]); + + // If both inputs are already NCHWc arguments, use the regular binary path. + if (nchwc_input_0 != nullptr && nchwc_input_1 != nullptr) { + TransformBinary(node, false); + return; + } + + // Exactly one input must be NCHWc and the other must be a static scale tensor. + if ((nchwc_input_0 == nullptr) == (nchwc_input_1 == nullptr)) { + return; + } + + const int nchwc_input_index = (nchwc_input_0 != nullptr) ? 0 : 1; + const int scale_input_index = nchwc_input_index ^ 1; + auto* nchwc_input = (nchwc_input_index == 0) ? nchwc_input_0 : nchwc_input_1; + + const auto* mul_scale_tensor_proto = graph_utils::GetConstantInitializer(graph_, input_defs[scale_input_index]->Name()); + if (mul_scale_tensor_proto == nullptr || + mul_scale_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return; + } + + const int64_t channels = nchwc_input->channels_; + + Initializer mul_scale{graph_, *mul_scale_tensor_proto, graph_.ModelPath()}; + const auto scale_dims = mul_scale.dims(); + + bool channel_broadcast_shape = false; + if (scale_dims.size() == 1) { + channel_broadcast_shape = (scale_dims[0] == channels); + } else if (scale_dims.size() == 3) { + channel_broadcast_shape = (scale_dims[0] == channels && scale_dims[1] == 1 && scale_dims[2] == 1); + } else if (scale_dims.size() == 4) { + channel_broadcast_shape = (scale_dims[0] == 1 && scale_dims[1] == channels && scale_dims[2] == 1 && scale_dims[3] == 1); + } + + if (!channel_broadcast_shape) { + return; + } + + const size_t nchwc_block_size = MlasNchwcGetBlockSize(); + const int64_t nchwc_channels = (channels + static_cast(nchwc_block_size) - 1) & ~static_cast(nchwc_block_size - 1); + + InlinedVector padded_scale(gsl::narrow(nchwc_channels)); + std::copy_n(mul_scale.data(), channels, padded_scale.data()); + + ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto; + nchwc_conv_W_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + nchwc_conv_W_tensor_proto.set_name(graph_.GenerateNodeArgName("mul_scale")); + utils::SetRawDataInTensorProto(nchwc_conv_W_tensor_proto, padded_scale.data(), + gsl::narrow(nchwc_channels) * sizeof(float)); + nchwc_conv_W_tensor_proto.add_dims(nchwc_channels); + nchwc_conv_W_tensor_proto.add_dims(1); + nchwc_conv_W_tensor_proto.add_dims(1); + nchwc_conv_W_tensor_proto.add_dims(1); + + auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithOrtValue(graph_, nchwc_conv_W_tensor_proto); + + std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_mul_nchwc"); + Node& nchwc_node = graph_.AddNode(nchwc_node_name, + "Conv", + nchwc_node_name, + std::array{nchwc_input->nchwc_arg_, nchwc_conv_W_arg}, + output_defs, + nullptr, + kMSNchwcDomain); + nchwc_node.SetExecutionProviderType(kCpuExecutionProvider); + nchwc_node.AddAttribute("group", nchwc_channels); + + nchwc_input->remaining_original_uses_--; + + CreateNchwcArgument(node, nchwc_node, channels, nchwc_input->shape_); + removed_nodes_.push_front(node.Index()); +} + void NchwcTransformerImpl::TransformConcat(Node& node) { auto& input_defs = node.MutableInputDefs(); auto& output_defs = node.MutableOutputDefs(); @@ -794,7 +880,12 @@ void NchwcTransformerImpl::TransformActivation(Node& node) { // Check if this is a single use NCHWc convolution that hasn't already // been fused with another activation. auto& nchwc_node = nchwc_input->output_node_; + + const bool can_fuse_activation = (node.OpType() == "Relu") || + (node.OpType() == "Sigmoid") || + (node.OpType() == "Tanh"); if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) && + can_fuse_activation && (nchwc_input->starting_original_uses_ == 1) && (graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) { nchwc_node.AddAttribute("activation", node.OpType()); @@ -1169,12 +1260,14 @@ void NchwcTransformerImpl::Transform(Node& node) { graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) { TransformBinary(node, true); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) { - TransformBinary(node, false); + TransformMul(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) { + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain) { TransformActivation(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14, 15})) { TransformBatchNormalization(node); diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 7dfa5c7812f6e..2669523287921 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -85,12 +85,15 @@ struct NchwcTestHelper { Node& AddNode(const std::string& op_type, const std::vector& input_args, - const std::vector& output_args) { + const std::vector& output_args, + const std::string& domain = kOnnxDomain) { return graph_.AddNode(graph_.GenerateNodeName("node"), op_type, "description", input_args, - output_args); + output_args, + nullptr, + domain); } Node& AddConvNode(NodeArg* input_arg, NodeArg* output_arg, const std::vector& weights_shape, bool no_bias = false) { @@ -706,6 +709,36 @@ TEST(NchwcOptimizerTests, ConvBinaryBroadcast) { } } +TEST(NchwcOptimizerTests, ConvMulChannelScale) { + auto test_case = [&](const std::vector& scale_shape) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 32, 25, 21}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* mul_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {32, 32, 3, 3}); + auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); + helper.AddNode("Mul", {conv_output_arg, scale_arg}, {mul_output_arg}); + helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 3); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["Mul"], 0); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + test_case({32}); + test_case({32, 1, 1}); + test_case({1, 32, 1, 1}); +} + TEST(NchwcOptimizerTests, ConvConcat) { auto test_case = [&](int axis, int channel_count, int reorder_output_count) { auto build_test_case = [&](NchwcTestHelper& helper) { @@ -1342,7 +1375,7 @@ TEST(NchwcOptimizerTests, UpsampleLinear) { } TEST(NchwcOptimizerTests, Activation) { - auto test_case = [&](const std::string& activation_op_type) { + auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 48, 11, 15}); auto* conv1_output_arg = helper.MakeIntermediate(); @@ -1351,7 +1384,7 @@ TEST(NchwcOptimizerTests, Activation) { auto* output_arg = helper.MakeOutput(); helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3}); - helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}); + helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}, domain); helper.AddNode("Add", {conv1_output_arg, activation_output_arg}, {mul_output_arg}); helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1}); }; @@ -1368,12 +1401,16 @@ TEST(NchwcOptimizerTests, Activation) { NchwcOptimizerTester(build_test_case, check_nchwc_graph); }; - // Verify that the optimizer doesn't add reorders for these activations that - // cannot be fused with a convolution. - std::vector activation_op_types{"Relu", "Sigmoid", "Tanh"}; - for (auto& activation_op_type : activation_op_types) { - test_case(activation_op_type); - } + // Verify that the optimizer doesn't add reorders for these activations in + // this pattern. Relu/Sigmoid/Tanh are generally fusable with a preceding + // convolution, but not here because the Conv output is consumed both by the + // activation node and directly by the Add node. Gelu/QuickGelu are also + // expected to remain as separate nodes. + test_case("Relu"); + test_case("Sigmoid"); + test_case("Tanh"); + test_case("Gelu", kMSDomain); + test_case("QuickGelu", kMSDomain); } TEST(NchwcOptimizerTests, MaxPoolTypeCheck) { From b2ca1a1400f5612dd3f462b2844c220bff6cec06 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 16 Mar 2026 21:10:40 -0700 Subject: [PATCH 02/16] Build break --- onnxruntime/core/optimizer/nchwc_transformer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 2c41c8180a1c3..5627d9ef60ab8 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -1267,7 +1267,7 @@ void NchwcTransformerImpl::Transform(Node& node) { graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain) { + graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) { TransformActivation(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14, 15})) { TransformBatchNormalization(node); From 3069a5d9df1f1557d5fd6cffb84dc77849a488f2 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 09:52:58 -0700 Subject: [PATCH 03/16] Fix --- onnxruntime/core/optimizer/nchwc_transformer.cc | 4 +--- onnxruntime/test/optimizer/nchwc_optimizer_test.cc | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 5627d9ef60ab8..155251b3cba65 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -773,9 +773,7 @@ void NchwcTransformerImpl::TransformMul(Node& node) { const auto scale_dims = mul_scale.dims(); bool channel_broadcast_shape = false; - if (scale_dims.size() == 1) { - channel_broadcast_shape = (scale_dims[0] == channels); - } else if (scale_dims.size() == 3) { + if (scale_dims.size() == 3) { channel_broadcast_shape = (scale_dims[0] == channels && scale_dims[1] == 1 && scale_dims[2] == 1); } else if (scale_dims.size() == 4) { channel_broadcast_shape = (scale_dims[0] == 1 && scale_dims[1] == channels && scale_dims[2] == 1 && scale_dims[3] == 1); diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 2669523287921..b894b8540f824 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -734,7 +734,7 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { NchwcOptimizerTester(build_test_case, check_nchwc_graph); }; - test_case({32}); + // Valid ONNX channel broadcast forms for NCHW tensors. test_case({32, 1, 1}); test_case({1, 32, 1, 1}); } From 3e16b181f28ed09fe581936f70c9d7c5b712ba5d Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 14:29:24 -0700 Subject: [PATCH 04/16] Fixes --- .../core/optimizer/nchwc_transformer.cc | 7 +++++-- .../test/optimizer/nchwc_optimizer_test.cc | 18 +++++++++++++----- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 155251b3cba65..f1a70c7bf3d99 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -1248,6 +1248,11 @@ void NchwcTransformerImpl::Transform(Node& node) { } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10, 11, 12, 22}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "AveragePool", {1, 7, 10, 11, 19, 22})) { TransformPool(node); + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) { + // Mul can be converted when exactly one input is already NCHWc and the + // other input is a static channel-scale tensor, so it does not need to + // wait for all input edges to be removed. + TransformMul(node); } else if (node.GetInputEdgesCount() == 0 && node.InputDefs().size() != 0) { // The following transforms only run when the input edge count has already // been decremented to zero by earlier transforms. This is a hint that the @@ -1257,8 +1262,6 @@ void NchwcTransformerImpl::Transform(Node& node) { if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) { TransformBinary(node, true); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) { - TransformMul(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index b894b8540f824..a1a29d4daa752 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -180,6 +180,7 @@ void NchwcOptimizerTester(const std::function& bu // Build the model for this test. std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = opset_version; + domain_to_version[kMSDomain] = 1; Model model("nchwc", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); NchwcTestHelper helper(model.MainGraph()); @@ -710,7 +711,7 @@ TEST(NchwcOptimizerTests, ConvBinaryBroadcast) { } TEST(NchwcOptimizerTests, ConvMulChannelScale) { - auto test_case = [&](const std::vector& scale_shape) { + auto test_case = [&](const std::vector& scale_shape, bool scale_first) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 32, 25, 21}); auto* conv_output_arg = helper.MakeIntermediate(); @@ -719,7 +720,11 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { helper.AddConvNode(input_arg, conv_output_arg, {32, 32, 3, 3}); auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); - helper.AddNode("Mul", {conv_output_arg, scale_arg}, {mul_output_arg}); + if (scale_first) { + helper.AddNode("Mul", {scale_arg, conv_output_arg}, {mul_output_arg}); + } else { + helper.AddNode("Mul", {conv_output_arg, scale_arg}, {mul_output_arg}); + } helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1}); }; @@ -735,8 +740,10 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { }; // Valid ONNX channel broadcast forms for NCHW tensors. - test_case({32, 1, 1}); - test_case({1, 32, 1, 1}); + test_case({32, 1, 1}, false); + test_case({32, 1, 1}, true); + test_case({1, 32, 1, 1}, false); + test_case({1, 32, 1, 1}, true); } TEST(NchwcOptimizerTests, ConvConcat) { @@ -1391,10 +1398,11 @@ TEST(NchwcOptimizerTests, Activation) { auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); + const std::string activation_key = domain.empty() ? activation_op_type : domain + "." + activation_op_type; EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); - EXPECT_EQ(op_to_count[activation_op_type], 1); + EXPECT_EQ(op_to_count[activation_key], 1); EXPECT_EQ(op_to_count["Add"], 1); }; From a5d282dce24d3ae548a9bc557ebc269be8329859 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 15:22:15 -0700 Subject: [PATCH 05/16] Fix test --- .../test/optimizer/nchwc_optimizer_test.cc | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index a1a29d4daa752..c138de0c17645 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -711,21 +711,26 @@ TEST(NchwcOptimizerTests, ConvBinaryBroadcast) { } TEST(NchwcOptimizerTests, ConvMulChannelScale) { - auto test_case = [&](const std::vector& scale_shape, bool scale_first) { + const int64_t channels = static_cast(MlasNchwcGetBlockSize()) * 2; + + auto test_case = [&](bool use_explicit_batch_dim, bool scale_first) { auto build_test_case = [&](NchwcTestHelper& helper) { - auto* input_arg = helper.MakeInput({1, 32, 25, 21}); + auto* input_arg = helper.MakeInput({1, channels, 25, 21}); auto* conv_output_arg = helper.MakeIntermediate(); auto* mul_output_arg = helper.MakeIntermediate(); auto* output_arg = helper.MakeOutput(); - helper.AddConvNode(input_arg, conv_output_arg, {32, 32, 3, 3}); + helper.AddConvNode(input_arg, conv_output_arg, {channels, channels, 3, 3}); + const std::vector scale_shape = use_explicit_batch_dim + ? std::vector{1, channels, 1, 1} + : std::vector{channels, 1, 1}; auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); if (scale_first) { helper.AddNode("Mul", {scale_arg, conv_output_arg}, {mul_output_arg}); } else { helper.AddNode("Mul", {conv_output_arg, scale_arg}, {mul_output_arg}); } - helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1}); + helper.AddConvNode(mul_output_arg, output_arg, {16, channels, 1, 1}); }; auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { @@ -740,10 +745,10 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { }; // Valid ONNX channel broadcast forms for NCHW tensors. - test_case({32, 1, 1}, false); - test_case({32, 1, 1}, true); - test_case({1, 32, 1, 1}, false); - test_case({1, 32, 1, 1}, true); + test_case(false, false); + test_case(false, true); + test_case(true, false); + test_case(true, true); } TEST(NchwcOptimizerTests, ConvConcat) { From 167a70bf6e4f62ba8cfe1d4820e6023efec337ff Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 17:11:08 -0700 Subject: [PATCH 06/16] Fixes --- onnxruntime/test/optimizer/nchwc_optimizer_test.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index c138de0c17645..c6e5ebeea9b72 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -717,7 +717,6 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, channels, 25, 21}); auto* conv_output_arg = helper.MakeIntermediate(); - auto* mul_output_arg = helper.MakeIntermediate(); auto* output_arg = helper.MakeOutput(); helper.AddConvNode(input_arg, conv_output_arg, {channels, channels, 3, 3}); @@ -726,16 +725,16 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { : std::vector{channels, 1, 1}; auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); if (scale_first) { - helper.AddNode("Mul", {scale_arg, conv_output_arg}, {mul_output_arg}); + helper.AddNode("Mul", {scale_arg, conv_output_arg}, {output_arg}); } else { - helper.AddNode("Mul", {conv_output_arg, scale_arg}, {mul_output_arg}); + helper.AddNode("Mul", {conv_output_arg, scale_arg}, {output_arg}); } - helper.AddConvNode(mul_output_arg, output_arg, {16, channels, 1, 1}); }; auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 3); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["Conv"], 0); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); EXPECT_EQ(op_to_count["Mul"], 0); From 4c7af22471358523ddcc2a9a0de93c57bb8b412a Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 17:47:35 -0700 Subject: [PATCH 07/16] Test debug --- onnxruntime/test/optimizer/nchwc_optimizer_test.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index c6e5ebeea9b72..7edbb36f7eb8e 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -733,10 +733,12 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); - EXPECT_EQ(op_to_count["Conv"], 0); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + // TODO: Re-enable the Conv count checks once the remaining platform- + // specific behavior is understood. + // EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"] + op_to_count["Conv"], 2); + // EXPECT_GE(op_to_count["com.microsoft.nchwc.Conv"], 1); + //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); EXPECT_EQ(op_to_count["Mul"], 0); }; From cc97e59d9abd7ef5976261bb7649575ed30734f5 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 20:33:59 -0700 Subject: [PATCH 08/16] More debugging --- .../test/optimizer/nchwc_optimizer_test.cc | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 7edbb36f7eb8e..8f8b64f7125b5 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -171,7 +171,8 @@ struct NchwcTestHelper { void NchwcOptimizerTester(const std::function& build_test_case, const std::function& check_nchwc_graph, - int opset_version = 13) { + int opset_version = 13, + const std::function& check_pre_optimization_graph = nullptr) { // Ignore the test if NCHWc is not supported by the platform. if (MlasNchwcGetBlockSize() <= 1) { return; @@ -187,6 +188,10 @@ void NchwcOptimizerTester(const std::function& bu build_test_case(helper); ASSERT_STATUS_OK(model.MainGraph().Resolve()); + if (check_pre_optimization_graph) { + check_pre_optimization_graph(model.MainGraph()); + } + // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); @@ -731,18 +736,27 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { } }; + auto check_pre_optimization_graph = [&](const Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count["Mul"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 0); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 0); + }; + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); // TODO: Re-enable the Conv count checks once the remaining platform- // specific behavior is understood. - // EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"] + op_to_count["Conv"], 2); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); // EXPECT_GE(op_to_count["com.microsoft.nchwc.Conv"], 1); //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); EXPECT_EQ(op_to_count["Mul"], 0); }; - NchwcOptimizerTester(build_test_case, check_nchwc_graph); + NchwcOptimizerTester(build_test_case, check_nchwc_graph, 13, check_pre_optimization_graph); }; // Valid ONNX channel broadcast forms for NCHW tensors. From ef2b74701187ef7bfc262a7e4a1fc4da980818b6 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Tue, 17 Mar 2026 21:23:55 -0700 Subject: [PATCH 09/16] Debug --- .../test/optimizer/nchwc_optimizer_test.cc | 76 ++++++++++++++++++- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 8f8b64f7125b5..df784ef8ea256 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -12,13 +12,77 @@ #include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" +#include #include +#include +#include #include "gtest/gtest.h" namespace onnxruntime { namespace test { +namespace { + +bool ShouldDumpNchwcOptimizerGraphs() { + const char* value = std::getenv("ORT_NCHWC_TEST_DUMP_GRAPH"); + return value == nullptr || value[0] == '\0' || std::string(value) != "0"; +} + +void DumpNchwcOptimizerGraph(const Graph& graph, const std::string& test_name, const std::string& stage) { + if (!ShouldDumpNchwcOptimizerGraphs()) { + return; + } + + std::ostringstream stream; + stream << "===== NCHWC graph dump: " << test_name << " [" << stage << "] =====\n"; + + const auto op_to_count = CountOpsInGraph(graph); + stream << "Op counts:\n"; + for (const auto& entry : op_to_count) { + stream << " " << entry.first << ": " << entry.second << "\n"; + } + + stream << "Nodes:\n"; + for (const auto& node : graph.Nodes()) { + stream << " " << node.Name() << " : " + << (node.Domain().empty() ? kOnnxDomain : node.Domain()) + << "." << node.OpType() << "\n"; + + stream << " Inputs:"; + for (const auto* input_def : node.InputDefs()) { + stream << " " << (input_def == nullptr ? "" : input_def->Name()); + } + stream << "\n"; + + stream << " Outputs:"; + for (const auto* output_def : node.OutputDefs()) { + stream << " " << (output_def == nullptr ? "" : output_def->Name()); + } + stream << "\n"; + } + + stream << "Graph outputs:"; + for (const auto* output : graph.GetOutputs()) { + stream << " " << (output == nullptr ? "" : output->Name()); + } + stream << "\n"; + stream << "===== End NCHWC graph dump =====\n"; + + const std::string text = stream.str(); + std::cerr << text; + + const char* file_path = std::getenv("ORT_NCHWC_TEST_DUMP_GRAPH_FILE"); + if (file_path != nullptr && file_path[0] != '\0') { + std::ofstream output_stream(file_path, std::ios::app); + if (output_stream.is_open()) { + output_stream << text; + } + } +} + +} // namespace + struct NchwcTestHelper { NchwcTestHelper(Graph& graph) : graph_(graph), fill_value_(0), per_sample_tolerance_(0.0) { } @@ -738,6 +802,12 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { auto check_pre_optimization_graph = [&](const Graph& graph) { auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["Conv"] != 1 || op_to_count["Mul"] != 1 || + op_to_count["com.microsoft.nchwc.Conv"] != 0 || + op_to_count["com.microsoft.nchwc.ReorderInput"] != 0 || + op_to_count["com.microsoft.nchwc.ReorderOutput"] != 0) { + DumpNchwcOptimizerGraph(graph, "NchwcOptimizerTests.ConvMulChannelScale", "pre_optimization_failure"); + } EXPECT_EQ(op_to_count["Conv"], 1); EXPECT_EQ(op_to_count["Mul"], 1); EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0); @@ -747,10 +817,10 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - // TODO: Re-enable the Conv count checks once the remaining platform- - // specific behavior is understood. + if (op_to_count["com.microsoft.nchwc.Conv"] != 2 || op_to_count["Mul"] != 0) { + DumpNchwcOptimizerGraph(session.GetGraph(), "NchwcOptimizerTests.ConvMulChannelScale", "post_optimization_failure"); + } EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); - // EXPECT_GE(op_to_count["com.microsoft.nchwc.Conv"], 1); //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); EXPECT_EQ(op_to_count["Mul"], 0); From 18b025aeaf22606aba37c525be744c701d4ec5c7 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Wed, 18 Mar 2026 00:49:38 -0700 Subject: [PATCH 10/16] Fix test once and for all hopefully --- .../test/optimizer/nchwc_optimizer_test.cc | 85 ++----------------- 1 file changed, 8 insertions(+), 77 deletions(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index df784ef8ea256..e9ea29b8c07f5 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -12,77 +12,13 @@ #include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" -#include #include -#include -#include #include "gtest/gtest.h" namespace onnxruntime { namespace test { -namespace { - -bool ShouldDumpNchwcOptimizerGraphs() { - const char* value = std::getenv("ORT_NCHWC_TEST_DUMP_GRAPH"); - return value == nullptr || value[0] == '\0' || std::string(value) != "0"; -} - -void DumpNchwcOptimizerGraph(const Graph& graph, const std::string& test_name, const std::string& stage) { - if (!ShouldDumpNchwcOptimizerGraphs()) { - return; - } - - std::ostringstream stream; - stream << "===== NCHWC graph dump: " << test_name << " [" << stage << "] =====\n"; - - const auto op_to_count = CountOpsInGraph(graph); - stream << "Op counts:\n"; - for (const auto& entry : op_to_count) { - stream << " " << entry.first << ": " << entry.second << "\n"; - } - - stream << "Nodes:\n"; - for (const auto& node : graph.Nodes()) { - stream << " " << node.Name() << " : " - << (node.Domain().empty() ? kOnnxDomain : node.Domain()) - << "." << node.OpType() << "\n"; - - stream << " Inputs:"; - for (const auto* input_def : node.InputDefs()) { - stream << " " << (input_def == nullptr ? "" : input_def->Name()); - } - stream << "\n"; - - stream << " Outputs:"; - for (const auto* output_def : node.OutputDefs()) { - stream << " " << (output_def == nullptr ? "" : output_def->Name()); - } - stream << "\n"; - } - - stream << "Graph outputs:"; - for (const auto* output : graph.GetOutputs()) { - stream << " " << (output == nullptr ? "" : output->Name()); - } - stream << "\n"; - stream << "===== End NCHWC graph dump =====\n"; - - const std::string text = stream.str(); - std::cerr << text; - - const char* file_path = std::getenv("ORT_NCHWC_TEST_DUMP_GRAPH_FILE"); - if (file_path != nullptr && file_path[0] != '\0') { - std::ofstream output_stream(file_path, std::ios::app); - if (output_stream.is_open()) { - output_stream << text; - } - } -} - -} // namespace - struct NchwcTestHelper { NchwcTestHelper(Graph& graph) : graph_(graph), fill_value_(0), per_sample_tolerance_(0.0) { } @@ -786,29 +722,26 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, channels, 25, 21}); auto* conv_output_arg = helper.MakeIntermediate(); + auto* quickgelu_output_arg = helper.MakeIntermediate(); auto* output_arg = helper.MakeOutput(); helper.AddConvNode(input_arg, conv_output_arg, {channels, channels, 3, 3}); + helper.AddNode("QuickGelu", {conv_output_arg}, {quickgelu_output_arg}, kMSDomain); const std::vector scale_shape = use_explicit_batch_dim ? std::vector{1, channels, 1, 1} : std::vector{channels, 1, 1}; auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); if (scale_first) { - helper.AddNode("Mul", {scale_arg, conv_output_arg}, {output_arg}); + helper.AddNode("Mul", {scale_arg, quickgelu_output_arg}, {output_arg}); } else { - helper.AddNode("Mul", {conv_output_arg, scale_arg}, {output_arg}); + helper.AddNode("Mul", {quickgelu_output_arg, scale_arg}, {output_arg}); } }; auto check_pre_optimization_graph = [&](const Graph& graph) { auto op_to_count = CountOpsInGraph(graph); - if (op_to_count["Conv"] != 1 || op_to_count["Mul"] != 1 || - op_to_count["com.microsoft.nchwc.Conv"] != 0 || - op_to_count["com.microsoft.nchwc.ReorderInput"] != 0 || - op_to_count["com.microsoft.nchwc.ReorderOutput"] != 0) { - DumpNchwcOptimizerGraph(graph, "NchwcOptimizerTests.ConvMulChannelScale", "pre_optimization_failure"); - } EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count["com.microsoft.QuickGelu"], 1); EXPECT_EQ(op_to_count["Mul"], 1); EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 0); @@ -817,12 +750,10 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - if (op_to_count["com.microsoft.nchwc.Conv"] != 2 || op_to_count["Mul"] != 0) { - DumpNchwcOptimizerGraph(session.GetGraph(), "NchwcOptimizerTests.ConvMulChannelScale", "post_optimization_failure"); - } EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); - //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); - //EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.QuickGelu"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); EXPECT_EQ(op_to_count["Mul"], 0); }; From 0ecd1eda3be2dca0f932a756856f60052e6c8940 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 19 Mar 2026 11:48:11 -0700 Subject: [PATCH 11/16] Add HardSigmoid --- .../core/optimizer/nchwc_transformer.cc | 1 + .../test/optimizer/nchwc_optimizer_test.cc | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index f1a70c7bf3d99..45044fdea2c1d 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -1265,6 +1265,7 @@ void NchwcTransformerImpl::Transform(Node& node) { } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6, 13, 22}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) || diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index e9ea29b8c07f5..e0378453f1d69 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -767,6 +767,60 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { test_case(true, true); } +TEST(NchwcOptimizerTests, ConvMulChannelScaleHardSigmoid) { + const int64_t channels = static_cast(MlasNchwcGetBlockSize()) * 2; + + auto test_case = [&](bool use_explicit_batch_dim, bool scale_first) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, channels, 25, 21}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* hardsigmoid_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {channels, channels, 3, 3}); + auto& hardsigmoid_node = helper.AddNode("HardSigmoid", {conv_output_arg}, {hardsigmoid_output_arg}); + hardsigmoid_node.AddAttribute("alpha", 0.125f); + hardsigmoid_node.AddAttribute("beta", 0.4f); + + const std::vector scale_shape = use_explicit_batch_dim + ? std::vector{1, channels, 1, 1} + : std::vector{channels, 1, 1}; + auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); + if (scale_first) { + helper.AddNode("Mul", {scale_arg, hardsigmoid_output_arg}, {output_arg}); + } else { + helper.AddNode("Mul", {hardsigmoid_output_arg, scale_arg}, {output_arg}); + } + }; + + auto check_pre_optimization_graph = [&](const Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count["HardSigmoid"], 1); + EXPECT_EQ(op_to_count["Mul"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 0); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 0); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["HardSigmoid"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["Mul"], 0); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph, 13, check_pre_optimization_graph); + }; + + test_case(false, false); + test_case(false, true); + test_case(true, false); + test_case(true, true); +} + TEST(NchwcOptimizerTests, ConvConcat) { auto test_case = [&](int axis, int channel_count, int reorder_output_count) { auto build_test_case = [&](NchwcTestHelper& helper) { From d9fa998ebed8888fe67165a8539d323547aba101 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 19 Mar 2026 12:46:59 -0700 Subject: [PATCH 12/16] Fix --- .../test/optimizer/nchwc_optimizer_test.cc | 85 +++++++++++++++++-- 1 file changed, 78 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index e0378453f1d69..ecf73c06febbc 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -769,6 +769,8 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { TEST(NchwcOptimizerTests, ConvMulChannelScaleHardSigmoid) { const int64_t channels = static_cast(MlasNchwcGetBlockSize()) * 2; + constexpr float alpha = 0.125f; + constexpr float beta = 0.4f; auto test_case = [&](bool use_explicit_batch_dim, bool scale_first) { auto build_test_case = [&](NchwcTestHelper& helper) { @@ -779,8 +781,8 @@ TEST(NchwcOptimizerTests, ConvMulChannelScaleHardSigmoid) { helper.AddConvNode(input_arg, conv_output_arg, {channels, channels, 3, 3}); auto& hardsigmoid_node = helper.AddNode("HardSigmoid", {conv_output_arg}, {hardsigmoid_output_arg}); - hardsigmoid_node.AddAttribute("alpha", 0.125f); - hardsigmoid_node.AddAttribute("beta", 0.4f); + hardsigmoid_node.AddAttribute("alpha", alpha); + hardsigmoid_node.AddAttribute("beta", beta); const std::vector scale_shape = use_explicit_batch_dim ? std::vector{1, channels, 1, 1} @@ -806,10 +808,33 @@ TEST(NchwcOptimizerTests, ConvMulChannelScaleHardSigmoid) { auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); - EXPECT_EQ(op_to_count["HardSigmoid"], 1); + EXPECT_EQ(op_to_count["HardSigmoid"], 0); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); EXPECT_EQ(op_to_count["Mul"], 0); + + size_t hard_sigmoid_fused_count = 0; + for (const auto& node : session.GetGraph().Nodes()) { + if (node.Domain() != kMSNchwcDomain || node.OpType() != "Conv") { + continue; + } + + const auto activation_it = node.GetAttributes().find("activation"); + if (activation_it == node.GetAttributes().end() || activation_it->second.s() != "HardSigmoid") { + continue; + } + + ++hard_sigmoid_fused_count; + + const auto activation_params_it = node.GetAttributes().find("activation_params"); + ASSERT_NE(activation_params_it, node.GetAttributes().end()); + const auto& params = activation_params_it->second.floats(); + ASSERT_EQ(params.size(), 2); + EXPECT_EQ(params.Get(0), alpha); + EXPECT_EQ(params.Get(1), beta); + } + + EXPECT_EQ(hard_sigmoid_fused_count, 1u); }; NchwcOptimizerTester(build_test_case, check_nchwc_graph, 13, check_pre_optimization_graph); @@ -821,6 +846,51 @@ TEST(NchwcOptimizerTests, ConvMulChannelScaleHardSigmoid) { test_case(true, true); } +TEST(NchwcOptimizerTests, ConvHardSigmoidTwoConsumers) { + constexpr float alpha = 0.125f; + constexpr float beta = 0.4f; + + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 48, 11, 15}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* hardsigmoid_output_arg = helper.MakeIntermediate(); + auto* add_output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {32, 48, 3, 3}); + auto& hardsigmoid_node = helper.AddNode("HardSigmoid", {conv_output_arg}, {hardsigmoid_output_arg}); + hardsigmoid_node.AddAttribute("alpha", alpha); + hardsigmoid_node.AddAttribute("beta", beta); + helper.AddNode("Add", {conv_output_arg, hardsigmoid_output_arg}, {add_output_arg}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1); + EXPECT_EQ(op_to_count["HardSigmoid"], 1); + EXPECT_EQ(op_to_count["Add"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + + const Node* hard_sigmoid_node = nullptr; + const Node* nchwc_conv_node = nullptr; + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "HardSigmoid") { + hard_sigmoid_node = &node; + } else if (node.Domain() == kMSNchwcDomain && node.OpType() == "Conv") { + nchwc_conv_node = &node; + } + } + + ASSERT_NE(hard_sigmoid_node, nullptr); + ASSERT_NE(nchwc_conv_node, nullptr); + ASSERT_EQ(hard_sigmoid_node->GetInputEdgesCount(), 1u); + EXPECT_EQ(&*hard_sigmoid_node->InputNodesBegin(), nchwc_conv_node); + EXPECT_EQ(nchwc_conv_node->GetOutputEdgesCount(), 2u); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); +} + TEST(NchwcOptimizerTests, ConvConcat) { auto test_case = [&](int axis, int channel_count, int reorder_output_count) { auto build_test_case = [&](NchwcTestHelper& helper) { @@ -1485,13 +1555,14 @@ TEST(NchwcOptimizerTests, Activation) { }; // Verify that the optimizer doesn't add reorders for these activations in - // this pattern. Relu/Sigmoid/Tanh are generally fusable with a preceding - // convolution, but not here because the Conv output is consumed both by the - // activation node and directly by the Add node. Gelu/QuickGelu are also - // expected to remain as separate nodes. + // this pattern. Relu/Sigmoid/Tanh/HardSigmoid are generally fusable with a + // preceding convolution, but not here because the Conv output is consumed + // both by the activation node and directly by the Add node. Gelu/QuickGelu + // are also expected to remain as separate nodes. test_case("Relu"); test_case("Sigmoid"); test_case("Tanh"); + test_case("HardSigmoid"); test_case("Gelu", kMSDomain); test_case("QuickGelu", kMSDomain); } From fcbdaf6118d5e3784ee214b3c0574d9abb0c9909 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 19:13:20 -0700 Subject: [PATCH 13/16] Remove HardSigmoid related changes as it introduces some complications --- .../core/optimizer/nchwc_transformer.cc | 1 - .../test/optimizer/nchwc_optimizer_test.cc | 127 +----------------- 2 files changed, 1 insertion(+), 127 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 45044fdea2c1d..f1a70c7bf3d99 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -1265,7 +1265,6 @@ void NchwcTransformerImpl::Transform(Node& node) { } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "HardSigmoid", {6, 13, 22}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) || diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index ecf73c06febbc..4a547d12a9ed2 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -767,130 +767,6 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { test_case(true, true); } -TEST(NchwcOptimizerTests, ConvMulChannelScaleHardSigmoid) { - const int64_t channels = static_cast(MlasNchwcGetBlockSize()) * 2; - constexpr float alpha = 0.125f; - constexpr float beta = 0.4f; - - auto test_case = [&](bool use_explicit_batch_dim, bool scale_first) { - auto build_test_case = [&](NchwcTestHelper& helper) { - auto* input_arg = helper.MakeInput({1, channels, 25, 21}); - auto* conv_output_arg = helper.MakeIntermediate(); - auto* hardsigmoid_output_arg = helper.MakeIntermediate(); - auto* output_arg = helper.MakeOutput(); - - helper.AddConvNode(input_arg, conv_output_arg, {channels, channels, 3, 3}); - auto& hardsigmoid_node = helper.AddNode("HardSigmoid", {conv_output_arg}, {hardsigmoid_output_arg}); - hardsigmoid_node.AddAttribute("alpha", alpha); - hardsigmoid_node.AddAttribute("beta", beta); - - const std::vector scale_shape = use_explicit_batch_dim - ? std::vector{1, channels, 1, 1} - : std::vector{channels, 1, 1}; - auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); - if (scale_first) { - helper.AddNode("Mul", {scale_arg, hardsigmoid_output_arg}, {output_arg}); - } else { - helper.AddNode("Mul", {hardsigmoid_output_arg, scale_arg}, {output_arg}); - } - }; - - auto check_pre_optimization_graph = [&](const Graph& graph) { - auto op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Conv"], 1); - EXPECT_EQ(op_to_count["HardSigmoid"], 1); - EXPECT_EQ(op_to_count["Mul"], 1); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 0); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 0); - }; - - auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); - EXPECT_EQ(op_to_count["HardSigmoid"], 0); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); - EXPECT_EQ(op_to_count["Mul"], 0); - - size_t hard_sigmoid_fused_count = 0; - for (const auto& node : session.GetGraph().Nodes()) { - if (node.Domain() != kMSNchwcDomain || node.OpType() != "Conv") { - continue; - } - - const auto activation_it = node.GetAttributes().find("activation"); - if (activation_it == node.GetAttributes().end() || activation_it->second.s() != "HardSigmoid") { - continue; - } - - ++hard_sigmoid_fused_count; - - const auto activation_params_it = node.GetAttributes().find("activation_params"); - ASSERT_NE(activation_params_it, node.GetAttributes().end()); - const auto& params = activation_params_it->second.floats(); - ASSERT_EQ(params.size(), 2); - EXPECT_EQ(params.Get(0), alpha); - EXPECT_EQ(params.Get(1), beta); - } - - EXPECT_EQ(hard_sigmoid_fused_count, 1u); - }; - - NchwcOptimizerTester(build_test_case, check_nchwc_graph, 13, check_pre_optimization_graph); - }; - - test_case(false, false); - test_case(false, true); - test_case(true, false); - test_case(true, true); -} - -TEST(NchwcOptimizerTests, ConvHardSigmoidTwoConsumers) { - constexpr float alpha = 0.125f; - constexpr float beta = 0.4f; - - auto build_test_case = [&](NchwcTestHelper& helper) { - auto* input_arg = helper.MakeInput({1, 48, 11, 15}); - auto* conv_output_arg = helper.MakeIntermediate(); - auto* hardsigmoid_output_arg = helper.MakeIntermediate(); - auto* add_output_arg = helper.MakeOutput(); - - helper.AddConvNode(input_arg, conv_output_arg, {32, 48, 3, 3}); - auto& hardsigmoid_node = helper.AddNode("HardSigmoid", {conv_output_arg}, {hardsigmoid_output_arg}); - hardsigmoid_node.AddAttribute("alpha", alpha); - hardsigmoid_node.AddAttribute("beta", beta); - helper.AddNode("Add", {conv_output_arg, hardsigmoid_output_arg}, {add_output_arg}); - }; - - auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1); - EXPECT_EQ(op_to_count["HardSigmoid"], 1); - EXPECT_EQ(op_to_count["Add"], 1); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); - EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); - - const Node* hard_sigmoid_node = nullptr; - const Node* nchwc_conv_node = nullptr; - for (const auto& node : session.GetGraph().Nodes()) { - if (node.OpType() == "HardSigmoid") { - hard_sigmoid_node = &node; - } else if (node.Domain() == kMSNchwcDomain && node.OpType() == "Conv") { - nchwc_conv_node = &node; - } - } - - ASSERT_NE(hard_sigmoid_node, nullptr); - ASSERT_NE(nchwc_conv_node, nullptr); - ASSERT_EQ(hard_sigmoid_node->GetInputEdgesCount(), 1u); - EXPECT_EQ(&*hard_sigmoid_node->InputNodesBegin(), nchwc_conv_node); - EXPECT_EQ(nchwc_conv_node->GetOutputEdgesCount(), 2u); - }; - - NchwcOptimizerTester(build_test_case, check_nchwc_graph); -} - TEST(NchwcOptimizerTests, ConvConcat) { auto test_case = [&](int axis, int channel_count, int reorder_output_count) { auto build_test_case = [&](NchwcTestHelper& helper) { @@ -1555,14 +1431,13 @@ TEST(NchwcOptimizerTests, Activation) { }; // Verify that the optimizer doesn't add reorders for these activations in - // this pattern. Relu/Sigmoid/Tanh/HardSigmoid are generally fusable with a + // this pattern. Relu/Sigmoid/Tanh are generally fusable with a // preceding convolution, but not here because the Conv output is consumed // both by the activation node and directly by the Add node. Gelu/QuickGelu // are also expected to remain as separate nodes. test_case("Relu"); test_case("Sigmoid"); test_case("Tanh"); - test_case("HardSigmoid"); test_case("Gelu", kMSDomain); test_case("QuickGelu", kMSDomain); } From 914ec4ddae3275a700bccfeab664a00e18089bb2 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 20 Mar 2026 19:40:47 -0700 Subject: [PATCH 14/16] Copilot comments --- .../core/optimizer/nchwc_transformer.cc | 2 +- .../test/optimizer/nchwc_optimizer_test.cc | 26 +++++++++++-------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index f1a70c7bf3d99..4e03450077718 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -786,7 +786,7 @@ void NchwcTransformerImpl::TransformMul(Node& node) { const size_t nchwc_block_size = MlasNchwcGetBlockSize(); const int64_t nchwc_channels = (channels + static_cast(nchwc_block_size) - 1) & ~static_cast(nchwc_block_size - 1); - InlinedVector padded_scale(gsl::narrow(nchwc_channels)); + InlinedVector padded_scale(gsl::narrow(nchwc_channels), 1.0f); std::copy_n(mul_scale.data(), channels, padded_scale.data()); ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto; diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 4a547d12a9ed2..3581f3b07927d 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -716,20 +716,20 @@ TEST(NchwcOptimizerTests, ConvBinaryBroadcast) { } TEST(NchwcOptimizerTests, ConvMulChannelScale) { - const int64_t channels = static_cast(MlasNchwcGetBlockSize()) * 2; + const int64_t input_channels = static_cast(MlasNchwcGetBlockSize()) * 2; - auto test_case = [&](bool use_explicit_batch_dim, bool scale_first) { + auto test_case = [&](int64_t output_channels, bool use_explicit_batch_dim, bool scale_first) { auto build_test_case = [&](NchwcTestHelper& helper) { - auto* input_arg = helper.MakeInput({1, channels, 25, 21}); + auto* input_arg = helper.MakeInput({1, input_channels, 25, 21}); auto* conv_output_arg = helper.MakeIntermediate(); auto* quickgelu_output_arg = helper.MakeIntermediate(); auto* output_arg = helper.MakeOutput(); - helper.AddConvNode(input_arg, conv_output_arg, {channels, channels, 3, 3}); + helper.AddConvNode(input_arg, conv_output_arg, {output_channels, input_channels, 3, 3}); helper.AddNode("QuickGelu", {conv_output_arg}, {quickgelu_output_arg}, kMSDomain); const std::vector scale_shape = use_explicit_batch_dim - ? std::vector{1, channels, 1, 1} - : std::vector{channels, 1, 1}; + ? std::vector{1, output_channels, 1, 1} + : std::vector{output_channels, 1, 1}; auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); if (scale_first) { helper.AddNode("Mul", {scale_arg, quickgelu_output_arg}, {output_arg}); @@ -760,11 +760,15 @@ TEST(NchwcOptimizerTests, ConvMulChannelScale) { NchwcOptimizerTester(build_test_case, check_nchwc_graph, 13, check_pre_optimization_graph); }; - // Valid ONNX channel broadcast forms for NCHW tensors. - test_case(false, false); - test_case(false, true); - test_case(true, false); - test_case(true, true); + // Keep Conv input channels aligned so the initial NCHWc transform still + // applies, and vary only the logical output channel count to cover both the + // aligned path and the padded scale path in the Mul rewrite. + for (int64_t output_channels : {input_channels, input_channels + 1}) { + test_case(output_channels, false, false); + test_case(output_channels, false, true); + test_case(output_channels, true, false); + test_case(output_channels, true, true); + } } TEST(NchwcOptimizerTests, ConvConcat) { From d7979da4b400608274a30fcad67913d1cba37c0a Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sat, 21 Mar 2026 22:34:01 -0700 Subject: [PATCH 15/16] PR comment --- .../test/optimizer/nchwc_optimizer_test.cc | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 3581f3b07927d..7a72e919212e5 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -1446,6 +1446,47 @@ TEST(NchwcOptimizerTests, Activation) { test_case("QuickGelu", kMSDomain); } +TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) { + auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 48, 11, 15}); + auto* conv1_output_arg = helper.MakeIntermediate(); + auto* activation_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3}); + helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}, domain); + helper.AddConvNode(activation_output_arg, output_arg, {16, 32, 1, 1}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto& graph = session.GetGraph(); + auto op_to_count = CountOpsInGraph(graph); + const std::string activation_key = domain.empty() ? activation_op_type : domain + "." + activation_op_type; + + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count[activation_key], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Conv" && node.Domain() == kMSNchwcDomain) { + EXPECT_EQ(graph_utils::GetNodeAttribute(node, "activation"), nullptr) + << activation_op_type << " should not fuse into a single-consumer NCHWc Conv"; + } + } + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + // Gelu/QuickGelu must remain separate even with a single-consumer Conv input, + // because the NCHWc Conv activation fuse guard only allows a fixed subset of + // activations. + test_case("Gelu", kMSDomain); + test_case("QuickGelu", kMSDomain); +} + TEST(NchwcOptimizerTests, MaxPoolTypeCheck) { auto build_test_case = [&](NchwcTestHelper& helper) { auto add_pool_node = [&](NchwcTestHelper& helper, NodeArg* input_arg) { From ba935dd74d5a2012b216d5a7b20ccf2018e97087 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Sat, 21 Mar 2026 23:00:58 -0700 Subject: [PATCH 16/16] Fix build --- onnxruntime/test/optimizer/nchwc_optimizer_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 7a72e919212e5..6078660bf0d6e 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -1471,7 +1471,7 @@ TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) { for (const auto& node : graph.Nodes()) { if (node.OpType() == "Conv" && node.Domain() == kMSNchwcDomain) { - EXPECT_EQ(graph_utils::GetNodeAttribute(node, "activation"), nullptr) + EXPECT_EQ(node.GetAttributes().count("activation"), 0U) << activation_op_type << " should not fuse into a single-consumer NCHWc Conv"; } }