diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 3d3f49733e209..2b3dbda98e569 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -177,6 +177,7 @@ public struct OrtApi public IntPtr TensorAt; public IntPtr CreateAndRegisterAllocator; public IntPtr SetLanguageProjection; + public IntPtr AddInitializer; } internal static class NativeMethods @@ -238,7 +239,8 @@ static NativeMethods() OrtSetSessionGraphOptimizationLevel = (DOrtSetSessionGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionGraphOptimizationLevel, typeof(DOrtSetSessionGraphOptimizationLevel)); OrtRegisterCustomOpsLibrary = (DOrtRegisterCustomOpsLibrary)Marshal.GetDelegateForFunctionPointer(api_.RegisterCustomOpsLibrary, typeof(DOrtRegisterCustomOpsLibrary)); OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry)); - + OrtAddInitializer = (DOrtAddInitializer)Marshal.GetDelegateForFunctionPointer(api_.AddInitializer, typeof(DOrtAddInitializer)); + OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions)); OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions)); OrtRunOptionsSetRunLogVerbosityLevel = (DOrtRunOptionsSetRunLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsSetRunLogVerbosityLevel, typeof(DOrtRunOptionsSetRunLogVerbosityLevel)); @@ -549,6 +551,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ library_path, out IntPtr /* (void**) */ library_handle); public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary; + public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*) */ options, string /*(const char*)*/ name, IntPtr /* OrtValue* */ ort_value); + public static DOrtAddInitializer OrtAddInitializer; + #endregion #region RunOptions API diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index 138fce5dfec23..50241a17824c2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -44,7 +44,7 @@ public class SessionOptions : SafeHandle /// Constructs an empty SessionOptions /// public SessionOptions() - :base(IntPtr.Zero, true) + : base(IntPtr.Zero, true) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionOptions(out handle)); } @@ -175,6 +175,21 @@ public void RegisterCustomOpLibrary(string libraryPath) NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, libraryPath, out libraryHandle)); } + /// + /// Add a pre-allocated initializer to a session. If a model contains an initializer with a name + /// that is same as the name passed to this API call, ORT will use this initializer instance + /// instead of deserializing one from the model file. This is useful when you want to share + /// the same initializer across sessions. + /// \param name name of the initializer + /// \param val OrtValue containing the initializer. Lifetime of 'val' and the underlying initializer buffer must be + /// managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object + /// to which it is added. + /// + public void AddInitializer(string name, OrtValue ort_value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, name, ort_value.Handle)); + } + public void AddSessionConfigEntry(string configKey, string configValue) { NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle, configKey, configValue)); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj index b9222c2dd6543..302457a4bc2bf 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj @@ -7,6 +7,7 @@ bin\$(Configuration)\ Microsoft.ML.OnnxRuntime false + True diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 04637b1581f85..40503ee8f44c7 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -268,6 +268,9 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev } } + + float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out"); + int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data // Run inference with named inputs and named outputs { // correct pre-allocated outputs @@ -276,7 +279,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev NamedOnnxValue.CreateFromTensor("softmaxout_1", new DenseTensor(expectedOutputDimensions)) }; session.Run(container, expectedOutputValues); - validateRunResultData(expectedOutputValues[0].AsTensor()); + validateRunResultData(expectedOutputValues[0].AsTensor(), expectedOutput, expectedDimensions); } // Run inference with pinned inputs and named outputs @@ -291,7 +294,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev NamedOnnxValue.CreateFromTensor("softmaxout_1", new DenseTensor(expectedOutputDimensions)) }; session.Run(inputNames, pinnedInputs, expectedOutputValues); - validateRunResultData(expectedOutputValues[0].AsTensor()); + validateRunResultData(expectedOutputValues[0].AsTensor(), expectedOutput, expectedDimensions); } // Run inference with named inputs and pinned outputs @@ -302,7 +305,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev var outputTensor = new DenseTensor(expectedOutputDimensions); pinnedOutputs.Add(FixedBufferOnnxValue.CreateFromTensor(outputTensor)); session.Run(container, expectedOutputNames, pinnedOutputs); - validateRunResultData(outputTensor); + validateRunResultData(outputTensor, expectedOutput, expectedDimensions); } } @@ -317,7 +320,7 @@ private void CanRunInferenceOnAModel(GraphOptimizationLevel graphOptimizationLev pinnedOutputs.Add(FixedBufferOnnxValue.CreateFromTensor(outputTensor)); session.Run(inputNames, pinnedInputs, expectedOutputNames, pinnedOutputs); - validateRunResultData(outputTensor); + validateRunResultData(outputTensor, expectedOutput, expectedDimensions); } } } @@ -371,15 +374,14 @@ private void validateRunResults(IReadOnlyCollection results) Assert.Equal(1, results.Count); Assert.Equal("softmaxout_1", r.Name); - validateRunResultData(r.AsTensor()); + float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out"); + int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data + validateRunResultData(r.AsTensor(), expectedOutput, expectedDimensions); } } - private void validateRunResultData(Tensor resultTensor) + private void validateRunResultData(Tensor resultTensor, float[] expectedOutput, int[] expectedDimensions) { - float[] expectedOutput = LoadTensorFromFile(@"bench.expected_out"); - - int[] expectedDimensions = { 1, 1000, 1, 1 }; // hardcoded for now for the test data Assert.Equal(expectedDimensions.Length, resultTensor.Rank); var resultDimensions = resultTensor.Dimensions; @@ -1837,6 +1839,72 @@ private void TestIOBinding() } } + [Fact] + private void TestWeightSharingBetweenSessions() + { + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "mul_1.onnx"); + + // create initializer to share + var ortCpuMemInfo = OrtMemoryInfo.DefaultInstance; + var dims = new long[] { 3, 2 }; + var dataBuffer = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F }; + var allocator = OrtAllocator.DefaultInstance; + var ortAllocationInput = allocator.Allocate((uint)dataBuffer.Length * sizeof(float)); + unsafe + { + float* p = (float*)ortAllocationInput.DangerousGetHandle(); + for (int i = 0; i < dataBuffer.Length; ++i) + { + *p++ = dataBuffer[i]; + } + } + var dataBufferNumBytes = (uint)dataBuffer.Length * sizeof(float); + var sharedInitializer = OrtValue.CreateTensorValueWithData(ortCpuMemInfo, Tensors.TensorElementType.Float, + dims, ortAllocationInput.DangerousGetHandle(), dataBufferNumBytes); + + SessionOptions options = new SessionOptions(); + options.AddInitializer("W", sharedInitializer); + + float[] expectedOutput = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F }; + int[] expectedDimensions = { 3, 2 }; + + using (var session = new InferenceSession(modelPath, options)) + using (var session2 = new InferenceSession(modelPath, options)) + { + var inputMeta = session.InputMetadata; + var container = new List(); + + foreach (var name in inputMeta.Keys) + { + Assert.Equal(typeof(float), inputMeta[name].ElementType); + Assert.True(inputMeta[name].IsTensor); + var tensor = new DenseTensor(dataBuffer, inputMeta[name].Dimensions); + container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); + } + + ReadOnlySpan expectedOutputDimensions = new int[] { 1, 1000, 1, 1 }; + string[] expectedOutputNames = new string[] { "Y" }; + + // Run inference with named inputs and outputs created with in Run() + using (var results = session.Run(container)) // results is an IReadOnlyList container + { + foreach (var r in results) + { + validateRunResultData(r.AsTensor(), expectedOutput, expectedDimensions); + } + } + + // Run inference with named inputs and outputs created with in Run() + using (var results2 = session2.Run(container)) // results is an IReadOnlyList container + { + foreach (var r in results2) + { + validateRunResultData(r.AsTensor(), expectedOutput, expectedDimensions); + } + } + } + } + [DllImport("kernel32", SetLastError = true)] static extern IntPtr LoadLibrary(string lpFileName); diff --git a/csharp/testdata/mul_1.onnx b/csharp/testdata/mul_1.onnx new file mode 100644 index 0000000000000..0b6dc51026132 Binary files /dev/null and b/csharp/testdata/mul_1.onnx differ diff --git a/docs/C_API.md b/docs/C_API.md index b7266015c3549..75de23e678478 100644 --- a/docs/C_API.md +++ b/docs/C_API.md @@ -28,6 +28,11 @@ chooses to override this by setting ```session_state.use_env_allocators``` to "0 * Set ```session.use_env_allocators``` to "1" for each session that wants to use the env registered allocators. * See test ```TestSharedAllocatorUsingCreateAndRegisterAllocator``` in onnxruntime/test/shared_lib/test_inference.cc for an example. +* **Share initializer(s) between sessions:** + * *Description*: This feature allows a user to share the same instance of an initializer across +multiple sessions. + * *Scenario*: You've several models that use the same set of initializers except the last few layers of the model and you load these models in the same process. When every model (session) creates a separate instance of the same initializer, it leads to excessive and wasteful memory usage since in this case it's the same initializer. You want to optimize memory usage while having the flexibility to allocate the initializers (possibly even store them in shared memory). + * *Example Usage*: Use the ```AddInitializer``` API to add a pre-allocated initializer to session options before calling ```CreateSession```. Use the same instance of session options to create several sessions allowing the initializer(s) to be shared between the sessions. See [C API sample usage (TestSharingOfInitializer)](../onnxruntime/test/shared_lib/test_inference.cc) and [C# API sample usage (TestWeightSharingBetweenSessions)](../csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs). ## Usage Overview diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index ac31611c42a87..e6679f32c77b3 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -178,6 +178,10 @@ class Tensor final { return static_cast(p_data_) + byte_offset_; } + bool OwnsBuffer() const noexcept { + return buffer_deleter_ != nullptr; + } + /** * Resizes the tensor without touching underlying storage. * This requires the total size of the tensor to remains constant. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index feaaa0e2a17d8..3d3240e93852f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -1051,6 +1051,19 @@ struct OrtApi { * Prefer a value of 0 if your CPU usage is very high. */ ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); + + /** + * Add a pre-allocated initializer to a session. If a model contains an initializer with a name + * that is same as the name passed to this API call, ORT will use this initializer instance + * instead of deserializing one from the model file. This is useful when you want to share + * the same initializer across sessions. + * \param name name of the initializer + * \param val OrtValue containing the initializer. Lifetime of 'val' and the underlying initializer buffer must be + * managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object + * to which it is added. + */ + ORT_API2_STATUS(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, + _In_ const OrtValue* val); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 9e8f365adc05a..1041faa91905b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -248,6 +248,8 @@ struct SessionOptions : Base { SessionOptions& DisablePerSessionThreads(); SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); + + SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); }; struct ModelMetadata : Base { @@ -330,7 +332,6 @@ struct TypeInfo : Base { Unowned GetSequenceTypeInfo() const; Unowned GetMapTypeInfo() const; - ONNXType GetONNXType() const; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 104ae566d5d5c..2da83b6b5fad7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -450,6 +450,11 @@ inline SessionOptions& SessionOptions::AddConfigEntry(const char* config_key, co return *this; } +inline SessionOptions& SessionOptions::AddInitializer(const char* name, const OrtValue* ort_val) { + ThrowOnError(GetApi().AddInitializer(p_, name, ort_val)); + return *this; +} + inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_)); } @@ -927,4 +932,7 @@ inline std::vector GetAvailableProviders() { ThrowOnError(api.ReleaseAvailableProviders(providers, len)); return available_providers; } + +SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); + } // namespace Ort diff --git a/onnxruntime/core/framework/session_options.cc b/onnxruntime/core/framework/session_options.cc index 81637faec2296..65b2c36e680f3 100644 --- a/onnxruntime/core/framework/session_options.cc +++ b/onnxruntime/core/framework/session_options.cc @@ -3,6 +3,7 @@ #include "core/framework/session_options.h" #include "core/common/logging/logging.h" +#include "core/framework/ml_value.h" namespace onnxruntime { @@ -45,4 +46,31 @@ Status SessionOptions::AddConfigEntry(const char* config_key, const char* config return Status::OK(); } + +Status SessionOptions::AddInitializer(const char* name, const OrtValue* val) noexcept { + // input validation + if (name == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Received nullptr for name."); + } + + if (val == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Received nullptr for OrtValue."); + } + + if (!val->IsTensor()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Received OrtValue is not a tensor. Only tensors are supported."); + } + + if (val->Get().OwnsBuffer()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer containing the initializer must be owned by the user."); + } + + // now do the actual work + auto rc = initializers_to_share_map.insert({name, val}); + if (!rc.second) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "An OrtValue for this name has already been added."); + } + + return Status::OK(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 77cf89cf2c20f..d8136087c5343 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -93,6 +93,10 @@ struct SessionOptions { // The configuration keys and value formats are defined in // /include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h std::unordered_map session_configurations; + std::unordered_map initializers_to_share_map; + + // See onnxruntime_c_api.h for detailed documentation. + Status AddInitializer(const char* name, const OrtValue* val) noexcept; // Check if the given SessionOptions has a config using the given config_key. // Returns true if found and copies the value into config_value. diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index cfbb43e04ea0a..02436626a55fe 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -852,7 +852,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string Status { return AddInitializedTensor(idx, value, &d, constant); }, - logger_, data_transfer_mgr_)); + logger_, data_transfer_mgr_, *p_seq_exec_plan_.get(), session_options)); // remove weights from the graph now to save memory but in many cases it won't save memory, if the tensor was // preallocated with the some other tensors in a single 'allocate' call, which is very common. diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 2ae0b80b049ad..1bcdcba04657f 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -92,20 +92,59 @@ common::Status SaveInitializedTensors( const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner, const std::function& save_tensor_func, - const logging::Logger& logger, const DataTransferManager& data_transfer_mgr) { + const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, + const ExecutionPlanBase& exec_plan, + const SessionOptions& session_options) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); + // Determine if an intializer was supplied by the user for the purpose of sharing and if it requires a cross-device + // copy. In case a cross-device copy is required, sharing cannot be accomplished since we allocate our own buffer + // for the destn device which cannot be shared between sessions. + auto use_user_supplied_initializer = + [&session_options, &exec_plan, &logger, &ort_value_name_idx_map](const std::string& name) -> bool { + bool retval = false; + auto it = session_options.initializers_to_share_map.find(name); + if (it == session_options.initializers_to_share_map.end()) { + retval = false; + } else { + int ort_value_index = -1; + if (!ort_value_name_idx_map.GetIdx(name, ort_value_index).IsOK()) { + retval = false; + } else { + auto planned_mem_info = exec_plan.GetLocation(ort_value_index); + auto user_mem_info = it->second->Get().Location(); + retval = user_mem_info.device == planned_mem_info.device; + if (!retval) { + LOGS(logger, WARNING) << "Cannot use user supplied initializer with name: (" + << name << ") because the ORT planned memory location device " + << planned_mem_info.ToString() + << " ) is different from what is supplied (" << user_mem_info.ToString() << ")"; + } + } + } + + return retval; + }; + //1. first plan the memory const onnxruntime::InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); std::unordered_map id_to_initialized_tensor; + std::set user_supplied_initializer_ids; // set containing the ort value ids of all user supplied initializers for (const auto& entry : initialized_tensor_set) { int ort_value_index; ORT_RETURN_IF_ERROR(ort_value_name_idx_map.GetIdx(entry.first, ort_value_index)); + if (use_user_supplied_initializer(entry.first)) { + user_supplied_initializer_ids.insert(ort_value_index); + } id_to_initialized_tensor[ort_value_index] = entry.second; } for (const auto& entry : id_to_initialized_tensor) { + // We don't want to trace shared initializers since their memory is provided by the user + if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { + continue; + } ORT_RETURN_IF_ERROR(planner.Trace(entry.first, entry.second)); } @@ -121,28 +160,34 @@ common::Status SaveInitializedTensors( << i.second << " bytes for " << i.first << std::endl; } - OrtCallback deleter; + OrtCallback deleter{nullptr, nullptr}; //3. create weight tensors based on weights buffer for (const auto& entry : id_to_initialized_tensor) { int ort_value_index = entry.first; const char* name = (entry.second->name().empty()) ? "" : entry.second->name().c_str(); - const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); + OrtValue ort_value; + + if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { + ort_value = *(session_options.initializers_to_share_map.at(name)); + LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; + } else { + const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); - std::unique_ptr m; - // TODO: if the tensor need be copied, does it have enough room? - ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, m)); + std::unique_ptr m; + // TODO: if the tensor need be copied, does it have enough room? + ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, m)); #ifndef NDEBUG - ORT_ENFORCE(m != nullptr); - ORT_ENFORCE(m->GetBuffer() != nullptr || m->GetLen() == 0); + ORT_ENFORCE(m != nullptr); + ORT_ENFORCE(m->GetBuffer() != nullptr || m->GetLen() == 0); #endif - OrtValue ort_value; - Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, *m, default_cpu_memory_info, ort_value, deleter, - data_transfer_mgr); - if (!st.IsOK()) { - std::ostringstream oss; - oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); - return Status(st.Category(), st.Code(), oss.str()); + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, *m, default_cpu_memory_info, ort_value, deleter, + data_transfer_mgr); + if (!st.IsOK()) { + std::ostringstream oss; + oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); + return Status(st.Category(), st.Code(), oss.str()); + } } // any outer scope value is shadowed by a local value and can't override it. diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index 1c2110413dad6..b499c8426734a 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -33,7 +33,9 @@ common::Status SaveInitializedTensors( ITensorAllocator& planner, const std::function& save_tensor_func, const logging::Logger& logger, - const DataTransferManager& data_transfer_mgr); + const DataTransferManager& data_transfer_mgr, + const ExecutionPlanBase& exec_plan, + const SessionOptions& session_options); common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, SessionState& session_state, diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 793bfb7008a80..5a1235ae611dc 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -25,11 +25,11 @@ class WindowsTelemetry : public Telemetry { void LogProcessInfo() const override; void LogSessionCreationStart() const override; - + void LogEvaluationStop() const override; void LogEvaluationStart() const override; - + void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, const std::unordered_map& domain_to_version_map, diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index e784cd6ce54e0..ea4be7848c8c9 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -184,3 +184,12 @@ ORT_API_STATUS_IMPL(OrtApis::AddSessionConfigEntry, _Inout_ OrtSessionOptions* o _In_z_ const char* config_key, _In_z_ const char* config_value) { return onnxruntime::ToOrtStatus(options->value.AddConfigEntry(config_key, config_value)); } + +ORT_API_STATUS_IMPL(OrtApis::AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, + _In_ const OrtValue* val) { + auto st = options->value.AddInitializer(name, val); + if (!st.IsOK()) { + return onnxruntime::ToOrtStatus(st); + } + return nullptr; +} \ No newline at end of file diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5ebbc62fb0ece..6327c77417779 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1835,7 +1835,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_5 = { +static constexpr OrtApi ort_api_1_to_6 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -1986,8 +1986,6 @@ static constexpr OrtApi ort_api_1_to_5 = { &OrtApis::ReleaseAvailableProviders, // End of Version 4 - DO NOT MODIFY ABOVE (see above text for more information) - // Version 5 - In development, feel free to add/remove/rearrange here - &OrtApis::GetStringTensorElementLength, &OrtApis::GetStringTensorElement, &OrtApis::FillStringTensorElement, @@ -2014,6 +2012,10 @@ static constexpr OrtApi ort_api_1_to_5 = { &OrtApis::SetGlobalIntraOpNumThreads, &OrtApis::SetGlobalInterOpNumThreads, &OrtApis::SetGlobalSpinControl, + // End of Version 5 - DO NOT MODIFY ABOVE (see above text for more information) + + // Version 6 - In development, feel free to add/remove/rearrange here + &OrtApis::AddInitializer, }; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) @@ -2021,8 +2023,8 @@ static constexpr OrtApi ort_api_1_to_5 = { static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change"); ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { - if (version >= 1 && version <= 5) - return &ort_api_1_to_5; + if (version >= 1 && version <= 6) + return &ort_api_1_to_6; return nullptr; // Unsupported version } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fecfb2d083d2f..15ccfef3db4e5 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -237,4 +237,6 @@ ORT_API_STATUS_IMPL(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, ORT_API_STATUS_IMPL(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); ORT_API_STATUS_IMPL(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); ORT_API_STATUS_IMPL(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); +ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_ const char* name, + _In_ const OrtValue* val); } // namespace OrtApis diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index d0d696f8a38f1..6e676f9e2d596 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -2158,13 +2158,21 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { #endif } +struct InferenceSessionExposingSessionState : public InferenceSession { + InferenceSessionExposingSessionState(const SessionOptions& session_options, + const Environment& env) + : InferenceSession(session_options, env) { + } + const SessionState& GetSessionState() const { return InferenceSession::GetSessionState(); } +}; + // Global threadpool related tests // We test for 4 combinations -class InferenceSessionTestGlobalThreadPools : public InferenceSession { +class InferenceSessionTestGlobalThreadPools : public InferenceSessionExposingSessionState { public: InferenceSessionTestGlobalThreadPools(const SessionOptions& session_options, const Environment& env) - : InferenceSession(session_options, env) { + : InferenceSessionExposingSessionState(session_options, env) { } onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPoolToUse() const { @@ -2174,8 +2182,6 @@ class InferenceSessionTestGlobalThreadPools : public InferenceSession { onnxruntime::concurrency::ThreadPool* GetInterOpThreadPoolToUse() const { return InferenceSession::GetInterOpThreadPoolToUse(); } - - const SessionState& GetSessionState() { return InferenceSession::GetSessionState(); } }; // Test 1: env created WITHOUT global tp / use per session tp (default case): in this case per session tps should be in use @@ -2331,14 +2337,12 @@ TEST(InferenceSessionTests, InvalidSessionEnvCombination) { } // Tests for sharing allocators between sessions -class InferenceSessionTestSharingAllocator : public InferenceSession { +class InferenceSessionTestSharingAllocator : public InferenceSessionExposingSessionState { public: InferenceSessionTestSharingAllocator(const SessionOptions& session_options, const Environment& env) - : InferenceSession(session_options, env) { + : InferenceSessionExposingSessionState(session_options, env) { } - - const SessionState& GetSessionState() { return InferenceSession::GetSessionState(); } }; // Ensure sessions use the same allocator. It uses ORT created allocator. @@ -2430,5 +2434,86 @@ TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsDontUseSameOrtCreated ASSERT_NE(sess1.GetSessionState().GetAllocator(mem_info).get(), sess2.GetSessionState().GetAllocator(mem_info).get()); } + +class InferenceSessionTestSharingInitializer : public InferenceSessionExposingSessionState { + public: + InferenceSessionTestSharingInitializer(const SessionOptions& session_options, + const Environment& env) + : InferenceSessionExposingSessionState(session_options, env) { + } +}; + +TEST(InferenceSessionTests, InitializerSharing_EnsureSessionsUseUserAddedInitializer) { + auto logging_manager = onnxruntime::make_unique( + std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, + LoggingManager::InstanceType::Temporal); + + std::unique_ptr env; + auto st = Environment::Create(std::move(logging_manager), env); + ASSERT_TRUE(st.IsOK()); + + // create initializer to share between sessions + const char* init_name = "W"; + OrtValue val_to_share_from_allocator; + OrtValue val_to_share; + std::vector input_data_vec{1., 2., 3., 4., 5., 6.}; + + auto allocator = TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault); + CreateMLValue(allocator, {3, 2}, input_data_vec, &val_to_share_from_allocator); + + OrtMemoryInfo mem_info{CPU, OrtArenaAllocator}; + CreateMLValue({3, 2}, input_data_vec.data(), mem_info, &val_to_share); + + // create sessions to share the allocator + SessionOptions so1; + ASSERT_STATUS_OK(so1.AddInitializer(init_name, &val_to_share)); + + // ensure an error is returned when an initializer with the same name is added. + ASSERT_FALSE(so1.AddInitializer(init_name, &val_to_share).IsOK()); + + // ensure an error is returned when an initializer with a buffer NOT owned by the user is added. + ASSERT_FALSE(so1.AddInitializer(init_name, &val_to_share_from_allocator).IsOK()); + + InferenceSessionTestSharingInitializer sess1(so1, *env); + ASSERT_STATUS_OK(sess1.Load(MODEL_URI)); + ASSERT_STATUS_OK(sess1.Initialize()); + + SessionOptions so2; + ASSERT_STATUS_OK(so2.AddInitializer(init_name, &val_to_share)); + InferenceSessionTestSharingInitializer sess2(so2, *env); + ASSERT_STATUS_OK(sess2.Load(MODEL_URI)); + ASSERT_STATUS_OK(sess2.Initialize()); + + SessionOptions so3; + InferenceSessionTestSharingInitializer sess3(so3, *env); + ASSERT_STATUS_OK(sess3.Load(MODEL_URI)); + ASSERT_STATUS_OK(sess3.Initialize()); + + int so1_idx; + ASSERT_STATUS_OK(sess1.GetSessionState().GetOrtValueNameIdxMap().GetIdx(init_name, so1_idx)); + const auto* so1_init_buffer = sess1.GetSessionState().GetInitializedTensors().at(so1_idx).Get().Data(); + + int so2_idx; + ASSERT_STATUS_OK(sess2.GetSessionState().GetOrtValueNameIdxMap().GetIdx(init_name, so2_idx)); + const auto* so2_init_buffer = sess2.GetSessionState().GetInitializedTensors().at(so2_idx).Get().Data(); + + // Ensure session1 stores the same data ptr as the one supplied by the user + ASSERT_EQ(so1_init_buffer, val_to_share.Get().Data()); + + // Ensure both sessions share the same data ptr + ASSERT_EQ(so1_init_buffer, so2_init_buffer); + + int so3_idx; + ASSERT_STATUS_OK(sess3.GetSessionState().GetOrtValueNameIdxMap().GetIdx(init_name, so3_idx)); + const auto* so3_init_buffer = sess3.GetSessionState().GetInitializedTensors().at(so3_idx).Get().Data(); + + // Ensure session 3 doesn't share the same data ptr as any other session + ASSERT_NE(so3_init_buffer, so1_init_buffer); + ASSERT_NE(so3_init_buffer, so2_init_buffer); + + // Ensure session 3 doesn't share the same data ptr as the one supplied by the user for any of the other sessions + ASSERT_NE(so3_init_buffer, val_to_share.Get().Data()); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/framework/test_utils.h index 2119634e3c26b..9827ba2c6e1e7 100644 --- a/onnxruntime/test/framework/test_utils.h +++ b/onnxruntime/test/framework/test_utils.h @@ -85,6 +85,21 @@ void CreateMLValue(AllocatorPtr alloc, const std::vector& dims, const s DataTypeImpl::GetType()->GetDeleteFunc()); } +// Lifetime of data_buffer should be managed by the caller. +template +void CreateMLValue(const std::vector& dims, T* data_buffer, const OrtMemoryInfo& info, + OrtValue* p_mlvalue) { + TensorShape shape(dims); + auto element_type = DataTypeImpl::GetType(); + std::unique_ptr p_tensor = onnxruntime::make_unique(element_type, + shape, + data_buffer, + info); + p_mlvalue->Init(p_tensor.release(), + DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc()); +} + template void AllocateMLValue(AllocatorPtr alloc, const std::vector& dims, OrtValue* p_mlvalue) { TensorShape shape(dims); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 04169d129c537..b696f863fe2f2 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -936,3 +936,47 @@ TEST(CApiTest, TestSharedAllocatorUsingCreateAndRegisterAllocator) { expected_values_y, nullptr); } + +TEST(CApiTest, TestSharingOfInitializer) { + // simple inference test + // prepare inputs + std::vector inputs(1); + Input& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + Ort::SessionOptions session_options; + Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + float data[] = {1., 2., 3., 4., 5., 6.}; + const int data_len = sizeof(data) / sizeof(data[0]); + const int64_t shape[] = {3, 2}; + const size_t shape_len = sizeof(shape) / sizeof(shape[0]); + Ort::Value val = Ort::Value::CreateTensor(mem_info, data, data_len, shape, shape_len); + session_options.AddInitializer("W", val); + + auto default_allocator = onnxruntime::make_unique(); + // create session 1 + Ort::Session session1(*ort_env, MODEL_URI, session_options); + RunSession(default_allocator.get(), + session1, + inputs, + "Y", + expected_dims_y, + expected_values_y, + nullptr); + + // create session 2 + Ort::Session session2(*ort_env, MODEL_URI, session_options); + RunSession(default_allocator.get(), + session2, + inputs, + "Y", + expected_dims_y, + expected_values_y, + nullptr); +} \ No newline at end of file diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index e4e174dcc4f5e..31e1ce01fefca 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -51,6 +51,7 @@ static SessionOptions SESSION_OPTION = { true, //thread_pool_allow_spinning false, //use_deterministic_compute {}, //session_configurations + {}, // initializers_to_share_map }; TrainingRunner::TrainingRunner(Parameters params, const Environment& env) @@ -214,7 +215,7 @@ Status TrainingRunner::Initialize() { pipeline_context_.pipeline_tensor_names = config_result.pipeline_config_result.value().pipeline_tensor_names; // Create a local function to append non-empty name to fetch_names list. - auto append_non_empty_name = [&] (const std::string& name) { + auto append_non_empty_name = [&](const std::string& name) { if (!name.empty()) { fetch_names.push_back(name); } @@ -277,7 +278,6 @@ Status TrainingRunner::Initialize() { Status TrainingRunner::Run(IDataLoader* training_data_loader, IDataLoader* test_data_loader, const MapStringToString& mapped_dimensions) { if (MPIContext::GetInstance().GetWorldRank() == 0 && !params_.model_actual_running_graph_path.empty()) { - session_.Save(params_.model_actual_running_graph_path, TrainingSession::SaveOption::NO_RELOAD); } @@ -466,7 +466,7 @@ Status TrainingRunner::PrepareFetchNamesAndFetches(const SessionMode mode, // TODO: create a list of must-to-fetch tensors and pass it to all graph transformer. if (params_.pipeline_parallel_size > 1) { // Create a local function to append non-empty name to fetch_names list. - auto append_non_empty_name = [&] (const std::string& name) { + auto append_non_empty_name = [&](const std::string& name) { if (!name.empty()) { fetch_names.push_back(name); }