Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 101 additions & 92 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5099,33 +5099,29 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionTest) {
builder.AddNode("Identity", {concat_out}, {output});
};

auto pre_graph_checker = [get_op_count](Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4);
TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1);
TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0);
return Status::OK();
};

auto post_graph_checker = [get_op_count](Graph& graph) {
auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
const Graph& graph = session.GetGraph();
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0);
TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0);
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1);
ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0);
ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0);
ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1);

for (const auto& node : graph.Nodes()) {
if (node.OpType() == "SpaceToDepth") {
const auto* blocksize_attr = graph_utils::GetNodeAttribute(node, "blocksize");
TEST_RETURN_IF_NOT(blocksize_attr != nullptr && utils::HasInt(*blocksize_attr) && blocksize_attr->i() == 2);
ASSERT_TRUE(blocksize_attr != nullptr && utils::HasInt(*blocksize_attr) && blocksize_attr->i() == 2);
}
}

return Status::OK();
};

auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
TransformerTester(build_test_case,
check_transformed_graph,
TransformerLevel::Default,
TransformerLevel::Level1,
13,
0.0,
0.0,
std::make_unique<SliceConcatToSpaceToDepthFusion>());
}

TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNodesTest) {
Expand Down Expand Up @@ -5178,26 +5174,22 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNode
builder.AddNode("Identity", {concat_out}, {output});
};

auto pre_graph_checker = [get_op_count](Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4);
TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1);
TEST_RETURN_IF_NOT(op_to_count.at("Constant") == 7);
TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0);
return Status::OK();
};

auto post_graph_checker = [get_op_count](Graph& graph) {
auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
const Graph& graph = session.GetGraph();
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0);
TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0);
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1);
return Status::OK();
ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0);
ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0);
ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1);
};

auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
TransformerTester(build_test_case,
check_transformed_graph,
TransformerLevel::Default,
TransformerLevel::Level1,
13,
0.0,
0.0,
std::make_unique<SliceConcatToSpaceToDepthFusion>());
}

TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBlockOrderTest) {
Expand Down Expand Up @@ -5234,27 +5226,23 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBloc
builder.AddNode("Identity", {concat_out}, {output});
};

auto pre_graph_checker = [get_op_count](Graph& graph) {
auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
const Graph& graph = session.GetGraph();
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4);
TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1);
TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0);
TEST_RETURN_IF(get_op_count(op_to_count, "Gather") != 0);
return Status::OK();
};

auto post_graph_checker = [get_op_count](Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0);
TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0);
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1);
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "Gather") == 1);
return Status::OK();
ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0);
ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0);
ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1);
ASSERT_EQ(get_op_count(op_to_count, "Gather"), 1);
};

auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
TransformerTester(build_test_case,
check_transformed_graph,
TransformerLevel::Default,
TransformerLevel::Level1,
13,
0.0,
0.0,
std::make_unique<SliceConcatToSpaceToDepthFusion>());
}

TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForDynamicChannelPermutedBlockOrderTest) {
Expand Down Expand Up @@ -6107,7 +6095,7 @@ static void BuildMobileClipAttentionTestCase(ModelTestBuilder& builder,
builder.AddNode("Add", std::vector<NodeArg*>{input_skip, layer_scale_out}, std::vector<NodeArg*>{output});
}

static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
static Status CheckMobileClipAttentionFusedGraph(const Graph& graph) {
auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 1);
TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1);
Expand All @@ -6116,14 +6104,13 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
TEST_RETURN_IF_NOT(op_to_count["Split"] == 1);
TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1);
TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 2);
TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4);
TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1);
TEST_RETURN_IF_NOT(op_to_count["Add"] == 1);

int mha_nodes = 0;
int gemm_nodes = 0;
int split_nodes = 0;
for (Node& node : graph.Nodes()) {
for (const Node& node : graph.Nodes()) {
if (node.OpType() == "MultiHeadAttention" && node.Domain() == kMSDomain) {
++mha_nodes;
TEST_RETURN_IF_NOT(node.GetAttributes().at("num_heads").i() == 16);
Expand Down Expand Up @@ -6170,16 +6157,24 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
return Status::OK();
}

static Status CheckMobileClipAttentionFusedGraphOnProvider(Graph& graph, const char* provider) {
static Status CheckMobileClipAttentionFusedGraphOnProvider(const Graph& graph, const char* provider) {
ORT_RETURN_IF_ERROR(CheckMobileClipAttentionFusedGraph(graph));

for (Node& node : graph.Nodes()) {
for (const Node& node : graph.Nodes()) {
TEST_RETURN_IF_NOT(node.GetExecutionProviderType() == provider);
}

return Status::OK();
}

static void CheckMobileClipAttentionFusedSession(InferenceSessionWrapper& session) {
ASSERT_STATUS_OK(CheckMobileClipAttentionFusedGraph(session.GetGraph()));
}

static void CheckMobileClipAttentionFusedCudaSession(InferenceSessionWrapper& session) {
ASSERT_STATUS_OK(CheckMobileClipAttentionFusedGraphOnProvider(session.GetGraph(), kCudaExecutionProvider));
}

static Status CheckMobileClipAttentionUnfusedProjectionGemmGraph(Graph& graph) {
auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0);
Expand Down Expand Up @@ -6230,61 +6225,75 @@ TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) {
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(),
TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph));
TransformerTester(build_test_case,
CheckMobileClipAttentionFusedSession,
TransformerLevel::Level1,
TransformerLevel::Level2,
14,
1e-3,
0.0,
std::make_unique<AttentionFusion>());
}

TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(),
TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph));
TransformerTester(build_test_case,
CheckMobileClipAttentionFusedSession,
TransformerLevel::Level1,
TransformerLevel::Level2,
14,
1e-3,
0.0,
std::make_unique<AttentionFusion>());
}

TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaCudaEpTest) {
auto cuda_ep = DefaultCudaExecutionProvider();
if (!cuda_ep) {
GTEST_SKIP() << "CUDA execution provider is not available";
}

auto build_test_case = [](ModelTestBuilder& builder) {
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd);
};

auto pre_graph_checker = [](Graph& graph) {
for (Node& node : graph.Nodes()) {
node.SetExecutionProviderType(kCudaExecutionProvider);
}

return Status::OK();
};

auto post_graph_checker = [](Graph& graph) {
return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider);
};

ASSERT_STATUS_OK(TestGraphTransformer(
build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker));
TransformerTester(build_test_case,
CheckMobileClipAttentionFusedCudaSession,
TransformerLevel::Level1,
TransformerLevel::Level2,
14,
1e-3,
0.0,
std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
{},
{},
std::move(cuda_ep));
}

TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmCudaEpTest) {
auto cuda_ep = DefaultCudaExecutionProvider();
if (!cuda_ep) {
GTEST_SKIP() << "CUDA execution provider is not available";
}

auto build_test_case = [](ModelTestBuilder& builder) {
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes);
};

auto pre_graph_checker = [](Graph& graph) {
for (Node& node : graph.Nodes()) {
node.SetExecutionProviderType(kCudaExecutionProvider);
}

return Status::OK();
};

auto post_graph_checker = [](Graph& graph) {
return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider);
};

ASSERT_STATUS_OK(TestGraphTransformer(
build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker));
TransformerTester(build_test_case,
CheckMobileClipAttentionFusedCudaSession,
TransformerLevel::Level1,
TransformerLevel::Level2,
14,
1e-3,
0.0,
std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
{},
{},
std::move(cuda_ep));
}

TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaInvalidQkvWeightShapeTest) {
Expand Down
Loading