Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,12 @@ static std::unique_ptr<api::NodeRef> MakeDequantizeOp(api::GraphRef& graph, std:
return node;
}

// Returns whether perm is a valid permutation (contains each value from 0 to perm.size() - 1 exactly once)
// Returns whether perm is a non-empty valid permutation (rank > 0 and contains each value from 0 to perm.size() - 1 exactly once)
static bool IsValidPerm(const std::vector<int64_t>& perm) {
size_t rank = perm.size();
if (rank == 0) {
return false;
Comment thread
prathikr marked this conversation as resolved.
}
int64_t rank_int = gsl::narrow_cast<int64_t>(rank);
std::vector<bool> used_dims(rank);
for (size_t i = 0; i < rank; ++i) {
Expand Down
49 changes: 49 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4501,6 +4501,55 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151_NegativeDQAxis) {
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}

// Regression test for a division-by-zero in Permute1DConstant when a Transpose node carries an empty perm
// attribute (rank-0 / scalar tensor). perm.size() == 0 caused a divide-by-zero before the fix.
// Verifies that session initialization completes without crashing when the optimizer encounters this graph.
TEST(TransposeOptimizerTests, RegressionTest_Permute1DConstantEmptyPerm) {
// Graph: scalar_input → Transpose(perm=[]) → Pad(pads const, shape=[0]) → output
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 13; // opset 13: Pad accepts pads as a named input
Model model("RegressionTest_Permute1DConstantEmptyPerm", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
DefaultLoggingManager().DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder builder(graph);

// Rank-0 scalar float input
auto* scalar_input = MakeInput<float>(builder, std::vector<int64_t>{}, std::vector<int64_t>{}, 0.0f, 1.0f);

// Transpose with empty perm — valid ONNX identity on a scalar
auto* transpose_out = builder.MakeIntermediate(); auto& transpose_node = builder.AddNode("Transpose", {scalar_input}, {transpose_out});
Comment thread
prathikr marked this conversation as resolved.
Outdated
transpose_node.AddAttribute("perm", std::vector<int64_t>{});

// Pad: empty pads for a rank-0 input
auto* pads_init = builder.MakeInitializer<int64_t>({0}, std::vector<int64_t>{});
auto* pad_out = builder.MakeOutput();
builder.AddNode("Pad", {transpose_out, pads_init}, {pad_out});

builder.SetGraphOutputs();
ASSERT_STATUS_OK(graph.Resolve());

// Serialize to an in-memory buffer so we can load it into a session
std::string model_data;
model.ToProto().SerializeToString(&model_data);

// Run with Level1 optimizations (transpose optimizer is active at Level1)
SessionOptions so;
so.graph_optimization_level = TransformerLevel::Level1;
so.session_logid = "TransposeOptimizerTests.RegressionTest_Permute1DConstantEmptyPerm";

InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast<int>(model_data.size())));

// The critical property is that Initialize() completes without crashing.
// It may succeed or return a graceful error — either is acceptable.
Status init_status = session.Initialize();
// Log the result so CI output is informative, but do not assert IsOK().
if (!init_status.IsOK()) {
GTEST_LOG_(INFO) << "Session initialization returned (non-crash) error: " << init_status.ErrorMessage();
}
Comment thread
prathikr marked this conversation as resolved.
Outdated
}

// These tests use the internal testing EP with static kernels which requires a full build and contrib ops,
// and the NHWC Conv which requires contrib ops
#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS)
Expand Down
Loading