diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index a02aa309a0bd2..20eec3a8e2e7a 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -43,6 +43,21 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { return data_type == actual_data_type; } +bool ConvFusionDataTypeCheck(const Node& conv_node) { + // TODO(hasesh): The CPU and CUDA EP only support float type for the Conv+Activation + // and the Conv+Add+Relu fusions. + // Assess the support level for the other compatible EPs and if they also + // only support float, remove the EP check altogether. + const std::string_view node_ep = conv_node.GetExecutionProviderType(); + if (node_ep == kCudaExecutionProvider || node_ep == kCpuExecutionProvider) { + if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { + return false; + } + } + + return true; +} + class ConvActivation : public NodeSelector { public: ConvActivation() = default; @@ -74,12 +89,12 @@ class ConvActivation : public NodeSelector { return false; }; + if (!ConvFusionDataTypeCheck(node)) { + return std::nullopt; + } + // check EP type and activation if (node_ep == kCudaExecutionProvider) { - if (!HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { - return std::nullopt; - } - if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } @@ -112,6 +127,10 @@ class ConvAddRelu : public NodeSelector { return std::nullopt; } + if (!ConvFusionDataTypeCheck(node)) { + return std::nullopt; + } + const auto* add_node = GetLoneConsumerNode(graph_viewer, node); if (!add_node || !graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {6, 7, 13, 14}) || diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1e9b65a9911b7..24cfa5e8122e9 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -714,6 +714,37 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph } +// Currently the ConvAddRelu fusion is only backed by a float kernel for the +// the CUDA EP. + +// When we see the corresponding pattern for the fp16 data type, the fusion +// should not be triggered as there is no kernel to back the fused pattern. + +// TODO(hasesh): Limit the test to using the CUDA EP for now as the level of +// data type support in other compatible EPs is still yet to be ascertained. + +// TODO(hasesh): If at all the fp16 type is supported for the fusion, adjust/remove +// this test. +TEST_F(GraphTransformationTests, FuseCudaConvAddRelu_UnsupportedType) { + auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Add"], 1); + ASSERT_EQ(op_to_count["Relu"], 1); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Add"], 1); // Add not removed from graph (fusion not triggered) + ASSERT_EQ(op_to_count["Relu"], 1); // Relu not removed from graph (fusion not triggered) +} + // Conv->Add->Relu will be left intact since there is Identity depend on Add TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) { auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/conv_add_relu_fp16.onnx b/onnxruntime/test/testdata/transform/fusion/conv_add_relu_fp16.onnx new file mode 100644 index 0000000000000..131b2bbec69eb --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/conv_add_relu_fp16.onnx @@ -0,0 +1,43 @@ +:¾ + +X +W +BC"Conv + +SY"Relu + +C +AS"AddgraphZ +X + + + + + +Z +W + + + + + +Z +B + + + +Z +A + + + + + +b +Y + + + + + +B \ No newline at end of file