Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,57 @@ static int64_t EstimateIdentityOutputSizeInBytes(const Node& node) {
return EstimateTensorSizeInBytes(*input_defs[0]);
}

// ConstantOfShape's output shape is determined by the values of its first input (a shape tensor),
// which is required to be a constant initializer for the node to be eligible for constant folding.
// Compute the output byte size directly from that initializer so we do not have to rely on ONNX
// shape inference having propagated the shape onto the output NodeArg (which is not guaranteed
// for all opsets, all shapes-as-initializers, or all build configurations).
static int64_t EstimateConstantOfShapeOutputSizeInBytes(const Node& node, const Graph& graph) {
const auto& input_defs = node.InputDefs();
if (input_defs.empty() || input_defs[0] == nullptr || !input_defs[0]->Exists()) {
return -1;
}

constexpr bool check_outer_scope = true;
const ONNX_NAMESPACE::TensorProto* shape_init =
graph.GetConstantInitializer(input_defs[0]->Name(), check_outer_scope);
if (shape_init == nullptr) {
return -1;
}

Initializer shape_data{graph, *shape_init, graph.ModelPath()};
if (shape_data.data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return -1;
}

SafeInt<int64_t> num_elements = 1;
for (int64_t dim : shape_data.DataAsSpan<int64_t>()) {
if (dim < 0) {
return -1; // Invalid shape value; let the kernel reject it.
}
Comment thread
tianleiwu marked this conversation as resolved.
num_elements *= dim;
}

// Determine the element size of the output. The ONNX spec for ConstantOfShape defaults the
// element type to float when the optional 'value' attribute is absent.
size_t element_size = sizeof(float);
const auto& attrs = node.GetAttributes();
auto it = attrs.find("value");
if (it != attrs.end() && it->second.type() == ONNX_NAMESPACE::AttributeProto::TENSOR) {
const auto elem_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
it->second.t().data_type());
const size_t es = GetElementSizeForConstantFolding(elem_type);
if (es != 0) {
element_size = es;
}
}

return SafeInt<int64_t>(num_elements) * static_cast<int64_t>(element_size);
Comment thread
xadupre marked this conversation as resolved.
Outdated
}

// Estimate the total output size in bytes for a node using shape inference results.
// Returns -1 if the output size cannot be estimated (e.g., unknown shapes or types).
static int64_t EstimateNodeOutputSizeInBytes(const Node& node) {
static int64_t EstimateNodeOutputSizeInBytes(const Node& node, const Graph& graph) {
if (node.OpType() == "Identity" && node.Domain().empty()) {
return EstimateIdentityOutputSizeInBytes(node);
}
Expand All @@ -260,6 +308,15 @@ static int64_t EstimateNodeOutputSizeInBytes(const Node& node) {
return EstimateUniqueOutputSizeInBytes(node);
}

if (node.OpType() == "ConstantOfShape" && node.Domain().empty()) {
const int64_t size = EstimateConstantOfShapeOutputSizeInBytes(node, graph);
if (size >= 0) {
return size;
}
// Fall through to the generic estimator if we could not derive a size from the input
// initializer (e.g., the shape input is not a recognizable constant initializer).
}

SafeInt<int64_t> total_size = 0;
for (const auto* output_def : node.OutputDefs()) {
if (!output_def->Exists()) {
Expand Down Expand Up @@ -391,7 +448,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if (max_output_size > 0) {
int64_t estimated_size = -1;
try {
estimated_size = EstimateNodeOutputSizeInBytes(*node);
estimated_size = EstimateNodeOutputSizeInBytes(*node, graph);
} catch (const std::exception&) {
// SafeInt overflow means the size is astronomically large - definitely skip
LOGS(logger, WARNING) << "Integer overflow while estimating output size of "
Expand Down
47 changes: 47 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,53 @@ TEST_F(GraphTransformationTests, ConstantFoldingConfiguredLimitBlocksLargeConsta
pre_graph_checker, post_graph_checker));
}

// Verify that ConstantOfShape output size is estimated directly from the input shape
// initializer (and not just from shape inference) so that excessive constant-folded
Comment thread
tianleiwu marked this conversation as resolved.
// allocations are blocked before kernel execution. The 'value' attribute uses int64
// elements (8 bytes) so that 100M * 8 = 800 MB exceeds the configured 256 MB cap.
TEST_F(GraphTransformationTests, ConstantFoldingConstantOfShapeUsesInputInitializerForSizeCheck) {
constexpr int64_t kLargeDim = 100 * 1024 * 1024; // 100M elements

auto build_model = [&](ModelTestBuilder& builder) {
auto* shape_data = builder.MakeInitializer<int64_t>({1}, {kLargeDim});
auto* output_arg = builder.MakeOutput();

auto& node = builder.AddNode("ConstantOfShape", {shape_data}, {output_arg});
ONNX_NAMESPACE::AttributeProto value_attr;
value_attr.set_name("value");
value_attr.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR);
auto* tensor = value_attr.mutable_t();
tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
tensor->add_dims(1);
tensor->add_int64_data(0);
node.AddAttributeProto(std::move(value_attr));
};

auto pre_graph_checker = [](Graph& graph) -> Status {
auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_to_count["ConstantOfShape"] == 1);
return Status::OK();
};

auto post_graph_checker = [](Graph& graph) -> Status {
auto op_to_count = CountOpsInGraph(graph);
// 800 MB output exceeds the 256 MB cap, so the node must not be folded
// (and crucially, the 800 MB allocation must not have happened during folding).
TEST_RETURN_IF_NOT(op_to_count["ConstantOfShape"] == 1);
return Status::OK();
};

std::unique_ptr<CPUExecutionProvider> e = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
ConfigOptions config_options;
ASSERT_STATUS_OK(config_options.AddConfigEntry(
kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes, "268435456")); // 256 MB

ASSERT_STATUS_OK(TestGraphTransformer(build_model, 14, *logger_,
std::make_unique<ConstantFolding>(*e.get(), false, config_options),
TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}

// Test that small constant folding still works with the size limit.
TEST_F(GraphTransformationTests, ConstantFoldingSmallOutputAllowed) {
// Build a model with a small Expand: scalar -> [4, 4] = 16 * 4 = 64 bytes.
Expand Down
Loading