diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 84020d84c9e73..00ca25d0a6367 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -25,6 +25,7 @@ public struct OrtCompileApi public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel; public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc; public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; + public IntPtr ModelCompilationOptions_SetInputModel; } internal class NativeMethods @@ -136,6 +137,12 @@ public DOrtModelCompilationOptions_SetOutputModelWriteFunc public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModel( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* const OrtModel* */ inputModel); + public DOrtModelCompilationOptions_SetInputModel OrtModelCompilationOptions_SetInputModel; + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) { @@ -217,6 +224,11 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi _compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)); + OrtModelCompilationOptions_SetInputModel = + (DOrtModelCompilationOptions_SetInputModel)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModel, + typeof(DOrtModelCompilationOptions_SetInputModel)); + } } } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 77c2ff795e800..66e4443de06ad 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -8037,6 +8037,29 @@ struct OrtCompileApi { ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, _In_ OrtModelCompilationOptions* model_compile_options, _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); + + /** \brief Sets the OrtModel to compile. + * + * Sets an OrtModel created via the Model Editor API as the input for compilation. + * + * The input model's source (file path, memory buffer, or OrtModel) must be set with + * one of: ModelCompilationOptions_SetInputModelPath, ModelCompilationOptions_SetInputModelFromBuffer, + * or ModelCompilationOptions_SetInputModel. + * + * The OrtModel must have a complete graph with inputs, outputs, and nodes defined. + * The caller retains ownership of the OrtModel and must not release it until after + * CompileModel returns. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] model The OrtModel to compile. The model is borrowed (not copied or owned). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* model); }; /** diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 2c1d52894e7f3..8dae24a3bffe7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1612,6 +1612,8 @@ struct ModelCompilationOptions : detail::Base { ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel + + ModelCompilationOptions& SetInputModel(const OrtModel* model); ///< Wraps OrtCompileApi::ModelCompilationOptions_SetInputModel }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 745128fe6c7b4..bce2aa97d47cd 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1180,6 +1180,11 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLev return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const OrtModel* model) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModel(this->p_, model)); + return *this; +} + namespace detail { template diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 12127e9708255..54d26021d8c99 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -306,6 +306,27 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationL API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ const OrtModel* model) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: OrtModel pointer is null"); + } + + model_compile_options->SetInputModel(model); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(model); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* ort_model_compile_options) { API_IMPL_BEGIN @@ -343,6 +364,9 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, &OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, // End of Version 23 - DO NOT MODIFY ABOVE + + &OrtCompileAPI::ModelCompilationOptions_SetInputModel, + // End of Version 24 - DO NOT MODIFY ABOVE }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned @@ -350,6 +374,8 @@ static_assert(offsetof(OrtCompileApi, CompileModel) / sizeof(void*) == 8, "Size of version 22 Api cannot change"); // initial version in ORT 1.22 static_assert(offsetof(OrtCompileApi, ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc) / sizeof(void*) == 13, "Size of version 23 of Api cannot change"); +static_assert(offsetof(OrtCompileApi, ModelCompilationOptions_SetInputModel) / sizeof(void*) == 14, + "Size of version 24 of Api cannot change"); ORT_API(const OrtCompileApi*, OrtCompileAPI::GetCompileApi) { return &ort_compile_api; diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 34fa06340a7f9..e8f171ee24295 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -41,5 +41,8 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelWriteFunc, ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, _In_ OrtModelCompilationOptions* model_compile_options, _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* model); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 468dacc30c054..efaf28fbeefc0 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -45,7 +45,16 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da input_model_data_size_ = input_model_data_size; } +void ModelCompilationOptions::SetInputModel(const OrtModel* model) { + ResetInputModelSettings(); + input_model_ = model; +} + Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) { + if (output_model_path.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Output model path must not be empty."); + } + ConfigOptions& config_options = session_options_.value.config_options; epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; @@ -186,10 +195,19 @@ size_t ModelCompilationOptions::GetInputModelDataSize() const { return input_model_data_size_; } +bool ModelCompilationOptions::InputModelComesFromOrtModel() const { + return input_model_ != nullptr; +} + +const OrtModel* ModelCompilationOptions::GetInputModel() const { + return input_model_; +} + void ModelCompilationOptions::ResetInputModelSettings() { input_model_path_.clear(); input_model_data_ = nullptr; input_model_data_size_ = 0; + input_model_ = nullptr; } Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { @@ -229,16 +247,21 @@ Status ModelCompilationOptions::Check() const { // Check input model settings. const bool input_from_file = !input_model_path_.empty(); const bool input_from_memory = input_model_data_ != nullptr; + const bool input_from_model = input_model_ != nullptr; + + int input_source_count = (input_from_file ? 1 : 0) + + (input_from_memory ? 1 : 0) + + (input_from_model ? 1 : 0); - if (!input_from_file && !input_from_memory) { + if (input_source_count == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer"); + "Input model to compile must be specified via file path, memory buffer, or OrtModel"); } - if (input_from_file && input_from_memory) { + if (input_source_count > 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer, ", - "but not both."); + "Input model to compile must be specified via exactly one of: ", + "file path, memory buffer, or OrtModel"); } if (input_from_file && !std::filesystem::exists(input_model_path_)) { @@ -249,12 +272,45 @@ Status ModelCompilationOptions::Check() const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); } + // Validate OrtModel input + if (input_from_model) { + if (input_model_->graph == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "OrtModel has no graph. Call AddGraphToModel before compilation."); + } + + if (input_model_->graph->GetNumInputs() == 0 || input_model_->graph->GetNumOutputs() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "OrtModel graph must have at least one input and one output defined."); + } + + if (input_model_->domain_to_version.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "OrtModel must specify at least one opset domain/version."); + } + + // Note: Additional validation (node connections, schema) happens during + // Model::LoadFromModelEditorApiModel -> Graph::Resolve() + } + // Check output model settings. const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; bool has_no_output_model_location = std::holds_alternative( ep_context_gen_options.output_model_location); - if (has_no_output_model_location && input_from_file) { + // Determine if we can derive an output path from the input + bool can_derive_output_path = input_from_file; + + // For OrtModel input, check if model_path is set in the graph using the virtual GetModelPath() method + // (avoids dynamic_cast which requires RTTI) + if (input_from_model && input_model_->graph) { + const ORTCHAR_T* model_path_cstr = input_model_->graph->GetModelPath(); + if (model_path_cstr && model_path_cstr[0] != ORT_TSTR('\0')) { + can_derive_output_path = true; + } + } + + if (has_no_output_model_location && can_derive_output_path) { // User did not specify an output file, an output buffer, or an output write function. We default to generating an // output file with a name based on the input file name, so do not return an error. return Status::OK(); @@ -294,7 +350,13 @@ Status ModelCompilationOptions::Check() const { } std::string ModelCompilationOptions::GetInputSourceForTelemetry() const { - return InputModelComesFromFile() ? "file" : "buffer"; + if (InputModelComesFromFile()) { + return "file"; + } + if (InputModelComesFromOrtModel()) { + return "ort_model"; + } + return "buffer"; } std::string ModelCompilationOptions::GetOutputTargetForTelemetry() const { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 4ba8712a6c9c7..47529e794677e 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -10,6 +10,7 @@ #include "core/common/status.h" #include "core/common/path_string.h" #include "core/framework/allocator.h" +#include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -45,6 +46,14 @@ class ModelCompilationOptions { /// The size in bytes of the input model's buffer void SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size); + /// + /// Sets the OrtModel to compile. + /// The OrtModel is borrowed (not copied) - caller must keep it alive until CompileModel returns. + /// Overrides any previous call to SetInputModelPath(), SetInputModelFromBuffer(), or SetInputModel(). + /// + /// The OrtModel to compile + void SetInputModel(const OrtModel* model); + /// /// Sets the file path to store the output/compiled ONNX model. /// Overrides any previous call to SetOutputModelPath() or SetOutputModelBuffer(). @@ -132,6 +141,18 @@ class ModelCompilationOptions { /// true if input model comes from a file bool InputModelComesFromFile() const; + /// + /// Returns true if the input model comes from an OrtModel pointer. + /// + /// true if input model comes from an OrtModel + bool InputModelComesFromOrtModel() const; + + /// + /// Returns the OrtModel to compile, or nullptr if not set. + /// + /// pointer to the OrtModel or nullptr + const OrtModel* GetInputModel() const; + /// /// Returns the buffer that contains the bytes for the input ONNX model. /// Returns nullptr if the input model is not stored in a buffer. @@ -162,9 +183,9 @@ class ModelCompilationOptions { // Telemetry helper methods /// - /// Returns a string describing the input source type: "file" or "buffer". + /// Returns a string describing the input source type: "file", "buffer", or "ort_model". /// - /// "file" or "buffer" + /// "file", "buffer", or "ort_model" std::string GetInputSourceForTelemetry() const; /// @@ -205,6 +226,7 @@ class ModelCompilationOptions { std::filesystem::path input_model_path_; const void* input_model_data_ = nullptr; size_t input_model_data_size_ = 0; + const OrtModel* input_model_ = nullptr; // Borrowed pointer }; } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 7ad09e1a2cd5e..4fa26efb53ceb 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -22,6 +22,7 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #if !defined(ORT_MINIMAL_BUILD) +#include "core/graph/model_editor_api_types.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/plugin_ep/ep_library_plugin.h" @@ -288,6 +289,90 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op return nullptr; } +#if !defined(ORT_MINIMAL_BUILD) +// Overload of CreateSessionAndLoadModelImpl that takes an OrtModel* directly. +// This ensures load-path parity with file/buffer inputs by running the same checks +// (ORT_LOAD_CONFIG_FROM_MODEL, EP-context output validation, custom domain wiring). +static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* options, + const onnxruntime::Environment& env, + _In_ const OrtModel* model, + std::unique_ptr& sess) { + if (model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtModel pointer is null"); + } + + // Check EPContext model generation options - OrtModel has no file path by default, + // so we need explicit output location or embedded model path. + if (options) { + epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); + + if (ep_ctx_gen_options.enable) { + auto* output_model_path = ep_ctx_gen_options.TryGetOutputModelPath(); + + // Check if OrtModel has a model_path set + bool has_model_path = false; + if (model->graph) { + const ORTCHAR_T* model_path_cstr = model->graph->GetModelPath(); + has_model_path = model_path_cstr && model_path_cstr[0] != ORT_TSTR('\0'); + } + + // If there's no model path and no output location, fail early + if (!has_model_path && + (!ep_ctx_gen_options.HasOutputModelLocation() || + (output_model_path != nullptr && output_model_path->empty()))) { + return OrtApis::CreateStatus(ORT_FAIL, + "OrtModel has no model_path set and no valid output location was specified " + "for EPContext model generation. " + "SetOutputModelPath/SetOutputModelBuffer, or set the model_path on the " + "OrtGraph before adding it to OrtModel."); + } + } + } + + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env); + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Add custom domains + if (options && !options->custom_op_domains_.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + +#if !defined(ORT_MINIMAL_BUILD) + // Add custom domains for all OrtEpDevice instances to inference session. + // The custom domains should be registered before model load for ORT to validate the custom ops. + // This mirrors the same block in the file/buffer overload to maintain load-path parity. + if (options != nullptr && + options->provider_factories.empty() && + options->value.ep_selection_policy.enable) { + InlinedVector all_ep_custom_op_domains; + + for (const OrtEpDevice* ep_device : env.GetOrtEpDevices()) { + InlinedVector domains; + ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); + + for (auto domain : domains) { + if (ShouldAddDomain(domain, options->custom_op_domains_)) { + all_ep_custom_op_domains.push_back(domain); + } + } + } + + if (!all_ep_custom_op_domains.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains)); + } + } +#endif // !defined(ORT_MINIMAL_BUILD) + + // Load from OrtModel + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); + + return nullptr; +} +#endif // !defined(ORT_MINIMAL_BUILD) + // Creates an InferenceSession and loads the model. // Caller should provide either model_path, or modal_data + model_data_length. OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, @@ -491,6 +576,12 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model status = ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, input_model_path.c_str(), nullptr, 0, session)); + } else if (model_compile_options.InputModelComesFromOrtModel()) { + // Use the OrtModel overload of CreateSessionAndLoadModelImpl to maintain load-path parity + // with file/buffer inputs (same checks for ORT_LOAD_CONFIG_FROM_MODEL, EP-context output, etc.) + const OrtModel* input_model = model_compile_options.GetInputModel(); + status = ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, + input_model, session)); } else { status = ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, nullptr, model_compile_options.GetInputModelData(), diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 018204bd1dfb0..ea5e889ad67a4 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -18,6 +18,7 @@ #include "test/shared_lib/test_fixture.h" #include "test/shared_lib/utils.h" +#include "test/util/include/api_asserts.h" #include "test/util/include/test_allocator.h" #include "onnxruntime_config.h" // generated file in build output dir @@ -725,3 +726,359 @@ TEST(ModelEditorAPITest, CreateTypeInfo) { api.ReleaseTypeInfo(base_tensor_type_info); } + +// +// Tests for Model Editor API + Compile API integration +// + +namespace { +// Helper to create a simple model for testing with Model Editor API +// Creates a model with a Gemm operation: Z = X * Y where X is input and Y is initializer +Ort::Model CreateSimpleGemmModel(std::vector>>& weights) { + Ort::Graph graph; + + std::vector graph_inputs; + std::vector graph_outputs; + + // Input: X is 3x4 + std::vector input_dims({3, 4}); + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs.emplace_back("X", input_type_info.GetConst()); + + // Output: Z is 3x8 + std::vector output_dims = {3, 8}; + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + graph_outputs.emplace_back("Z", output_type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + // Gemm node with alpha=2.0 + std::vector attributes; + float alpha_value = 2.0; + attributes.push_back(OpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT)); + + Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attributes); + graph.AddNode(node); + + // Y initializer: 4x8 + std::vector y_dims = {4, 8}; + weights.emplace_back(std::make_unique>(32)); + auto& y_values = *weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + auto y_tensor = Value::CreateTensor(info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); + graph.AddInitializer("Y", y_tensor, /*data is external*/ true); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + return model; +} + +// Helper to run inference on the simple Gemm model and verify all output values. +// Model is Z = 2.0 * X * Y where X is 3x4 (all ones) and Y is 4x8 (iota 1..32). +// Expected output: each row is 2 * column_sums_of_Y = {104, 112, 120, 128, 136, 144, 152, 160}. +void RunAndVerifySimpleGemmModel(const Ort::Model& model) { + Ort::SessionOptions session_options; + Ort::Session session(*ort_env, model, session_options); + ASSERT_EQ(session.GetInputCount(), 1u); + ASSERT_EQ(session.GetOutputCount(), 1u); + + std::vector input_data(3 * 4, 1.0f); + std::vector input_dims = {3, 4}; + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + auto input_tensor = Ort::Value::CreateTensor(memory_info, input_data.data(), input_data.size(), + input_dims.data(), input_dims.size()); + + const char* input_names[] = {"X"}; + const char* output_names[] = {"Z"}; + auto outputs = session.Run(Ort::RunOptions{}, input_names, &input_tensor, 1, output_names, 1); + ASSERT_EQ(outputs.size(), 1u); + ASSERT_TRUE(outputs[0].IsTensor()); + + auto output_shape = outputs[0].GetTensorTypeAndShapeInfo().GetShape(); + ASSERT_EQ(output_shape, (std::vector{3, 8})); + + const float* output_data = outputs[0].GetTensorData(); + // alpha=2.0, X is all ones, so each output row = 2 * sum of each column of Y (iota 1..32 in 4x8) + const std::vector expected_row = {104.0f, 112.0f, 120.0f, 128.0f, 136.0f, 144.0f, 152.0f, 160.0f}; + for (int row = 0; row < 3; ++row) { + for (int col = 0; col < 8; ++col) { + EXPECT_FLOAT_EQ(output_data[row * 8 + col], expected_row[col]) + << "Mismatch at row=" << row << " col=" << col; + } + } +} +} // namespace + +// Test basic compilation from OrtModel +TEST(ModelEditorCompileAPITest, BasicCompileFromOrtModel) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + // Set the OrtModel as input + compile_options.SetInputModel(static_cast(model)); + + // Set output to buffer - use embed mode for simplicity + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Compile should succeed (note: may not produce EPContext nodes without specific EP, but validation passes) + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + // Verify output was produced + EXPECT_NE(output_buffer, nullptr); + EXPECT_GT(output_size, 0u); + + // Cleanup + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } + + // Verify the model still produces correct inference results after compilation + RunAndVerifySimpleGemmModel(model); +} +TEST(ModelEditorCompileAPITest, CompileFromNullModel_Fails) { + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + try { + compile_options.SetInputModel(nullptr); + FAIL() << "Expected exception for null model pointer"; + } catch (const Ort::Exception& e) { + EXPECT_THAT(e.what(), ::testing::HasSubstr("null")); + } +} + +// Test validation: model with no graph +TEST(ModelEditorCompileAPITest, CompileFromModelWithNoGraph_Fails) { + // Create a model but don't add a graph + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with no graph"; + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("graph")); +} + +// Test validation: model with empty inputs/outputs +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { + // Create a model with a graph that has no inputs or outputs + Ort::Graph graph; + // Don't set inputs or outputs + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with empty inputs/outputs"; + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("input")); +} + +// Test: model can be reused after compilation. +// NOTE: This is not an explicit API guarantee. It documents current behavior so that if a future change +// breaks model reuse, the regression is surfaced and can be evaluated. +TEST(ModelEditorCompileAPITest, ModelCanBeReusedAfterCompilation) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + // First compilation + { + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } + } + + // Second compilation with same model + { + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } + } + + // Model should still be usable for creating a session and running inference + RunAndVerifySimpleGemmModel(model); +} + +// Test: SetInputModel overrides previous input source (file path) +TEST(ModelEditorCompileAPITest, SetInputModelOverridesPreviousInputPath) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + // First set a file path (doesn't need to exist since we'll override it) + compile_options.SetInputModelPath(ORT_TSTR("nonexistent_file.onnx")); + + // Then override with OrtModel + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Should use the OrtModel, not the nonexistent file + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +} + +// Test: SetInputModelPath overrides previous OrtModel setting +TEST(ModelEditorCompileAPITest, SetInputModelPathOverridesPreviousModel) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + // First set an OrtModel + compile_options.SetInputModel(static_cast(model)); + + // Then override with a real file path + compile_options.SetInputModelPath(ORT_TSTR("testdata/matmul_1.onnx")); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Should use the file path, not the OrtModel + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +} + +// Test: Compile with output to file +TEST(ModelEditorCompileAPITest, CompileFromOrtModelToFile) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + auto output_path = ORT_TSTR("test_compile_from_ortmodel_output.onnx"); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetOutputModelPath(output_path); + compile_options.SetEpContextEmbedMode(true); + + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + // Verify output file exists + EXPECT_TRUE(std::filesystem::exists(output_path)); + + // Verify the output model can be loaded + Ort::Session session(*ort_env, output_path, Ort::SessionOptions()); + EXPECT_GE(session.GetInputCount(), 1u); + EXPECT_GE(session.GetOutputCount(), 1u); + + // Cleanup + std::filesystem::remove(output_path); +} + +// Test: Validation error for OrtModel with no model_path, no output location, and no embed mode. +TEST(ModelEditorCompileAPITest, NoOutputLocationNoModelPathFails) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + // Intentionally do NOT call SetEpContextEmbedMode, SetOutputModelPath, or SetOutputModelBuffer + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("output")); +} + +// Test: Setting embed mode with buffer output satisfies the output location requirement +// for OrtModel with no model_path. +TEST(ModelEditorCompileAPITest, EmbedModeWithBufferOutputSatisfiesValidation) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +}