Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion onnxruntime/core/optimizer/concat_slice_elimination.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <numeric>

#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/concat_slice_elimination.h"
Expand Down Expand Up @@ -141,6 +143,19 @@ static bool GetSliceInfo(const Graph& graph,
} else {
return false;
}
// Materialize defaults for optional axes/steps so callers can safely index them.
// This aligns with ONNX Slice defaults in the common case where starts/ends are
// provided for leading axes.
// Opset v1 : `axes` attribute is optional if absent it is empty
// Opset >= 10: if axes input doesn't exist `axes` stays empty
if (axes.empty()) {
axes.resize(starts.size());
std::iota(axes.begin(), axes.end(), 0LL);
}

if (steps.empty()) {
steps.assign(starts.size(), 1LL);
}
return true;
}

Expand Down Expand Up @@ -219,7 +234,13 @@ bool ConcatSliceElimination::FuseConcatSliceSubgraph(Node& concat, Graph& graph,
for (auto slice : concat_outputs) {
InlinedVector<int64_t> starts, ends, axes, steps;
if (!GetSliceInfo(graph, *slice, logger, starts, ends, axes, steps)) return false;
if (starts.size() > 1) return false;
// The code already enforces starts.size() == ends.size() (opset == 1 and opset >=10)
assert(starts.size() == ends.size());
// This check must come before any axes/steps indexing
// Other starts sizes are valid for the Slice operator,
// but they are intentionally out of scope for this specific fusion.
// FuseConcatSliceSubgraph() is a very narrow, pattern-based optimization, not a general Slice normalizer.
if (starts.size() != 1) return false;
if (axes[0] != 0) return false;
if (steps[0] != 1) return false;
auto iter = std::find(cumulative_input_len.begin(), cumulative_input_len.end(), starts[0]);
Expand Down
34 changes: 22 additions & 12 deletions onnxruntime/core/optimizer/unsqueeze_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/graph/graph_utils.h"
#include "core/graph/graph.h"
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
Expand All @@ -30,32 +31,41 @@ Status UnsqueezeElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect&
return Status::OK();
}

auto num_axes = axes.size();
auto output_rank = num_axes + tensor_proto.dims().size();
const int64_t output_rank = narrow<int64_t>(axes.size() + tensor_proto.dims().size());

// handle any negative axis values
// handle any negative axis values and validate range
for (auto& axis : axes) {
if (!IsAxisInRange(axis, output_rank)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"'axes' has an out of range axis value ", axis,
" for output rank ", output_rank,
". This is an invalid model. Node: ", node.Name());
}
if (axis < 0) {
axis += output_rank;
}
}

// Generate new dims.
InlinedVector<int64_t> new_dims(output_rank, 0);
// Generate new dims. Mark axes positions with 1, fill the rest from input dims.
InlinedVector<int64_t> new_dims(narrow<size_t>(output_rank), 0);
for (int64_t axis : axes) {
if (static_cast<size_t>(axis) >= new_dims.size()) {
LOGS(logger, WARNING) << "UnsqueezeElimination cannot remove node due to invalid axes" << node.Name();
return Status::OK();
const size_t idx = narrow<size_t>(axis);
if (new_dims[idx] != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"'axes' has a duplicate axis value ", axis,
". This is an invalid model. Node: ", node.Name());
}
new_dims[static_cast<size_t>(axis)] = 1;
new_dims[idx] = 1;
}

auto begin = tensor_proto.dims().cbegin();
for (auto& axis : new_dims) {
if (axis == 0) {
axis = *begin++;
for (auto& dim : new_dims) {
if (dim == 0) {
assert(begin != tensor_proto.dims().cend());
dim = *begin++;
}
}
assert(begin == tensor_proto.dims().cend());

Initializer initializer(graph, tensor_proto, graph.ModelPath(), /*check_outer_scope=*/false);
ONNX_NAMESPACE::TensorProto new_tensor_proto;
Expand Down
Loading
Loading