diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index e8c3bf24a612f..4e03450077718 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -122,6 +122,7 @@ class NchwcTransformerImpl { void TransformConv(Node& node); void TransformPool(Node& node); void TransformBinary(Node& node, bool add_node); + void TransformMul(Node& node); void TransformConcat(Node& node); void TransformActivation(Node& node); void TransformBatchNormalization(Node& node); @@ -734,6 +735,89 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { } } +void NchwcTransformerImpl::TransformMul(Node& node) { + auto& input_defs = node.MutableInputDefs(); + auto& output_defs = node.MutableOutputDefs(); + + if (input_defs.size() != 2 || output_defs.size() != 1) { + return; + } + + auto* nchwc_input_0 = LookupNchwcArgument(input_defs[0]); + auto* nchwc_input_1 = LookupNchwcArgument(input_defs[1]); + + // If both inputs are already NCHWc arguments, use the regular binary path. + if (nchwc_input_0 != nullptr && nchwc_input_1 != nullptr) { + TransformBinary(node, false); + return; + } + + // Exactly one input must be NCHWc and the other must be a static scale tensor. + if ((nchwc_input_0 == nullptr) == (nchwc_input_1 == nullptr)) { + return; + } + + const int nchwc_input_index = (nchwc_input_0 != nullptr) ? 0 : 1; + const int scale_input_index = nchwc_input_index ^ 1; + auto* nchwc_input = (nchwc_input_index == 0) ? nchwc_input_0 : nchwc_input_1; + + const auto* mul_scale_tensor_proto = graph_utils::GetConstantInitializer(graph_, input_defs[scale_input_index]->Name()); + if (mul_scale_tensor_proto == nullptr || + mul_scale_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return; + } + + const int64_t channels = nchwc_input->channels_; + + Initializer mul_scale{graph_, *mul_scale_tensor_proto, graph_.ModelPath()}; + const auto scale_dims = mul_scale.dims(); + + bool channel_broadcast_shape = false; + if (scale_dims.size() == 3) { + channel_broadcast_shape = (scale_dims[0] == channels && scale_dims[1] == 1 && scale_dims[2] == 1); + } else if (scale_dims.size() == 4) { + channel_broadcast_shape = (scale_dims[0] == 1 && scale_dims[1] == channels && scale_dims[2] == 1 && scale_dims[3] == 1); + } + + if (!channel_broadcast_shape) { + return; + } + + const size_t nchwc_block_size = MlasNchwcGetBlockSize(); + const int64_t nchwc_channels = (channels + static_cast(nchwc_block_size) - 1) & ~static_cast(nchwc_block_size - 1); + + InlinedVector padded_scale(gsl::narrow(nchwc_channels), 1.0f); + std::copy_n(mul_scale.data(), channels, padded_scale.data()); + + ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto; + nchwc_conv_W_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + nchwc_conv_W_tensor_proto.set_name(graph_.GenerateNodeArgName("mul_scale")); + utils::SetRawDataInTensorProto(nchwc_conv_W_tensor_proto, padded_scale.data(), + gsl::narrow(nchwc_channels) * sizeof(float)); + nchwc_conv_W_tensor_proto.add_dims(nchwc_channels); + nchwc_conv_W_tensor_proto.add_dims(1); + nchwc_conv_W_tensor_proto.add_dims(1); + nchwc_conv_W_tensor_proto.add_dims(1); + + auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithOrtValue(graph_, nchwc_conv_W_tensor_proto); + + std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_mul_nchwc"); + Node& nchwc_node = graph_.AddNode(nchwc_node_name, + "Conv", + nchwc_node_name, + std::array{nchwc_input->nchwc_arg_, nchwc_conv_W_arg}, + output_defs, + nullptr, + kMSNchwcDomain); + nchwc_node.SetExecutionProviderType(kCpuExecutionProvider); + nchwc_node.AddAttribute("group", nchwc_channels); + + nchwc_input->remaining_original_uses_--; + + CreateNchwcArgument(node, nchwc_node, channels, nchwc_input->shape_); + removed_nodes_.push_front(node.Index()); +} + void NchwcTransformerImpl::TransformConcat(Node& node) { auto& input_defs = node.MutableInputDefs(); auto& output_defs = node.MutableOutputDefs(); @@ -794,7 +878,12 @@ void NchwcTransformerImpl::TransformActivation(Node& node) { // Check if this is a single use NCHWc convolution that hasn't already // been fused with another activation. auto& nchwc_node = nchwc_input->output_node_; + + const bool can_fuse_activation = (node.OpType() == "Relu") || + (node.OpType() == "Sigmoid") || + (node.OpType() == "Tanh"); if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) && + can_fuse_activation && (nchwc_input->starting_original_uses_ == 1) && (graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) { nchwc_node.AddAttribute("activation", node.OpType()); @@ -1159,6 +1248,11 @@ void NchwcTransformerImpl::Transform(Node& node) { } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10, 11, 12, 22}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "AveragePool", {1, 7, 10, 11, 19, 22})) { TransformPool(node); + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) { + // Mul can be converted when exactly one input is already NCHWc and the + // other input is a static channel-scale tensor, so it does not need to + // wait for all input edges to be removed. + TransformMul(node); } else if (node.GetInputEdgesCount() == 0 && node.InputDefs().size() != 0) { // The following transforms only run when the input edge count has already // been decremented to zero by earlier transforms. This is a hint that the @@ -1168,13 +1262,13 @@ void NchwcTransformerImpl::Transform(Node& node) { if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sum", {6, 8, 13})) { TransformBinary(node, true); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14})) { - TransformBinary(node, false); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11, 13})) { TransformConcat(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13})) { + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6, 13}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) { TransformActivation(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9, 14, 15})) { TransformBatchNormalization(node); diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 7dfa5c7812f6e..6078660bf0d6e 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -85,12 +85,15 @@ struct NchwcTestHelper { Node& AddNode(const std::string& op_type, const std::vector& input_args, - const std::vector& output_args) { + const std::vector& output_args, + const std::string& domain = kOnnxDomain) { return graph_.AddNode(graph_.GenerateNodeName("node"), op_type, "description", input_args, - output_args); + output_args, + nullptr, + domain); } Node& AddConvNode(NodeArg* input_arg, NodeArg* output_arg, const std::vector& weights_shape, bool no_bias = false) { @@ -168,7 +171,8 @@ struct NchwcTestHelper { void NchwcOptimizerTester(const std::function& build_test_case, const std::function& check_nchwc_graph, - int opset_version = 13) { + int opset_version = 13, + const std::function& check_pre_optimization_graph = nullptr) { // Ignore the test if NCHWc is not supported by the platform. if (MlasNchwcGetBlockSize() <= 1) { return; @@ -177,12 +181,17 @@ void NchwcOptimizerTester(const std::function& bu // Build the model for this test. std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = opset_version; + domain_to_version[kMSDomain] = 1; Model model("nchwc", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); NchwcTestHelper helper(model.MainGraph()); build_test_case(helper); ASSERT_STATUS_OK(model.MainGraph().Resolve()); + if (check_pre_optimization_graph) { + check_pre_optimization_graph(model.MainGraph()); + } + // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); @@ -706,6 +715,62 @@ TEST(NchwcOptimizerTests, ConvBinaryBroadcast) { } } +TEST(NchwcOptimizerTests, ConvMulChannelScale) { + const int64_t input_channels = static_cast(MlasNchwcGetBlockSize()) * 2; + + auto test_case = [&](int64_t output_channels, bool use_explicit_batch_dim, bool scale_first) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, input_channels, 25, 21}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* quickgelu_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {output_channels, input_channels, 3, 3}); + helper.AddNode("QuickGelu", {conv_output_arg}, {quickgelu_output_arg}, kMSDomain); + const std::vector scale_shape = use_explicit_batch_dim + ? std::vector{1, output_channels, 1, 1} + : std::vector{output_channels, 1, 1}; + auto* scale_arg = helper.MakeInitializer(scale_shape, helper.FillRandomData(scale_shape)); + if (scale_first) { + helper.AddNode("Mul", {scale_arg, quickgelu_output_arg}, {output_arg}); + } else { + helper.AddNode("Mul", {quickgelu_output_arg, scale_arg}, {output_arg}); + } + }; + + auto check_pre_optimization_graph = [&](const Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count["com.microsoft.QuickGelu"], 1); + EXPECT_EQ(op_to_count["Mul"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 0); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 0); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["com.microsoft.QuickGelu"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["Mul"], 0); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph, 13, check_pre_optimization_graph); + }; + + // Keep Conv input channels aligned so the initial NCHWc transform still + // applies, and vary only the logical output channel count to cover both the + // aligned path and the padded scale path in the Mul rewrite. + for (int64_t output_channels : {input_channels, input_channels + 1}) { + test_case(output_channels, false, false); + test_case(output_channels, false, true); + test_case(output_channels, true, false); + test_case(output_channels, true, true); + } +} + TEST(NchwcOptimizerTests, ConvConcat) { auto test_case = [&](int axis, int channel_count, int reorder_output_count) { auto build_test_case = [&](NchwcTestHelper& helper) { @@ -1342,7 +1407,7 @@ TEST(NchwcOptimizerTests, UpsampleLinear) { } TEST(NchwcOptimizerTests, Activation) { - auto test_case = [&](const std::string& activation_op_type) { + auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 48, 11, 15}); auto* conv1_output_arg = helper.MakeIntermediate(); @@ -1351,29 +1416,75 @@ TEST(NchwcOptimizerTests, Activation) { auto* output_arg = helper.MakeOutput(); helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3}); - helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}); + helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}, domain); helper.AddNode("Add", {conv1_output_arg, activation_output_arg}, {mul_output_arg}); helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1}); }; auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); + const std::string activation_key = domain.empty() ? activation_op_type : domain + "." + activation_op_type; EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); - EXPECT_EQ(op_to_count[activation_op_type], 1); + EXPECT_EQ(op_to_count[activation_key], 1); EXPECT_EQ(op_to_count["Add"], 1); }; NchwcOptimizerTester(build_test_case, check_nchwc_graph); }; - // Verify that the optimizer doesn't add reorders for these activations that - // cannot be fused with a convolution. - std::vector activation_op_types{"Relu", "Sigmoid", "Tanh"}; - for (auto& activation_op_type : activation_op_types) { - test_case(activation_op_type); - } + // Verify that the optimizer doesn't add reorders for these activations in + // this pattern. Relu/Sigmoid/Tanh are generally fusable with a + // preceding convolution, but not here because the Conv output is consumed + // both by the activation node and directly by the Add node. Gelu/QuickGelu + // are also expected to remain as separate nodes. + test_case("Relu"); + test_case("Sigmoid"); + test_case("Tanh"); + test_case("Gelu", kMSDomain); + test_case("QuickGelu", kMSDomain); +} + +TEST(NchwcOptimizerTests, ActivationSingleConsumerConvGuard) { + auto test_case = [&](const std::string& activation_op_type, const std::string& domain = kOnnxDomain) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 48, 11, 15}); + auto* conv1_output_arg = helper.MakeIntermediate(); + auto* activation_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3}); + helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}, domain); + helper.AddConvNode(activation_output_arg, output_arg, {16, 32, 1, 1}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto& graph = session.GetGraph(); + auto op_to_count = CountOpsInGraph(graph); + const std::string activation_key = domain.empty() ? activation_op_type : domain + "." + activation_op_type; + + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count[activation_key], 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Conv" && node.Domain() == kMSNchwcDomain) { + EXPECT_EQ(node.GetAttributes().count("activation"), 0U) + << activation_op_type << " should not fuse into a single-consumer NCHWc Conv"; + } + } + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + // Gelu/QuickGelu must remain separate even with a single-consumer Conv input, + // because the NCHWc Conv activation fuse guard only allows a fixed subset of + // activations. + test_case("Gelu", kMSDomain); + test_case("QuickGelu", kMSDomain); } TEST(NchwcOptimizerTests, MaxPoolTypeCheck) {