diff --git a/onnxruntime/core/optimizer/concat_slice_elimination.cc b/onnxruntime/core/optimizer/concat_slice_elimination.cc index b49bcc186e93d..245618744f03d 100644 --- a/onnxruntime/core/optimizer/concat_slice_elimination.cc +++ b/onnxruntime/core/optimizer/concat_slice_elimination.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/concat_slice_elimination.h" @@ -141,6 +143,19 @@ static bool GetSliceInfo(const Graph& graph, } else { return false; } + // Materialize defaults for optional axes/steps so callers can safely index them. + // This aligns with ONNX Slice defaults in the common case where starts/ends are + // provided for leading axes. + // Opset v1 : `axes` attribute is optional if absent it is empty + // Opset >= 10: if axes input doesn't exist `axes` stays empty + if (axes.empty()) { + axes.resize(starts.size()); + std::iota(axes.begin(), axes.end(), 0LL); + } + + if (steps.empty()) { + steps.assign(starts.size(), 1LL); + } return true; } @@ -219,7 +234,13 @@ bool ConcatSliceElimination::FuseConcatSliceSubgraph(Node& concat, Graph& graph, for (auto slice : concat_outputs) { InlinedVector starts, ends, axes, steps; if (!GetSliceInfo(graph, *slice, logger, starts, ends, axes, steps)) return false; - if (starts.size() > 1) return false; + // The code already enforces starts.size() == ends.size() (opset == 1 and opset >=10) + assert(starts.size() == ends.size()); + // This check must come before any axes/steps indexing + // Other starts sizes are valid for the Slice operator, + // but they are intentionally out of scope for this specific fusion. + // FuseConcatSliceSubgraph() is a very narrow, pattern-based optimization, not a general Slice normalizer. + if (starts.size() != 1) return false; if (axes[0] != 0) return false; if (steps[0] != 1) return false; auto iter = std::find(cumulative_input_len.begin(), cumulative_input_len.end(), starts[0]); diff --git a/onnxruntime/core/optimizer/unsqueeze_elimination.cc b/onnxruntime/core/optimizer/unsqueeze_elimination.cc index 1cfca99ebc031..b1a122dc1fe1c 100644 --- a/onnxruntime/core/optimizer/unsqueeze_elimination.cc +++ b/onnxruntime/core/optimizer/unsqueeze_elimination.cc @@ -8,6 +8,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph.h" #include "core/optimizer/initializer.h" +#include "core/providers/common.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -30,32 +31,41 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& return Status::OK(); } - auto num_axes = axes.size(); - auto output_rank = num_axes + tensor_proto.dims().size(); + const int64_t output_rank = narrow(axes.size() + tensor_proto.dims().size()); - // handle any negative axis values + // handle any negative axis values and validate range for (auto& axis : axes) { + if (!IsAxisInRange(axis, output_rank)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "'axes' has an out of range axis value ", axis, + " for output rank ", output_rank, + ". This is an invalid model. Node: ", node.Name()); + } if (axis < 0) { axis += output_rank; } } - // Generate new dims. - InlinedVector new_dims(output_rank, 0); + // Generate new dims. Mark axes positions with 1, fill the rest from input dims. + InlinedVector new_dims(narrow(output_rank), 0); for (int64_t axis : axes) { - if (static_cast(axis) >= new_dims.size()) { - LOGS(logger, WARNING) << "UnsqueezeElimination cannot remove node due to invalid axes" << node.Name(); - return Status::OK(); + const size_t idx = narrow(axis); + if (new_dims[idx] != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "'axes' has a duplicate axis value ", axis, + ". This is an invalid model. Node: ", node.Name()); } - new_dims[static_cast(axis)] = 1; + new_dims[idx] = 1; } auto begin = tensor_proto.dims().cbegin(); - for (auto& axis : new_dims) { - if (axis == 0) { - axis = *begin++; + for (auto& dim : new_dims) { + if (dim == 0) { + assert(begin != tensor_proto.dims().cend()); + dim = *begin++; } } + assert(begin == tensor_proto.dims().cend()); Initializer initializer(graph, tensor_proto, graph.ModelPath(), /*check_outer_scope=*/false); ONNX_NAMESPACE::TensorProto new_tensor_proto; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index d7780da36626c..d77cad48c618a 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2534,6 +2534,407 @@ TEST_F(GraphTransformationTests, NegativeFuseConvAddNoBias) { ASSERT_TRUE(op_to_count["Unsqueeze"] != 0); } +// Basic test: Unsqueeze with a single axis on a constant initializer is eliminated. +TEST_F(GraphTransformationTests, UnsqueezeElimination_BasicConstantInput) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{0}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + // Input shape [2, 3] with axes {0} => output shape [1, 2, 3]. + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 3); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 2); + TEST_RETURN_IF_NOT(shape->dim(2).dim_value() == 3); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Unsqueeze with multiple axes on a constant initializer is eliminated. +TEST_F(GraphTransformationTests, UnsqueezeElimination_MultipleAxes) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({4}, {1.0f, 2.0f, 3.0f, 4.0f}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{0, 2}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + // Input shape [4] with axes {0, 2} => output shape [1, 4, 1]. + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 3); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 4); + TEST_RETURN_IF_NOT(shape->dim(2).dim_value() == 1); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Unsqueeze with negative axis on a constant initializer is eliminated. +TEST_F(GraphTransformationTests, UnsqueezeElimination_NegativeAxis) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + // Input rank 2, axes {-1}: output rank = 3, -1 maps to axis 2. + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{-1}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + // Input shape [2, 3] with axes {-1} => output shape [2, 3, 1]. + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 3); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 3); + TEST_RETURN_IF_NOT(shape->dim(2).dim_value() == 1); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Unsqueeze on a scalar constant initializer is eliminated. +TEST_F(GraphTransformationTests, UnsqueezeElimination_ScalarInput) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({}, {42.0f}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{0, 1}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + // Scalar with axes {0, 1} => output shape [1, 1]. + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 2); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 1); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Unsqueeze whose input is a graph input (not constant) is NOT eliminated. +TEST_F(GraphTransformationTests, UnsqueezeElimination_NonConstantInputNotEliminated) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{2, 3}}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {input_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{0}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, checker, checker)); +} + +// Boundary test: axes at the valid extremes are correctly handled. +TEST_F(GraphTransformationTests, UnsqueezeElimination_AxisBoundaryValues) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({2}, {1.0f, 2.0f}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + // Input rank 1, axes {-3, 2}: output rank = 3. + // -3 is the minimum valid negative axis (maps to 0), 2 is the maximum valid positive axis. + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{-3, 2}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + // Input [2] with axes {-3, 2} => axes {0, 2} => output shape [1, 2, 1]. + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 3); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 2); + TEST_RETURN_IF_NOT(shape->dim(2).dim_value() == 1); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Unsqueeze whose output is directly a graph output is NOT eliminated +// because the generated initializer name won't match the graph output name. +TEST_F(GraphTransformationTests, UnsqueezeElimination_OutputIsGraphOutput) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({3}, {1.0f, 2.0f, 3.0f}); + auto* output_arg = builder.MakeOutput(); + + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {output_arg}); + unsqueeze_node.AddAttribute("axes", std::vector{0}); + }; + + auto checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, checker, checker)); +} + +// Unsqueeze with non-float data type (int32) constant initializer is eliminated. +TEST_F(GraphTransformationTests, UnsqueezeElimination_Int32Initializer) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({2, 2}, {10, 20, 30, 40}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{1}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + // Input shape [2, 2] with axes {1} => output shape [2, 1, 2]. + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 3); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(2).dim_value() == 2); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Only the Unsqueeze with a constant initializer input is eliminated in a mixed graph. +TEST_F(GraphTransformationTests, UnsqueezeElimination_MixedConstantAndNonConstant) { + auto build_test_case = [&](ModelTestBuilder& builder) { + // This one should be eliminated. + auto* initializer_arg = builder.MakeInitializer({3}, {1.0f, 2.0f, 3.0f}); + auto* unsqueeze_out_1 = builder.MakeIntermediate(); + auto& unsqueeze1 = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out_1}); + unsqueeze1.AddAttribute("axes", std::vector{0}); + + // This one should NOT be eliminated. + auto* graph_input = builder.MakeInput({{3}}); + auto* unsqueeze_out_2 = builder.MakeIntermediate(); + auto& unsqueeze2 = builder.AddNode("Unsqueeze", {graph_input}, {unsqueeze_out_2}); + unsqueeze2.AddAttribute("axes", std::vector{0}); + + auto* output_arg = builder.MakeOutput(); + builder.AddNode("Add", {unsqueeze_out_1, unsqueeze_out_2}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 2); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Unsqueeze with all negative axes. +TEST_F(GraphTransformationTests, UnsqueezeElimination_AllNegativeAxes) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({2}, {1.0f, 2.0f}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + // Input rank 1, axes {-1, -3}: output rank = 3. + // -1 maps to axis 2, -3 maps to axis 0. => output shape [1, 2, 1]. + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{-1, -3}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 3); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 2); + TEST_RETURN_IF_NOT(shape->dim(2).dim_value() == 1); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Unsqueeze inserting dimensions at multiple positions on a rank-1 input. +TEST_F(GraphTransformationTests, UnsqueezeElimination_ManyAxes) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* initializer_arg = builder.MakeInitializer({2}, {1.0f, 2.0f}); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + + // Input rank 1, axes {0, 2, 3}: output rank = 4 => shape [1, 2, 1, 1]. + auto& unsqueeze_node = builder.AddNode("Unsqueeze", {initializer_arg}, {unsqueeze_out}); + unsqueeze_node.AddAttribute("axes", std::vector{0, 2, 3}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Unsqueeze"] == 0); + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + auto* shape = node.InputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape != nullptr); + TEST_RETURN_IF_NOT(shape->dim_size() == 4); + TEST_RETURN_IF_NOT(shape->dim(0).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(1).dim_value() == 2); + TEST_RETURN_IF_NOT(shape->dim(2).dim_value() == 1); + TEST_RETURN_IF_NOT(shape->dim(3).dim_value() == 1); + } + } + return Status::OK(); + }; + + auto rule_transformer = std::make_unique("RuleTransformer"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(rule_transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// NOTE: Duplicate-axis and out-of-range axis error paths in UnsqueezeElimination::Apply +// are defense-in-depth guards. They cannot be exercised through ModelTestBuilder because +// ONNX schema validation during graph.Resolve() rejects such invalid models before the +// optimizer runs. + static void TestFuseConvAddMul(logging::Logger& logger, const PathChar* model_uri) { std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, logger)); @@ -4402,6 +4803,141 @@ TEST_F(GraphTransformationTests, ConcatSliceEliminationTest) { ASSERT_TRUE(op_to_count["Slice"] == 0); } +// Verifies that ConcatSliceElimination correctly defaults axes to {0} and steps to {1} +// when Slice nodes omit the optional axes/steps inputs (opset >= 10). +// Before the fix, GetSliceInfo returned empty axes/steps vectors leading to undefined behavior. +// After the fix, the defaults allow the fusion to succeed deterministically. +TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetGte10_MissingAxesAndSteps) { + auto build_test_case = [&](ModelTestBuilder& builder) { + // Three 1-D constant initializer inputs to Concat, each of size 2. + auto* init0 = builder.MakeInitializer({2}, {1.0f, 2.0f}); + auto* init1 = builder.MakeInitializer({2}, {3.0f, 4.0f}); + auto* init2 = builder.MakeInitializer({2}, {5.0f, 6.0f}); + + auto* concat_out = builder.MakeIntermediate(); + auto& concat_node = builder.AddNode("Concat", {init0, init1, init2}, {concat_out}); + concat_node.AddAttribute("axis", int64_t{0}); + + // Three Slice nodes with only starts and ends - axes and steps intentionally omitted. + // Per ONNX spec, missing axes defaults to [0,...,ndim-1] and missing steps defaults to [1,...,1]. + auto* starts0 = builder.MakeInitializer({1}, {int64_t{0}}); + auto* ends0 = builder.MakeInitializer({1}, {int64_t{2}}); + auto* slice0_out = builder.MakeIntermediate(); + builder.AddNode("Slice", {concat_out, starts0, ends0}, {slice0_out}); + + auto* starts1 = builder.MakeInitializer({1}, {int64_t{2}}); + auto* ends1 = builder.MakeInitializer({1}, {int64_t{4}}); + auto* slice1_out = builder.MakeIntermediate(); + builder.AddNode("Slice", {concat_out, starts1, ends1}, {slice1_out}); + + auto* starts2 = builder.MakeInitializer({1}, {int64_t{4}}); + auto* ends2 = builder.MakeInitializer({1}, {int64_t{6}}); + auto* slice2_out = builder.MakeIntermediate(); + builder.AddNode("Slice", {concat_out, starts2, ends2}, {slice2_out}); + + auto* lhs0 = builder.MakeInput({2}, {0.0f, 0.0f}); + auto* add0_out = builder.MakeOutput(); + builder.AddNode("Add", {lhs0, slice0_out}, {add0_out}); + + auto* lhs1 = builder.MakeInput({2}, {0.0f, 0.0f}); + auto* add1_out = builder.MakeOutput(); + builder.AddNode("Add", {lhs1, slice1_out}, {add1_out}); + + auto* lhs2 = builder.MakeInput({2}, {0.0f, 0.0f}); + auto* add2_out = builder.MakeOutput(); + builder.AddNode("Add", {lhs2, slice2_out}, {add2_out}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Concat"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 3); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + // With the fix, axes defaults to {0} and steps defaults to {1}, so the + // fusion pattern matches and Concat + all Slices are eliminated. + TEST_RETURN_IF_NOT(op_to_count["Concat"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 3); + return Status::OK(); + }; + + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + +// Same test for opset v1, where axes is an optional attribute on the Slice node. +// When the "axes" attribute is absent, GetSliceInfo must default axes to {0}. +TEST_F(GraphTransformationTests, ConcatSliceElimination_OpsetV1_MissingAxesAttribute) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* init0 = builder.MakeInitializer({2}, {1.0f, 2.0f}); + auto* init1 = builder.MakeInitializer({2}, {3.0f, 4.0f}); + auto* init2 = builder.MakeInitializer({2}, {5.0f, 6.0f}); + + auto* concat_out = builder.MakeIntermediate(); + auto& concat_node = builder.AddNode("Concat", {init0, init1, init2}, {concat_out}); + concat_node.AddAttribute("axis", int64_t{0}); + + // Opset v1 Slice: starts/ends are attributes, axes attribute intentionally omitted. + auto* slice0_out = builder.MakeIntermediate(); + auto& slice0 = builder.AddNode("Slice", {concat_out}, {slice0_out}); + slice0.AddAttribute("starts", std::vector{0}); + slice0.AddAttribute("ends", std::vector{2}); + // No "axes" attribute - triggers the vulnerable path. + + auto* slice1_out = builder.MakeIntermediate(); + auto& slice1 = builder.AddNode("Slice", {concat_out}, {slice1_out}); + slice1.AddAttribute("starts", std::vector{2}); + slice1.AddAttribute("ends", std::vector{4}); + + auto* slice2_out = builder.MakeIntermediate(); + auto& slice2 = builder.AddNode("Slice", {concat_out}, {slice2_out}); + slice2.AddAttribute("starts", std::vector{4}); + slice2.AddAttribute("ends", std::vector{6}); + + auto* lhs0 = builder.MakeInput({2}, {0.0f, 0.0f}); + auto* add0_out = builder.MakeOutput(); + builder.AddNode("Add", {lhs0, slice0_out}, {add0_out}); + + auto* lhs1 = builder.MakeInput({2}, {0.0f, 0.0f}); + auto* add1_out = builder.MakeOutput(); + builder.AddNode("Add", {lhs1, slice1_out}, {add1_out}); + + auto* lhs2 = builder.MakeInput({2}, {0.0f, 0.0f}); + auto* add2_out = builder.MakeOutput(); + builder.AddNode("Add", {lhs2, slice2_out}, {add2_out}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Concat"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 3); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Concat"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Slice"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 3); + return Status::OK(); + }; + + // Choose an opset < 10 to use the attribute-based Slice (v1-style) + // while also meeting Concat's opset requirements. + // Opset 4 satisfies this by providing Concat v4 and + // attribute-based Slice. + auto transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 4, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, ExpandElimination) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "expand_elimination.onnx"; std::shared_ptr model;