diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 2aae3383a1072..950355742193c 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5099,33 +5099,29 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionTest) { builder.AddNode("Identity", {concat_out}, {output}); }; - auto pre_graph_checker = [get_op_count](Graph& graph) { - const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); - TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); - TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); - return Status::OK(); - }; - - auto post_graph_checker = [get_op_count](Graph& graph) { + auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) { + const Graph& graph = session.GetGraph(); const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0); - TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0); - TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1); + ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0); + ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0); + ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1); for (const auto& node : graph.Nodes()) { if (node.OpType() == "SpaceToDepth") { const auto* blocksize_attr = graph_utils::GetNodeAttribute(node, "blocksize"); - TEST_RETURN_IF_NOT(blocksize_attr != nullptr && utils::HasInt(*blocksize_attr) && blocksize_attr->i() == 2); + ASSERT_TRUE(blocksize_attr != nullptr && utils::HasInt(*blocksize_attr) && blocksize_attr->i() == 2); } } - - return Status::OK(); }; - auto transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); + TransformerTester(build_test_case, + check_transformed_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 13, + 0.0, + 0.0, + std::make_unique()); } TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNodesTest) { @@ -5178,26 +5174,22 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNode builder.AddNode("Identity", {concat_out}, {output}); }; - auto pre_graph_checker = [get_op_count](Graph& graph) { - const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); - TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); - TEST_RETURN_IF_NOT(op_to_count.at("Constant") == 7); - TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); - return Status::OK(); - }; - - auto post_graph_checker = [get_op_count](Graph& graph) { + auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) { + const Graph& graph = session.GetGraph(); const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0); - TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0); - TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1); - return Status::OK(); + ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0); + ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0); + ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1); }; - auto transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); + TransformerTester(build_test_case, + check_transformed_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 13, + 0.0, + 0.0, + std::make_unique()); } TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBlockOrderTest) { @@ -5234,27 +5226,23 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBloc builder.AddNode("Identity", {concat_out}, {output}); }; - auto pre_graph_checker = [get_op_count](Graph& graph) { + auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) { + const Graph& graph = session.GetGraph(); const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); - TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); - TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); - TEST_RETURN_IF(get_op_count(op_to_count, "Gather") != 0); - return Status::OK(); - }; - - auto post_graph_checker = [get_op_count](Graph& graph) { - const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0); - TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0); - TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1); - TEST_RETURN_IF_NOT(get_op_count(op_to_count, "Gather") == 1); - return Status::OK(); + ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0); + ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0); + ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1); + ASSERT_EQ(get_op_count(op_to_count, "Gather"), 1); }; - auto transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); + TransformerTester(build_test_case, + check_transformed_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 13, + 0.0, + 0.0, + std::make_unique()); } TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForDynamicChannelPermutedBlockOrderTest) { @@ -6107,7 +6095,7 @@ static void BuildMobileClipAttentionTestCase(ModelTestBuilder& builder, builder.AddNode("Add", std::vector{input_skip, layer_scale_out}, std::vector{output}); } -static Status CheckMobileClipAttentionFusedGraph(Graph& graph) { +static Status CheckMobileClipAttentionFusedGraph(const Graph& graph) { auto op_to_count = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 1); TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); @@ -6116,14 +6104,13 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) { TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1); TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 2); - TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4); TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1); TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); int mha_nodes = 0; int gemm_nodes = 0; int split_nodes = 0; - for (Node& node : graph.Nodes()) { + for (const Node& node : graph.Nodes()) { if (node.OpType() == "MultiHeadAttention" && node.Domain() == kMSDomain) { ++mha_nodes; TEST_RETURN_IF_NOT(node.GetAttributes().at("num_heads").i() == 16); @@ -6170,16 +6157,24 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) { return Status::OK(); } -static Status CheckMobileClipAttentionFusedGraphOnProvider(Graph& graph, const char* provider) { +static Status CheckMobileClipAttentionFusedGraphOnProvider(const Graph& graph, const char* provider) { ORT_RETURN_IF_ERROR(CheckMobileClipAttentionFusedGraph(graph)); - for (Node& node : graph.Nodes()) { + for (const Node& node : graph.Nodes()) { TEST_RETURN_IF_NOT(node.GetExecutionProviderType() == provider); } return Status::OK(); } +static void CheckMobileClipAttentionFusedSession(InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMobileClipAttentionFusedGraph(session.GetGraph())); +} + +static void CheckMobileClipAttentionFusedCudaSession(InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMobileClipAttentionFusedGraphOnProvider(session.GetGraph(), kCudaExecutionProvider)); +} + static Status CheckMobileClipAttentionUnfusedProjectionGemmGraph(Graph& graph) { auto op_to_count = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0); @@ -6230,8 +6225,14 @@ TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) { BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); }; - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), - TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); + TransformerTester(build_test_case, + CheckMobileClipAttentionFusedSession, + TransformerLevel::Level1, + TransformerLevel::Level2, + 14, + 1e-3, + 0.0, + std::make_unique()); } TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) { @@ -6239,52 +6240,60 @@ TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); }; - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), - TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); + TransformerTester(build_test_case, + CheckMobileClipAttentionFusedSession, + TransformerLevel::Level1, + TransformerLevel::Level2, + 14, + 1e-3, + 0.0, + std::make_unique()); } TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaCudaEpTest) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA execution provider is not available"; + } + auto build_test_case = [](ModelTestBuilder& builder) { BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); }; - auto pre_graph_checker = [](Graph& graph) { - for (Node& node : graph.Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - - return Status::OK(); - }; - - auto post_graph_checker = [](Graph& graph) { - return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); - }; - - ASSERT_STATUS_OK(TestGraphTransformer( - build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), - TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); + TransformerTester(build_test_case, + CheckMobileClipAttentionFusedCudaSession, + TransformerLevel::Level1, + TransformerLevel::Level2, + 14, + 1e-3, + 0.0, + std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + {}, + {}, + std::move(cuda_ep)); } TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmCudaEpTest) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA execution provider is not available"; + } + auto build_test_case = [](ModelTestBuilder& builder) { BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); }; - auto pre_graph_checker = [](Graph& graph) { - for (Node& node : graph.Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - - return Status::OK(); - }; - - auto post_graph_checker = [](Graph& graph) { - return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); - }; - - ASSERT_STATUS_OK(TestGraphTransformer( - build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), - TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); + TransformerTester(build_test_case, + CheckMobileClipAttentionFusedCudaSession, + TransformerLevel::Level1, + TransformerLevel::Level2, + 14, + 1e-3, + 0.0, + std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + {}, + {}, + std::move(cuda_ep)); } TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaInvalidQkvWeightShapeTest) {