diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 5ab196fdf4980..c4aa2c522b6d8 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -45,6 +45,8 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" "${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/selector_action_transformer.cc" "${ONNXRUNTIME_ROOT}/core/optimizer/selectors_actions/selector_action_transformer.h" + "${ONNXRUNTIME_ROOT}/core/optimizer/slice_concat_to_space_to_depth_fusion.cc" + "${ONNXRUNTIME_ROOT}/core/optimizer/slice_concat_to_space_to_depth_fusion.h" # files required for layout transformation "${ONNXRUNTIME_ROOT}/core/optimizer/layout_transformation/layout_transformation.h" "${ONNXRUNTIME_ROOT}/core/optimizer/layout_transformation/layout_transformation.cc" @@ -136,4 +138,4 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() \ No newline at end of file +endif() diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 640848d47fe93..cee8f650d021b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -81,6 +81,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" +#include "core/optimizer/slice_concat_to_space_to_depth_fusion.h" #include "core/optimizer/transpose_optimizer.h" #include "core/optimizer/unsqueeze_elimination.h" #ifdef ENABLE_TRAINING @@ -260,7 +261,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique( session_options.free_dimension_overrides)); - + transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); diff --git a/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.cc b/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.cc new file mode 100644 index 0000000000000..f72f74e3b4a5c --- /dev/null +++ b/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.cc @@ -0,0 +1,596 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/slice_concat_to_space_to_depth_fusion.h" + +#include +#include +#include +#include + +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace { + +using IntValues = InlinedVector; + +struct SlicePhase { + int64_t h_offset; + int64_t w_offset; +}; + +struct NormalizedSliceParams { + std::array starts; + std::array ends; + std::array steps; +}; + +constexpr int64_t kRank = 4; +constexpr int64_t kChannelAxis = 1; +constexpr int64_t kHeightAxis = 2; +constexpr int64_t kWidthAxis = 3; +// This fusion currently only recognizes the common blocksize=2 pattern used by +// YOLO-style focus layers: 4 Slice nodes with offsets in {0,1}x{0,1}, step=2, +// followed by channel-axis Concat. The same idea generalizes to arbitrary +// blocksize b by matching b^2 Slice nodes with offsets in {0..b-1}x{0..b-1}, +// step=b, and extending the phase-permutation/channel-reorder logic +// accordingly. For now we intentionally keep the implementation limited to the +// blocksize=2 case. +constexpr int64_t kBlockSize = 2; + +int64_t NormalizeAxis(int64_t axis, int64_t rank) { + return axis < 0 ? axis + rank : axis; +} + +bool GetInitializerIntValues(const Graph& graph, const TensorProto* initializer, IntValues& values) { + if (initializer == nullptr || initializer->dims_size() != 1) { + return false; + } + + Initializer init(graph, *initializer, graph.ModelPath()); + if (initializer->data_type() == TensorProto::INT32) { + const int32_t* init_data = init.data(); + values.assign(init_data, init_data + init.size()); + return true; + } + + if (initializer->data_type() == TensorProto::INT64) { + const int64_t* init_data = init.data(); + values.assign(init_data, init_data + init.size()); + return true; + } + + return false; +} + +const Node* GetInputProducerNode(const Node& node, size_t input_index) { + const int input_arg_index = onnxruntime::narrow(input_index); + for (auto edge_it = node.InputEdgesBegin(), edge_end = node.InputEdgesEnd(); edge_it != edge_end; ++edge_it) { + if (edge_it->GetDstArgIndex() == input_arg_index) { + return &edge_it->GetNode(); + } + } + + return nullptr; +} + +Node* GetMutableInputProducerNode(Graph& graph, Node& node, size_t input_index) { + const Node* producer = GetInputProducerNode(node, input_index); + return producer == nullptr ? nullptr : graph.GetNode(producer->Index()); +} + +bool HasSingleOutputEdgeToNode(const Node& node, const Node& consumer) { + if (node.GetOutputEdgesCount() != 1) { + return false; + } + + const auto edge_it = node.OutputEdgesBegin(); + return edge_it != node.OutputEdgesEnd() && &edge_it->GetNode() == &consumer; +} + +bool GetConstantInputIntValues(const Graph& graph, const Node& node, size_t input_index, IntValues& values) { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = input_defs.size() > input_index ? input_defs[input_index] : nullptr; + if (input == nullptr || !input->Exists()) { + return false; + } + + if (const TensorProto* initializer = graph_utils::GetConstantInitializer(graph, input->Name()); initializer != nullptr) { + return GetInitializerIntValues(graph, initializer, values); + } + + const Node* producer = GetInputProducerNode(node, input_index); + if (producer == nullptr || producer->OpType() != "Constant" || producer->Domain() != kOnnxDomain) { + return false; + } + + const auto& attributes = producer->GetAttributes(); + const auto attr_it = attributes.find("value"); + if (attr_it == attributes.end() || attr_it->second.type() != AttributeProto_AttributeType_TENSOR) { + return false; + } + + return GetInitializerIntValues(graph, &attr_it->second.t(), values); +} + +bool GetSliceInfo(const Graph& graph, + const Node& node, + const logging::Logger& logger, + IntValues& starts, + IntValues& ends, + IntValues& axes, + IntValues& steps) { + ORT_UNUSED_PARAMETER(logger); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {10, 11, 13}, kOnnxDomain) || + graph.NodeProducesGraphOutput(node)) { + return false; + } + + auto get_input_if_exists = [&node](size_t input_idx) -> const NodeArg* { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = input_defs.size() > input_idx ? input_defs[input_idx] : nullptr; + return (input == nullptr || !input->Exists()) ? nullptr : input; + }; + + if (!GetConstantInputIntValues(graph, node, 1, starts) || + !GetConstantInputIntValues(graph, node, 2, ends) || + starts.empty() || starts.size() != ends.size()) { + return false; + } + + axes.clear(); + steps.clear(); + + if (const NodeArg* axes_input = get_input_if_exists(3); axes_input != nullptr) { + if (!GetConstantInputIntValues(graph, node, 3, axes) || axes.size() != starts.size()) { + return false; + } + } else { + axes.resize(starts.size()); + std::iota(axes.begin(), axes.end(), int64_t{0}); + } + + if (const NodeArg* steps_input = get_input_if_exists(4); steps_input != nullptr) { + if (!GetConstantInputIntValues(graph, node, 4, steps) || steps.size() != starts.size()) { + return false; + } + } else { + steps.assign(starts.size(), int64_t{1}); + } + + return true; +} + +bool IsSupportedSpaceToDepthInputType(const NodeArg& input) { + const auto* type_proto = input.TypeAsProto(); + if (type_proto == nullptr || !type_proto->has_tensor_type()) { + return false; + } + + const auto& tensor_type = type_proto->tensor_type(); + if (!tensor_type.has_shape() || tensor_type.shape().dim_size() != kRank) { + return false; + } + + const int32_t elem_type = tensor_type.elem_type(); + + // TODO(hasesh): Consider supporting float16 too ? + if (elem_type != TensorProto::FLOAT && elem_type != TensorProto::DOUBLE) { + return false; + } + + return true; +} + +bool TryGetStaticChannelCount(const NodeArg& input, int64_t& channel_count) { + const auto* type_proto = input.TypeAsProto(); + if (type_proto == nullptr || !type_proto->has_tensor_type()) { + return false; + } + + const auto& tensor_type = type_proto->tensor_type(); + if (!tensor_type.has_shape() || tensor_type.shape().dim_size() != kRank) { + return false; + } + + const auto& channel_dim = tensor_type.shape().dim(onnxruntime::narrow(kChannelAxis)); + if (!utils::HasDimValue(channel_dim) || channel_dim.dim_value() <= 0) { + return false; + } + + channel_count = channel_dim.dim_value(); + return true; +} + +bool TryGetStaticInputDim(const NodeArg& input, int64_t axis, int64_t& dim_value) { + const auto* type_proto = input.TypeAsProto(); + if (type_proto == nullptr || !type_proto->has_tensor_type()) { + return false; + } + + const auto& tensor_type = type_proto->tensor_type(); + if (!tensor_type.has_shape() || tensor_type.shape().dim_size() != kRank || axis < 0 || axis >= kRank) { + return false; + } + + const auto& dim = tensor_type.shape().dim(onnxruntime::narrow(axis)); + if (!utils::HasDimValue(dim) || dim.dim_value() <= 0) { + return false; + } + + dim_value = dim.dim_value(); + return true; +} + +bool IsFullExtentEnd(const NodeArg& input, int64_t axis, int64_t end) { + if (end == std::numeric_limits::max()) { + return true; + } + + if (end < 0) { + return false; + } + + int64_t dim_value = 0; + return TryGetStaticInputDim(input, axis, dim_value) && end >= dim_value; +} + +TypeProto MakeSpaceToDepthOutputTypeProto(const NodeArg& input) { + TypeProto output_type; + + const auto* input_type_proto = input.TypeAsProto(); + if (input_type_proto == nullptr) { + return output_type; + } + + output_type = *input_type_proto; + + if (!output_type.has_tensor_type()) { + return output_type; + } + + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + if (output_shape == nullptr || output_shape->dim_size() != kRank) { + return output_type; + } + + auto* channel_dim = output_shape->mutable_dim(onnxruntime::narrow(kChannelAxis)); + if (utils::HasDimValue(*channel_dim) && channel_dim->dim_value() > 0) { + channel_dim->set_dim_value(channel_dim->dim_value() * kBlockSize * kBlockSize); + } else { + channel_dim->clear_dim_value(); + channel_dim->clear_dim_param(); + } + + for (const int64_t axis : {kHeightAxis, kWidthAxis}) { + auto* dim = output_shape->mutable_dim(onnxruntime::narrow(axis)); + if (utils::HasDimValue(*dim) && dim->dim_value() > 0 && dim->dim_value() % kBlockSize == 0) { + dim->set_dim_value(dim->dim_value() / kBlockSize); + } else { + dim->clear_dim_value(); + dim->clear_dim_param(); + } + } + + return output_type; +} + +bool TryMatchSlicePhase(const Graph& graph, + const Node& slice, + const NodeArg& common_input, + const logging::Logger& logger, + NormalizedSliceParams& params, + SlicePhase& phase) { + if (slice.InputDefs().empty() || slice.InputDefs()[0] != &common_input) { + return false; + } + + IntValues starts; + IntValues ends; + IntValues axes; + IntValues steps; + if (!GetSliceInfo(graph, slice, logger, starts, ends, axes, steps)) { + return false; + } + + params.starts = {0, 0, 0, 0}; + params.ends = { + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max()}; + params.steps = {1, 1, 1, 1}; + std::array axis_seen{false, false, false, false}; + + for (size_t i = 0; i < starts.size(); ++i) { + const int64_t axis = NormalizeAxis(axes[i], kRank); + if (axis < 0 || axis >= kRank) { + return false; + } + + if (axis_seen[onnxruntime::narrow(axis)]) { + return false; + } + + axis_seen[onnxruntime::narrow(axis)] = true; + + params.starts[onnxruntime::narrow(axis)] = starts[i]; + params.ends[onnxruntime::narrow(axis)] = ends[i]; + params.steps[onnxruntime::narrow(axis)] = steps[i]; + } + + for (size_t axis = 0; axis < params.starts.size(); ++axis) { + if (params.starts[axis] < 0 || params.steps[axis] <= 0) { + return false; + } + } + + if (params.starts[0] != 0 || params.starts[1] != 0 || + params.steps[0] != 1 || params.steps[1] != 1 || + params.steps[kHeightAxis] != kBlockSize || params.steps[kWidthAxis] != kBlockSize) { + return false; + } + + if (!IsFullExtentEnd(common_input, 0, params.ends[0]) || + !IsFullExtentEnd(common_input, 1, params.ends[1]) || + !IsFullExtentEnd(common_input, kHeightAxis, params.ends[kHeightAxis]) || + !IsFullExtentEnd(common_input, kWidthAxis, params.ends[kWidthAxis])) { + return false; + } + + const int64_t h_offset = params.starts[kHeightAxis]; + const int64_t w_offset = params.starts[kWidthAxis]; + if ((h_offset != 0 && h_offset != 1) || (w_offset != 0 && w_offset != 1)) { + return false; + } + + phase = {h_offset, w_offset}; + return true; +} + +bool IsSingleConsumerOfConcat(const Node& slice, const Node& concat) { + return HasSingleOutputEdgeToNode(slice, concat); +} + +bool TryGetPhasePermutation(const std::array& actual_phases, + std::array& permutation) { + static constexpr std::array kCanonicalPhases{{{0, 0}, {0, 1}, {1, 0}, {1, 1}}}; + std::array used{false, false, false, false}; + + for (size_t i = 0; i < actual_phases.size(); ++i) { + bool matched = false; + for (size_t j = 0; j < kCanonicalPhases.size(); ++j) { + if (!used[j] && actual_phases[i].h_offset == kCanonicalPhases[j].h_offset && + actual_phases[i].w_offset == kCanonicalPhases[j].w_offset) { + permutation[i] = static_cast(j); + used[j] = true; + matched = true; + break; + } + } + + if (!matched) { + return false; + } + } + + return true; +} + +NodeArg* CreateInt64Initializer(Graph& graph, + const std::vector& values, + const std::string& name) { + ONNX_NAMESPACE::TensorProto initializer; + initializer.set_name(name); + initializer.add_dims(onnxruntime::narrow(values.size())); + initializer.set_data_type(TensorProto::INT64); + utils::SetRawDataInTensorProto(initializer, + reinterpret_cast(values.data()), + values.size() * sizeof(int64_t)); + return &graph_utils::AddInitializerWithOrtValue(graph, initializer); +} + +bool FuseSliceConcatToSpaceToDepth(Node& concat, Graph& graph, const logging::Logger& logger) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {4, 11, 13}, kOnnxDomain) || + concat.InputDefs().size() != 4) { + return false; + } + + const auto* axis_attr = graph_utils::GetNodeAttribute(concat, "axis"); + if (axis_attr == nullptr || !utils::HasInt(*axis_attr)) { + return false; + } + + const int64_t concat_axis = NormalizeAxis(axis_attr->i(), kRank); + if (concat_axis != kChannelAxis) { + return false; + } + + Node* slice_nodes[4]{}; + const NodeArg* common_input = nullptr; + const auto& provider_type = concat.GetExecutionProviderType(); + NormalizedSliceParams reference_params{}; + std::array actual_phases{}; + + for (size_t i = 0; i < concat.InputDefs().size(); ++i) { + const NodeArg* concat_input = concat.InputDefs()[i]; + if (concat_input == nullptr || !concat_input->Exists()) { + return false; + } + + Node* slice = GetMutableInputProducerNode(graph, concat, i); + if (slice == nullptr || slice == &concat || slice->GetExecutionProviderType() != provider_type || + !IsSingleConsumerOfConcat(*slice, concat)) { + return false; + } + + if (i == 0) { + common_input = slice->InputDefs()[0]; + if (common_input == nullptr || !IsSupportedSpaceToDepthInputType(*common_input)) { + return false; + } + } + + ORT_ENFORCE(common_input != nullptr); + + NormalizedSliceParams current_params{}; + SlicePhase phase{}; + if (!TryMatchSlicePhase(graph, *slice, *common_input, logger, current_params, phase)) { + return false; + } + + actual_phases[i] = phase; + + if (i == 0) { + reference_params = current_params; + } else if (current_params.ends != reference_params.ends || + current_params.steps != reference_params.steps || + current_params.starts[0] != reference_params.starts[0] || + current_params.starts[1] != reference_params.starts[1]) { + return false; + } + + if (graph.NodeProducesGraphOutput(*slice)) { + return false; + } + + slice_nodes[i] = slice; + } + + std::array phase_permutation{}; + if (!TryGetPhasePermutation(actual_phases, phase_permutation)) { + return false; + } + + const bool is_canonical_order = phase_permutation == std::array{0, 1, 2, 3}; + int64_t channel_count = 0; + if (!is_canonical_order && !TryGetStaticChannelCount(*common_input, channel_count)) { + return false; + } + + InlinedVector space_to_depth_outputs; + if (is_canonical_order) { + space_to_depth_outputs = {}; + } else { + auto space_to_depth_output_type = MakeSpaceToDepthOutputTypeProto(*common_input); + space_to_depth_outputs.push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("space_to_depth_out"), &space_to_depth_output_type)); + } + + NodeArg* space_to_depth_input = graph.GetNodeArg(common_input->Name()); + + Node& space_to_depth = graph.AddNode(graph.GenerateNodeName("SpaceToDepth"), + "SpaceToDepth", + is_canonical_order ? "Fused Slice*4 + Concat into SpaceToDepth" + : "Fused Slice*4 + Concat into SpaceToDepth + channel permutation", + {space_to_depth_input}, + space_to_depth_outputs, + nullptr, + kOnnxDomain); + space_to_depth.AddAttribute("blocksize", kBlockSize); + space_to_depth.SetExecutionProviderType(provider_type); + + Node* replacement_end = &space_to_depth; + if (!is_canonical_order) { + InlinedVector gather_indices; + gather_indices.reserve(onnxruntime::narrow(channel_count * kBlockSize * kBlockSize)); + for (const int64_t source_block_index : phase_permutation) { + for (int64_t c = 0; c < channel_count; ++c) { + gather_indices.push_back(source_block_index * channel_count + c); + } + } + + NodeArg* gather_indices_arg = CreateInt64Initializer( + graph, + std::vector(gather_indices.begin(), gather_indices.end()), + graph.GenerateNodeArgName("space_to_depth_gather_indices")); + + Node& gather = graph.AddNode(graph.GenerateNodeName("Gather"), + "Gather", + "Reorder SpaceToDepth channels to preserve Slice+Concat block order", + {space_to_depth.MutableOutputDefs()[0], gather_indices_arg}, + {}, + nullptr, + kOnnxDomain); + gather.AddAttribute("axis", static_cast(kChannelAxis)); + gather.SetExecutionProviderType(provider_type); + graph.AddEdge(space_to_depth.Index(), gather.Index(), 0, 0); + replacement_end = &gather; + } + + // Explicitly transfer the shared data-input edge from the first Slice to + // SpaceToDepth. This avoids graph_utils::MoveAllNodeInputEdges(), which is + // not defined in extended minimal builds. + { + const auto data_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*slice_nodes[0], 0); + if (!data_input_edges.empty()) { + ORT_ENFORCE(data_input_edges.size() == 1, "Expected a single data input edge for Slice node."); + const auto& data_input_edge = data_input_edges[0]; + graph.AddEdge(data_input_edge.src_node, space_to_depth.Index(), data_input_edge.src_arg_index, 0); + } + } + + auto concat_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(concat); + replacement_end->MutableOutputDefs() = concat.MutableOutputDefs(); + + for (const auto& edge : concat_output_edges) { + graph.AddEdge(replacement_end->Index(), edge.dst_node, 0, edge.dst_arg_index); + } + + for (Node* node : {slice_nodes[0], slice_nodes[1], slice_nodes[2], slice_nodes[3], &concat}) { + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node->Index()); + } + + LOGS(logger, INFO) << "Fused Slice+Concat downsample pattern into " + << (is_canonical_order ? "SpaceToDepth" : "SpaceToDepth + Gather") + << " node sequence starting at: " << space_to_depth.Name(); + return true; +} + +} // namespace + +Status SliceConcatToSpaceToDepthFusion::ApplyImpl(Graph& graph, + bool& modified, + int graph_level, + const logging::Logger& logger) const { + bool local_modified = false; + + do { + local_modified = false; + + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) { + continue; + } + + Node& node = *p_node; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + continue; + } + + if (FuseSliceConcatToSpaceToDepth(node, graph, logger)) { + modified = true; + local_modified = true; + break; + } + } + } while (local_modified); + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.h b/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.h new file mode 100644 index 0000000000000..923bc938b08bf --- /dev/null +++ b/onnxruntime/core/optimizer/slice_concat_to_space_to_depth_fusion.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +class SliceConcatToSpaceToDepthFusion : public GraphTransformer { + public: + SliceConcatToSpaceToDepthFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("SliceConcatToSpaceToDepthFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index d77cad48c618a..18933e45b8922 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -73,6 +73,7 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" +#include "core/optimizer/slice_concat_to_space_to_depth_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/utils.h" @@ -4805,11 +4806,8 @@ TEST_F(GraphTransformationTests, ConcatSliceEliminationTest) { // Verifies that ConcatSliceElimination correctly defaults axes to {0} and steps to {1} // when Slice nodes omit the optional axes/steps inputs (opset >= 10). -// Before the fix, GetSliceInfo returned empty axes/steps vectors leading to undefined behavior. -// After the fix, the defaults allow the fusion to succeed deterministically. TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetGte10_MissingAxesAndSteps) { auto build_test_case = [&](ModelTestBuilder& builder) { - // Three 1-D constant initializer inputs to Concat, each of size 2. auto* init0 = builder.MakeInitializer({2}, {1.0f, 2.0f}); auto* init1 = builder.MakeInitializer({2}, {3.0f, 4.0f}); auto* init2 = builder.MakeInitializer({2}, {5.0f, 6.0f}); @@ -4818,8 +4816,6 @@ TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetGte10_MissingAxesAn auto& concat_node = builder.AddNode("Concat", {init0, init1, init2}, {concat_out}); concat_node.AddAttribute("axis", int64_t{0}); - // Three Slice nodes with only starts and ends - axes and steps intentionally omitted. - // Per ONNX spec, missing axes defaults to [0,...,ndim-1] and missing steps defaults to [1,...,1]. auto* starts0 = builder.MakeInitializer({1}, {int64_t{0}}); auto* ends0 = builder.MakeInitializer({1}, {int64_t{2}}); auto* slice0_out = builder.MakeIntermediate(); @@ -4858,8 +4854,6 @@ TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetGte10_MissingAxesAn auto post_graph_checker = [&](Graph& graph) { auto op_to_count = CountOpsInGraph(graph); - // With the fix, axes defaults to {0} and steps defaults to {1}, so the - // fusion pattern matches and Concat + all Slices are eliminated. TEST_RETURN_IF_NOT(op_to_count["Concat"] == 0); TEST_RETURN_IF_NOT(op_to_count["Slice"] == 0); TEST_RETURN_IF_NOT(op_to_count["Add"] == 3); @@ -4872,7 +4866,6 @@ TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetGte10_MissingAxesAn } // Same test for opset v1, where axes is an optional attribute on the Slice node. -// When the "axes" attribute is absent, GetSliceInfo must default axes to {0}. TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetV1_MissingAxesAttribute) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* init0 = builder.MakeInitializer({2}, {1.0f, 2.0f}); @@ -4883,12 +4876,10 @@ TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetV1_MissingAxesAttri auto& concat_node = builder.AddNode("Concat", {init0, init1, init2}, {concat_out}); concat_node.AddAttribute("axis", int64_t{0}); - // Opset v1 Slice: starts/ends are attributes, axes attribute intentionally omitted. auto* slice0_out = builder.MakeIntermediate(); auto& slice0 = builder.AddNode("Slice", {concat_out}, {slice0_out}); slice0.AddAttribute("starts", std::vector{0}); slice0.AddAttribute("ends", std::vector{2}); - // No "axes" attribute - triggers the vulnerable path. auto* slice1_out = builder.MakeIntermediate(); auto& slice1 = builder.AddNode("Slice", {concat_out}, {slice1_out}); @@ -4929,15 +4920,485 @@ TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetV1_MissingAxesAttri return Status::OK(); }; - // Choose an opset < 10 to use the attribute-based Slice (v1-style) - // while also meeting Concat's opset requirements. - // Opset 4 satisfies this by providing Concat v4 and - // attribute-based Slice. auto transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 4, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({1, 3, 8, 8}, -1.0f, 1.0f); + + auto* starts00 = builder.Make1DInitializer({0, 0}); + auto* ends00 = builder.Make1DInitializer({8, 8}); + auto* axes_hw = builder.Make1DInitializer({2, 3}); + auto* steps2 = builder.Make1DInitializer({2, 2}); + + auto* starts01 = builder.Make1DInitializer({0, 1}); + auto* ends01 = builder.Make1DInitializer({8, 8}); + + auto* starts10 = builder.Make1DInitializer({1, 0}); + auto* ends10 = builder.Make1DInitializer({8, 8}); + + auto* starts11 = builder.Make1DInitializer({1, 1}); + auto* ends11 = builder.Make1DInitializer({8, 8}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends00, axes_hw, steps2}, {slice00_out}); + builder.AddNode("Slice", {input, starts01, ends01, axes_hw, steps2}, {slice01_out}); + builder.AddNode("Slice", {input, starts10, ends10, axes_hw, steps2}, {slice10_out}); + builder.AddNode("Slice", {input, starts11, ends11, axes_hw, steps2}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice01_out, slice10_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0); + TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0); + TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1); + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "SpaceToDepth") { + const auto* blocksize_attr = graph_utils::GetNodeAttribute(node, "blocksize"); + TEST_RETURN_IF_NOT(blocksize_attr != nullptr && utils::HasInt(*blocksize_attr) && blocksize_attr->i() == 2); + } + } + + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNodesTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto make_int64_constant = [&](const std::vector& values) -> NodeArg* { + ONNX_NAMESPACE::TensorProto tensor_proto; + tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + if (!values.empty()) { + tensor_proto.add_dims(gsl::narrow(values.size())); + } + utils::SetRawDataInTensorProto(tensor_proto, + reinterpret_cast(values.data()), + values.size() * sizeof(int64_t)); + + NodeArg* output = builder.MakeIntermediate(); + tensor_proto.set_name(output->Name()); + builder.AddNode("Constant", {}, {output}).AddAttribute("value", tensor_proto); + return output; + }; + + auto* input = builder.MakeInput({1, 3, 8, 8}, -1.0f, 1.0f); + + auto* axes_hw = make_int64_constant({2, 3}); + auto* steps2 = make_int64_constant({2, 2}); + auto* ends = make_int64_constant({8, 8}); + + auto* starts00 = make_int64_constant({0, 0}); + auto* starts01 = make_int64_constant({0, 1}); + auto* starts10 = make_int64_constant({1, 0}); + auto* starts11 = make_int64_constant({1, 1}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends, axes_hw, steps2}, {slice00_out}); + builder.AddNode("Slice", {input, starts01, ends, axes_hw, steps2}, {slice01_out}); + builder.AddNode("Slice", {input, starts10, ends, axes_hw, steps2}, {slice10_out}); + builder.AddNode("Slice", {input, starts11, ends, axes_hw, steps2}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice01_out, slice10_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF_NOT(op_to_count.at("Constant") == 7); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0); + TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0); + TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBlockOrderTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({1, 3, 8, 8}, -1.0f, 1.0f); + + auto* axes_hw = builder.Make1DInitializer({2, 3}); + auto* steps2 = builder.Make1DInitializer({2, 2}); + auto* ends = builder.Make1DInitializer({std::numeric_limits::max(), std::numeric_limits::max()}); + + auto* starts00 = builder.Make1DInitializer({0, 0}); + auto* starts10 = builder.Make1DInitializer({1, 0}); + auto* starts01 = builder.Make1DInitializer({0, 1}); + auto* starts11 = builder.Make1DInitializer({1, 1}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends, axes_hw, steps2}, {slice00_out}); + builder.AddNode("Slice", {input, starts10, ends, axes_hw, steps2}, {slice10_out}); + builder.AddNode("Slice", {input, starts01, ends, axes_hw, steps2}, {slice01_out}); + builder.AddNode("Slice", {input, starts11, ends, axes_hw, steps2}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice10_out, slice01_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + TEST_RETURN_IF(get_op_count(op_to_count, "Gather") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0); + TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0); + TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1); + TEST_RETURN_IF_NOT(get_op_count(op_to_count, "Gather") == 1); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForDynamicChannelPermutedBlockOrderTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input = builder.MakeInput(std::optional>{{1, -1, 8, 8}}); + + auto* axes_hw = builder.Make1DInitializer({2, 3}); + auto* steps2 = builder.Make1DInitializer({2, 2}); + auto* ends = builder.Make1DInitializer({std::numeric_limits::max(), std::numeric_limits::max()}); + + auto* starts00 = builder.Make1DInitializer({0, 0}); + auto* starts10 = builder.Make1DInitializer({1, 0}); + auto* starts01 = builder.Make1DInitializer({0, 1}); + auto* starts11 = builder.Make1DInitializer({1, 1}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends, axes_hw, steps2}, {slice00_out}); + builder.AddNode("Slice", {input, starts10, ends, axes_hw, steps2}, {slice10_out}); + builder.AddNode("Slice", {input, starts01, ends, axes_hw, steps2}, {slice01_out}); + builder.AddNode("Slice", {input, starts11, ends, axes_hw, steps2}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice10_out, slice01_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + TEST_RETURN_IF(get_op_count(op_to_count, "Gather") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + TEST_RETURN_IF(get_op_count(op_to_count, "Gather") != 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForSpatialCropTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({1, 3, 8, 8}, -1.0f, 1.0f); + + auto* axes_hw = builder.Make1DInitializer({2, 3}); + auto* steps2 = builder.Make1DInitializer({2, 2}); + auto* ends = builder.Make1DInitializer({6, 8}); + + auto* starts00 = builder.Make1DInitializer({0, 0}); + auto* starts01 = builder.Make1DInitializer({0, 1}); + auto* starts10 = builder.Make1DInitializer({1, 0}); + auto* starts11 = builder.Make1DInitializer({1, 1}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends, axes_hw, steps2}, {slice00_out}); + builder.AddNode("Slice", {input, starts01, ends, axes_hw, steps2}, {slice01_out}); + builder.AddNode("Slice", {input, starts10, ends, axes_hw, steps2}, {slice10_out}); + builder.AddNode("Slice", {input, starts11, ends, axes_hw, steps2}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice01_out, slice10_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForChannelSliceTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({1, 3, 8, 8}, -1.0f, 1.0f); + + auto* axes_chw = builder.Make1DInitializer({1, 2, 3}); + auto* steps = builder.Make1DInitializer({1, 2, 2}); + auto* ends = builder.Make1DInitializer({2, 8, 8}); + + auto* starts00 = builder.Make1DInitializer({0, 0, 0}); + auto* starts01 = builder.Make1DInitializer({0, 0, 1}); + auto* starts10 = builder.Make1DInitializer({0, 1, 0}); + auto* starts11 = builder.Make1DInitializer({0, 1, 1}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends, axes_chw, steps}, {slice00_out}); + builder.AddNode("Slice", {input, starts01, ends, axes_chw, steps}, {slice01_out}); + builder.AddNode("Slice", {input, starts10, ends, axes_chw, steps}, {slice10_out}); + builder.AddNode("Slice", {input, starts11, ends, axes_chw, steps}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice01_out, slice10_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForUnknownRankInputTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input = builder.MakeInput(std::optional>{}); + + auto* axes_hw = builder.Make1DInitializer({2, 3}); + auto* steps2 = builder.Make1DInitializer({2, 2}); + auto* ends = builder.Make1DInitializer({8, 8}); + + auto* starts00 = builder.Make1DInitializer({0, 0}); + auto* starts01 = builder.Make1DInitializer({0, 1}); + auto* starts10 = builder.Make1DInitializer({1, 0}); + auto* starts11 = builder.Make1DInitializer({1, 1}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends, axes_hw, steps2}, {slice00_out}); + builder.AddNode("Slice", {input, starts01, ends, axes_hw, steps2}, {slice01_out}); + builder.AddNode("Slice", {input, starts10, ends, axes_hw, steps2}, {slice10_out}); + builder.AddNode("Slice", {input, starts11, ends, axes_hw, steps2}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice01_out, slice10_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForRank5InputTest) { + auto get_op_count = [](const OpCountMap& op_to_count, std::string_view op_type) { + const auto it = op_to_count.find(std::string(op_type)); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({1, 3, 2, 8, 8}, -1.0f, 1.0f); + + auto* axes_hw = builder.Make1DInitializer({2, 3}); + auto* steps2 = builder.Make1DInitializer({2, 2}); + auto* ends = builder.Make1DInitializer({2, 8}); + + auto* starts00 = builder.Make1DInitializer({0, 0}); + auto* starts01 = builder.Make1DInitializer({0, 1}); + auto* starts10 = builder.Make1DInitializer({1, 0}); + auto* starts11 = builder.Make1DInitializer({1, 1}); + + auto* slice00_out = builder.MakeIntermediate(); + auto* slice01_out = builder.MakeIntermediate(); + auto* slice10_out = builder.MakeIntermediate(); + auto* slice11_out = builder.MakeIntermediate(); + auto* concat_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Slice", {input, starts00, ends, axes_hw, steps2}, {slice00_out}); + builder.AddNode("Slice", {input, starts01, ends, axes_hw, steps2}, {slice01_out}); + builder.AddNode("Slice", {input, starts10, ends, axes_hw, steps2}, {slice10_out}); + builder.AddNode("Slice", {input, starts11, ends, axes_hw, steps2}, {slice11_out}); + builder.AddNode("Concat", {slice00_out, slice01_out, slice10_out, slice11_out}, {concat_out}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Identity", {concat_out}, {output}); + }; + + auto pre_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto post_graph_checker = [get_op_count](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4); + TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1); + TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, ExpandElimination) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "expand_elimination.onnx"; std::shared_ptr model;